@llm-dev-ops/shield-sdk
Version:
Enterprise-grade SDK for securing Large Language Model applications
318 lines • 10.4 kB
JavaScript
import { PromptInjectionScanner } from './scanners/prompt-injection.js';
import { SecretsScanner } from './scanners/secrets.js';
import { PIIScanner } from './scanners/pii.js';
import { ToxicityScanner } from './scanners/toxicity.js';
/**
* Shield - Main security facade for LLM applications
*
* @example
* ```typescript
* // Using presets
* const shield = Shield.standard();
* const result = await shield.scanPrompt("User input");
*
* // Custom configuration
* const shield = Shield.builder()
* .addInputScanner(new SecretsScanner())
* .addInputScanner(new PIIScanner())
* .withParallelExecution(true)
* .build();
* ```
*/
export class Shield {
inputScanners;
outputScanners;
config;
constructor(config, inputScanners, outputScanners) {
this.config = config;
this.inputScanners = inputScanners;
this.outputScanners = outputScanners;
}
/**
* Create a Shield with strict security settings
* Maximum security for regulated industries (banking, healthcare)
*/
static strict() {
return new Shield({
preset: 'strict',
shortCircuitThreshold: 0.7,
parallelExecution: false,
}, [
new PromptInjectionScanner(),
new SecretsScanner(),
new PIIScanner(),
new ToxicityScanner(),
], [
new SecretsScanner(),
new PIIScanner(),
]);
}
/**
* Create a Shield with standard security settings
* Balanced security for general-purpose applications (recommended)
*/
static standard() {
return new Shield({
preset: 'standard',
shortCircuitThreshold: 0.9,
parallelExecution: true,
maxConcurrent: 4,
}, [
new PromptInjectionScanner(),
new SecretsScanner(),
new PIIScanner({ piiTypes: ['email', 'ssn', 'credit-card'] }),
], [
new SecretsScanner(),
new PIIScanner({ piiTypes: ['email', 'ssn', 'credit-card'] }),
]);
}
/**
* Create a Shield with permissive security settings
* Minimal security for development/testing
*/
static permissive() {
return new Shield({
preset: 'permissive',
shortCircuitThreshold: 1.0,
parallelExecution: true,
}, [
new SecretsScanner({ secretTypes: ['aws', 'private-key'] }),
], []);
}
/**
* Create a ShieldBuilder for custom configuration
*/
static builder() {
return new ShieldBuilder();
}
/**
* Scan a prompt before sending to LLM
*/
async scanPrompt(text, options) {
return this.runScanners(text, this.inputScanners, options);
}
/**
* Scan LLM output before returning to user
*/
async scanOutput(text, options) {
return this.runScanners(text, this.outputScanners, options);
}
/**
* Scan both prompt and output in sequence
*/
async scanPromptAndOutput(prompt, output, options) {
const promptResult = await this.scanPrompt(prompt, options);
// Short-circuit if prompt is invalid
if (!promptResult.isValid && promptResult.riskScore >= (this.config.shortCircuitThreshold ?? 0.9)) {
return {
promptResult,
outputResult: {
isValid: false,
riskScore: 0,
sanitizedText: '',
entities: [],
riskFactors: [{
category: 'prompt-injection',
description: 'Output scan skipped due to invalid prompt',
severity: 'none',
confidence: 1.0,
}],
severity: 'none',
metadata: { skipped: 'true' },
durationMs: 0,
},
};
}
const outputResult = await this.scanOutput(output, options);
return { promptResult, outputResult };
}
/**
* Scan multiple texts in batch
*/
async scanBatch(texts, options) {
if (this.config.parallelExecution) {
return Promise.all(texts.map(text => this.scanPrompt(text, options)));
}
const results = [];
for (const text of texts) {
results.push(await this.scanPrompt(text, options));
}
return results;
}
async runScanners(text, scanners, options) {
if (scanners.length === 0) {
return {
isValid: true,
riskScore: 0,
sanitizedText: text,
entities: [],
riskFactors: [],
severity: 'none',
metadata: {},
durationMs: 0,
};
}
const startTime = performance.now();
const processedText = options?.maxLength ? text.slice(0, options.maxLength) : text;
let results;
if (this.config.parallelExecution) {
const maxConcurrent = this.config.maxConcurrent ?? 4;
results = await this.runParallel(processedText, scanners, maxConcurrent);
}
else {
results = await this.runSequential(processedText, scanners);
}
// Merge results
const mergedResult = this.mergeResults(processedText, results);
mergedResult.durationMs = performance.now() - startTime;
return mergedResult;
}
async runSequential(text, scanners) {
const results = [];
for (const scanner of scanners) {
const result = await scanner.scan(text);
results.push(result);
// Short-circuit if threshold exceeded
if (result.riskScore >= (this.config.shortCircuitThreshold ?? 0.9)) {
break;
}
}
return results;
}
async runParallel(text, scanners, maxConcurrent) {
const results = [];
for (let i = 0; i < scanners.length; i += maxConcurrent) {
const batch = scanners.slice(i, i + maxConcurrent);
const batchResults = await Promise.all(batch.map(s => s.scan(text)));
results.push(...batchResults);
// Check for short-circuit
const maxRisk = Math.max(...batchResults.map(r => r.riskScore));
if (maxRisk >= (this.config.shortCircuitThreshold ?? 0.9)) {
break;
}
}
return results;
}
mergeResults(originalText, results) {
if (results.length === 0) {
return {
isValid: true,
riskScore: 0,
sanitizedText: originalText,
entities: [],
riskFactors: [],
severity: 'none',
metadata: {},
durationMs: 0,
};
}
const allEntities = [];
const allRiskFactors = [];
let maxRiskScore = 0;
let maxSeverity = 'none';
const severityOrder = ['none', 'low', 'medium', 'high', 'critical'];
for (const result of results) {
allEntities.push(...result.entities);
allRiskFactors.push(...result.riskFactors);
if (result.riskScore > maxRiskScore) {
maxRiskScore = result.riskScore;
}
const severityIndex = severityOrder.indexOf(result.severity);
if (severityIndex > severityOrder.indexOf(maxSeverity)) {
maxSeverity = result.severity;
}
}
// Deduplicate entities by position
const uniqueEntities = this.deduplicateEntities(allEntities);
// Get sanitized text from the last result that modified it
let sanitizedText = originalText;
for (const result of results) {
if (result.sanitizedText !== originalText) {
sanitizedText = result.sanitizedText;
}
}
return {
isValid: allRiskFactors.length === 0,
riskScore: maxRiskScore,
sanitizedText,
entities: uniqueEntities,
riskFactors: allRiskFactors,
severity: maxSeverity,
metadata: {},
durationMs: 0,
};
}
deduplicateEntities(entities) {
const seen = new Set();
return entities.filter(entity => {
const key = `${entity.start}-${entity.end}-${entity.entityType}`;
if (seen.has(key))
return false;
seen.add(key);
return true;
});
}
}
/**
* Builder for creating custom Shield configurations
*/
export class ShieldBuilder {
inputScanners = [];
outputScanners = [];
config = {
parallelExecution: true,
maxConcurrent: 4,
shortCircuitThreshold: 0.9,
};
/**
* Add an input scanner
*/
addInputScanner(scanner) {
this.inputScanners.push(scanner);
return this;
}
/**
* Add an output scanner
*/
addOutputScanner(scanner) {
this.outputScanners.push(scanner);
return this;
}
/**
* Set the short-circuit threshold
*/
withShortCircuit(threshold) {
this.config.shortCircuitThreshold = threshold;
return this;
}
/**
* Enable or disable parallel execution
*/
withParallelExecution(enabled) {
this.config.parallelExecution = enabled;
return this;
}
/**
* Set maximum concurrent scanners
*/
withMaxConcurrent(max) {
this.config.maxConcurrent = max;
return this;
}
/**
* Build the Shield instance
*/
build() {
// Use private constructor via closure
return Shield['strict']().constructor['call'](Object.create(Shield.prototype), this.config, this.inputScanners, this.outputScanners) || new Shield(this.config, this.inputScanners, this.outputScanners);
}
}
// Fix the build method to work properly
ShieldBuilder.prototype.build = function () {
const shield = Shield.strict();
shield.config = this.config;
shield.inputScanners = this.inputScanners;
shield.outputScanners = this.outputScanners;
return shield;
};
//# sourceMappingURL=shield.js.map