UNPKG

dcpe-js

Version:

Distance Comparison Preserving Encryption for secure searchable vector embeddings

1,405 lines (1,282 loc) 65.4 kB
import crypto from 'crypto'; import { create, all } from 'mathjs'; // Base class for all exceptions class DCPEError extends Error { /** * Initializes the base error with a default or provided message. * @param {string} message - The error message. */ constructor() { let message = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : "An error occurred in the SDK"; super(message); this.name = "DCPEError"; } } // Error while loading or with invalid configuration class InvalidConfigurationError extends DCPEError { constructor() { let message = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : "Invalid configuration"; super(`InvalidConfigurationError: ${message}`); this.name = "InvalidConfigurationError"; } } // Error with key used for encryption or decryption class InvalidKeyError extends DCPEError { constructor() { let message = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : "Invalid key"; super(`InvalidKeyError: ${message}`); this.name = "InvalidKeyError"; } } // Error with user-provided input data class InvalidInputError extends DCPEError { constructor() { let message = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : "Invalid input"; super(`InvalidInputError: ${message}`); this.name = "InvalidInputError"; } } // Base class for encryption-related errors class EncryptError extends DCPEError { constructor() { let message = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : "Encryption error"; super(`EncryptError: ${message}`); this.name = "EncryptError"; } } // Base class for decryption-related errors class DecryptError extends DCPEError { constructor() { let message = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : "Decryption error"; super(`DecryptError: ${message}`); this.name = "DecryptError"; } } // Errors specific to vector encryption class VectorEncryptError extends EncryptError { constructor() { let message = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : "Vector encryption error"; super(message); this.name = "VectorEncryptError"; this.message = `VectorEncryptError: ${message.replace(/^VectorEncryptError: /, '')}`; } } // Errors specific to vector decryption class VectorDecryptError extends DecryptError { constructor() { let message = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : "Vector decryption error"; super(message); this.name = "VectorDecryptError"; this.message = `VectorDecryptError: ${message.replace(/^VectorDecryptError: /, '')}`; } } // Error due to numerical overflow during encryption class OverflowError extends EncryptError { constructor() { let message = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : "Embedding or approximation factor too large"; super(message); this.name = "OverflowError"; this.message = `OverflowError: ${message.replace(/^OverflowError: /, '')}`; } } // Error during Protobuf serialization or deserialization class ProtobufError extends DCPEError { constructor() { let message = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : "Protobuf error"; super(`ProtobufError: ${message}`); this.name = "ProtobufError"; } } // Error during a request to an external service (like TSP) class RequestError extends DCPEError { constructor() { let message = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : "Request error"; super(`RequestError: ${message}`); this.name = "RequestError"; } } // Error during JSON serialization or deserialization class SerdeJsonError extends DCPEError { constructor() { let message = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : "Serde JSON error"; super(`SerdeJsonError: ${message}`); this.name = "SerdeJsonError"; } } // Error directly from the Tenant Security Proxy (TSP) /** * Represents a TSP (Third-Party Service Provider) error. * This error extends the DCPEError class and includes additional * details specific to TSP-related issues. * * @class * @extends DCPEError * * @param {string} errorVariant - The error variant, typically a string representation * that categorizes the type of error. * @param {number} httpCode - The HTTP status code associated with the error. * @param {number} tspCode - The TSP-specific error code providing additional context. * @param {string} [message="TSP error"] - A descriptive error message. */ class TspError extends DCPEError { /** * Initializes the TSP error with specific details. * @param {string} errorVariant - The error variant (e.g., string representation). * @param {number} httpCode - The HTTP status code. * @param {number} tspCode - The TSP-specific error code. * @param {string} message - The error message. */ constructor(errorVariant, httpCode, tspCode) { let message = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : "TSP error"; super(`TspError: ${message}, Variant: '${errorVariant}', HTTP Code: ${httpCode}, TSP Code: ${tspCode}`); this.name = "TspError"; this.errorVariant = errorVariant; this.httpCode = httpCode; this.tspCode = tspCode; } } var index$5 = /*#__PURE__*/Object.freeze({ __proto__: null, DCPEError: DCPEError, DecryptError: DecryptError, EncryptError: EncryptError, InvalidConfigurationError: InvalidConfigurationError, InvalidInputError: InvalidInputError, InvalidKeyError: InvalidKeyError, OverflowError: OverflowError, ProtobufError: ProtobufError, RequestError: RequestError, SerdeJsonError: SerdeJsonError, TspError: TspError, VectorDecryptError: VectorDecryptError, VectorEncryptError: VectorEncryptError }); /** * Represents a raw encryption key as bytes. */ class EncryptionKey { /** * @param {Buffer} keyBytes - The raw encryption key as a Buffer. */ constructor(keyBytes) { if (!Buffer.isBuffer(keyBytes)) { throw new TypeError('EncryptionKey must be initialized with a Buffer'); } this.keyBytes = keyBytes; } /** * Returns the raw key bytes. * @returns {Buffer} */ getBytes() { return this.keyBytes; } /** * Checks equality with another EncryptionKey. * @param {EncryptionKey} other * @returns {boolean} */ equals(other) { return other instanceof EncryptionKey && this.keyBytes.equals(other.keyBytes); } /** * String representation of the EncryptionKey. * @returns {string} */ toString() { return `EncryptionKey(bytes of length: ${this.keyBytes.length})`; } } /** * Represents the scaling factor used in vector encryption. */ class ScalingFactor { /** * @param {number} factor - The scaling factor as a float. */ constructor(factor) { if (typeof factor !== 'number') { throw new TypeError('ScalingFactor must be initialized with a number'); } this.factor = factor; } /** * Returns the scaling factor value. * @returns {number} */ getFactor() { return this.factor; } /** * Checks equality with another ScalingFactor. * @param {ScalingFactor} other * @returns {boolean} */ equals(other) { return other instanceof ScalingFactor && this.factor === other.factor; } /** * String representation of the ScalingFactor. * @returns {string} */ toString() { return `ScalingFactor(factor: ${this.factor})`; } } /** * Represents the combined key for vector encryption, including scaling factor and encryption key. */ class VectorEncryptionKey { /** * @param {ScalingFactor} scalingFactor - The scaling factor. * @param {EncryptionKey} key - The encryption key. */ constructor(scalingFactor, key) { if (!(scalingFactor instanceof ScalingFactor)) { throw new TypeError('VectorEncryptionKey scalingFactor must be a ScalingFactor instance'); } if (!(key instanceof EncryptionKey)) { throw new TypeError('VectorEncryptionKey key must be an EncryptionKey instance'); } this.scalingFactor = scalingFactor; this.key = key; } /** * Derives a VectorEncryptionKey from a master secret, tenant ID, and derivation path. * @param {Buffer} secret - The master secret as a Buffer. * @param {string} tenantId - The tenant ID. * @param {string} derivationPath - The derivation path. * @returns {VectorEncryptionKey} */ static deriveFromSecret(secret, tenantId, derivationPath) { if (!Buffer.isBuffer(secret)) { throw new TypeError('Secret must be a Buffer'); } if (typeof tenantId !== 'string') { throw new TypeError('Tenant ID must be a string'); } if (typeof derivationPath !== 'string') { throw new TypeError('Derivation Path must be a string'); } const payload = Buffer.from(`${tenantId}-${derivationPath}`, 'utf-8'); const hashResultBytes = crypto.createHmac('sha512', secret).update(payload).digest(); return this.unsafeBytesToKey(hashResultBytes); } /** * Constructs a VectorEncryptionKey from raw bytes. * @param {Buffer} keyBytes - The raw bytes. * @returns {VectorEncryptionKey} * @throws {InvalidKeyError} If keyBytes is not long enough. */ static unsafeBytesToKey(keyBytes) { if (keyBytes.length < 35) { throw new InvalidKeyError('Key bytes must be at least 35 bytes long'); } const scalingFactorBytes = keyBytes.subarray(0, 3); const keyMaterialBytes = keyBytes.subarray(3, 35); // Add leading zero byte to match Python's behavior const paddedBytes = Buffer.concat([Buffer.from([0]), scalingFactorBytes]); // Use readUInt32BE instead of parseInt for consistent binary representation const scalingFactorU32 = paddedBytes.readUInt32BE(0); const scalingFactor = new ScalingFactor(scalingFactorU32); const encryptionKey = new EncryptionKey(keyMaterialBytes); return new VectorEncryptionKey(scalingFactor, encryptionKey); } /** * Checks equality with another VectorEncryptionKey. * @param {VectorEncryptionKey} other * @returns {boolean} */ equals(other) { return other instanceof VectorEncryptionKey && this.scalingFactor.equals(other.scalingFactor) && this.key.equals(other.key); } /** * String representation of the VectorEncryptionKey. * @returns {string} */ toString() { return `VectorEncryptionKey(scalingFactor=${this.scalingFactor}, key=${this.key})`; } } /** * Generates a cryptographically random EncryptionKey (32 bytes). * @returns {EncryptionKey} */ function generateRandomKey() { return new EncryptionKey(crypto.randomBytes(32)); } /** * Generates encryption keys for use with DCPE * @param {Object} options - Options for key generation * @param {number} [options.approximationFactor=1.0] - Approximation factor for vector encryption * @returns {Promise<Buffer>} - Generated encryption key material */ async function generateEncryptionKeys() { let options = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : {}; const { approximationFactor = 1.0 } = options; // Generate a random encryption key const keyMaterial = crypto.randomBytes(32); // Create a ScalingFactor instance const scalingFactor = new ScalingFactor(approximationFactor); // Create and return a VectorEncryptionKey return new VectorEncryptionKey(scalingFactor, new EncryptionKey(keyMaterial)); } var index$4 = /*#__PURE__*/Object.freeze({ __proto__: null, EncryptionKey: EncryptionKey, ScalingFactor: ScalingFactor, VectorEncryptionKey: VectorEncryptionKey, generateEncryptionKeys: generateEncryptionKeys, generateRandomKey: generateRandomKey }); const math = create(all); /** * Represents an authentication hash. */ class AuthHash { constructor(hashBytes) { if (!Buffer.isBuffer(hashBytes)) { throw new TypeError("AuthHash must be initialized with a Buffer"); } if (hashBytes.length !== 32) { throw new Error("AuthHash must be 32 bytes long"); } this.hashBytes = hashBytes; } getBytes() { return this.hashBytes; } equals(other) { return other instanceof AuthHash && this.hashBytes.equals(other.hashBytes); } toString() { return `AuthHash(${this.hashBytes.toString('hex')})`; } } /** * Generates a random vector sampled from a multivariate normal distribution. * @param {number} dimensionality - The dimensionality of the vector. * @returns {Array<number>} - The sampled vector. */ function sampleNormalVector(dimensionality) { return Array.from({ length: dimensionality }, () => { // Use crypto instead of Math.random() const u1Bytes = crypto.randomBytes(4); const u2Bytes = crypto.randomBytes(4); const u1 = u1Bytes.readUInt32LE(0) / 0x100000000; const u2 = u2Bytes.readUInt32LE(0) / 0x100000000; const z0 = Math.sqrt(-2.0 * Math.log(u1)) * Math.cos(2.0 * Math.PI * u2); return z0; }); } /** * Generates a random uniform point in the range [0, 1). * * This function uses cryptographic randomness to ensure high-quality random values. * It generates 4 random bytes, interprets them as a 32-bit unsigned integer, * and then normalizes the value to a floating-point number in the range [0, 1). * * @returns {number} A random floating-point number in the range [0, 1). */ function sampleUniformPoint() { const bytes = crypto.randomBytes(4); return bytes.readUInt32LE(0) / 0x100000000; } /** * Calculates a uniform point within an n-dimensional ball. * @param {ScalingFactor} scalingFactor - The scaling factor. * @param {number} approximationFactor - The approximation factor. * @param {number} uniformPoint - The sampled uniform point. * @param {number} dimensionality - The dimensionality of the vector. * @returns {number} - The calculated point. */ function calculateUniformPointInBall(scalingFactor, approximationFactor, uniformPoint, dimensionality) { const radius = scalingFactor.getFactor() / 4 * approximationFactor; return radius * Math.pow(uniformPoint, 1 / dimensionality); } /** * Normalizes a sampled vector. * @param {Array<number>} vector - The sampled vector. * @param {number} scale - The scaling factor. * @returns {Array<number>} - The normalized vector. */ function normalizeVector(vector, scale) { const norm = math.norm(vector); return vector.map(val => val * scale / norm); } /** * Generates a normalized noise vector for encryption. * @param {VectorEncryptionKey} key - The encryption key. * @param {Buffer} iv - The initialization vector. * @param {number} approximationFactor - The approximation factor. * @param {number} dimensionality - The dimensionality of the vector. * @returns {Array<number>} - The noise vector. */ function generateNoiseVector(key, iv, approximationFactor, dimensionality) { if (!key) { throw new Error("Key is required for noise vector generation"); } if (!iv || !Buffer.isBuffer(iv)) { throw new Error("IV must be a valid Buffer for noise vector generation"); } if (!Number.isFinite(approximationFactor) || approximationFactor <= 0) { throw new Error("Approximation factor must be a positive number"); } if (!Number.isInteger(dimensionality) || dimensionality <= 0) { throw new Error("Dimensionality must be a positive integer"); } const normalVector = sampleNormalVector(dimensionality); const uniformPoint = sampleUniformPoint(); const scaledPoint = calculateUniformPointInBall(key.scalingFactor, approximationFactor, uniformPoint, dimensionality); return normalizeVector(normalVector, scaledPoint); } /** * Creates a random number generator (RNG) function based on a given cryptographic key. * The RNG function generates pseudo-random numbers in the range [0, 1) using HMAC with SHA-256. * * @param {Object} key - The cryptographic key used to seed the RNG. It must have a `getBytes` method * that returns the key as a byte array. * @returns {Function} A function that generates a pseudo-random number between 0 (inclusive) and 1 (exclusive) * each time it is called. */ function createRngFromKey(key) { return function () { const hmac = crypto.createHmac('sha256', key.getBytes()); hmac.update(Buffer.from([this.counter++ & 0xFF])); const bytes = hmac.digest(); return bytes.readUInt32LE(0) / 0x100000000; }.bind({ counter: 0 }); } /** * Shuffles an array deterministically based on a key. * @param {EncryptionKey} key - The encryption key used for deterministic shuffling. * @param {Array} inputArray - The array to shuffle. * @returns {Array} - The shuffled array. */ function shuffle(key, inputArray) { if (!key || !Array.isArray(inputArray)) { throw new Error("Invalid input to shuffle function"); } // Create a deterministic random number generator based on the key const rng = createRngFromKey(key); // Create an array of indices and shuffle them const indices = inputArray.map((_, index) => index); for (let i = indices.length - 1; i > 0; i--) { const j = Math.floor(rng() * (i + 1)); [indices[i], indices[j]] = [indices[j], indices[i]]; } // Use the shuffled indices to reorder the input array return indices.map(index => inputArray[index]); } /** * Unshuffles an array that was shuffled deterministically based on a key. * @param {EncryptionKey} key - The encryption key used for deterministic shuffling. * @param {Array} shuffledArray - The array to unshuffle. * @returns {Array} - The unshuffled array. */ /** * Reverses the shuffling of an array based on a given key. * * @param {string} key - A string used to seed the deterministic random number generator. * @param {Array} shuffledArray - The array that was previously shuffled and needs to be restored to its original order. * @returns {Array} - The original array restored to its unshuffled order. * @throws {Error} - Throws an error if the key is not provided or if the shuffledArray is not an array. * * @description * This function assumes that the array was shuffled using a deterministic algorithm * based on the same key. It recreates the shuffle permutation using a seeded random * number generator and then reverses the shuffle to restore the original order. * * Note: The function relies on the existence of `createRngFromKey`, which must generate * a deterministic random number generator seeded by the provided key. */ function unshuffle(key, shuffledArray) { if (!key || !Array.isArray(shuffledArray)) { throw new Error("Invalid input to unshuffle function"); } // Create a deterministic random number generator based on the key const rng = createRngFromKey(key); // First recreate the exact same permutation that was used in the shuffle function const indices = Array.from({ length: shuffledArray.length }, (_, i) => i); for (let i = indices.length - 1; i > 0; i--) { const j = Math.floor(rng() * (i + 1)); [indices[i], indices[j]] = [indices[j], indices[i]]; } // Create a mapping from shuffled position to original position const reverseMap = new Array(indices.length); for (let i = 0; i < indices.length; i++) { reverseMap[indices[i]] = i; } // Use the mapping to restore the original order return shuffledArray.map((_, i) => shuffledArray[reverseMap[i]]); } /** * Computes an authentication hash for a vector embedding. * @param {VectorEncryptionKey} key - The encryption key. * @param {number} approximationFactor - The approximation factor. * @param {Buffer} iv - The initialization vector. * @param {Array<number>} encryptedVector - The encrypted vector. * @returns {AuthHash} - The computed authentication hash. */ function computeAuthHash(key, approximationFactor, iv, encryptedVector) { const hmac = crypto.createHmac('sha256', key.key.getBytes()); hmac.update(Buffer.from(Float32Array.of(key.scalingFactor.getFactor()).buffer)); hmac.update(Buffer.from(Float32Array.of(approximationFactor).buffer)); hmac.update(iv); encryptedVector.forEach(val => { hmac.update(Buffer.from(Float32Array.of(val).buffer)); }); return new AuthHash(hmac.digest()); } /** * Encrypts a vector embedding. * @param {VectorEncryptionKey} key - The encryption key. * @param {number} approximationFactor - The approximation factor. * @param {Array<number>} vector - The plaintext vector. * @returns {Object} - The encryption result containing ciphertext, IV, and auth hash. */ function encryptVector(key, approximationFactor, vector) { if (!key || !key.scalingFactor) { throw new InvalidKeyError("Scaling factor is not initialized in the encryption key"); } if (key.scalingFactor.getFactor() === 0) { throw new InvalidKeyError("Scaling factor cannot be zero"); } const iv = crypto.randomBytes(12); const noiseVector = generateNoiseVector(key, iv, approximationFactor, vector.length); const ciphertext = vector.map((val, i) => key.scalingFactor.getFactor() * val + noiseVector[i]); if (!ciphertext.every(val => Number.isFinite(val))) { throw new Error("Overflow error: Embedding or approximation factor too large."); } const authHash = computeAuthHash(key, approximationFactor, iv, ciphertext); return { ciphertext, iv, authHash }; } /** * Decrypts an encrypted vector embedding. * @param {VectorEncryptionKey} key - The encryption key. * @param {number} approximationFactor - The approximation factor. * @param {Object} encryptedResult - The encryption result containing ciphertext, IV, and auth hash. * @returns {Array<number>} - The decrypted vector. */ function decryptVector(key, approximationFactor, encryptedResult) { if (key.scalingFactor.getFactor() === 0) { throw new InvalidKeyError("Scaling factor cannot be zero"); } const { ciphertext, iv, authHash } = encryptedResult; if (!computeAuthHash(key, approximationFactor, iv, ciphertext).equals(authHash)) { throw new DecryptError("Authentication hash mismatch"); } const noiseVector = generateNoiseVector(key, iv, approximationFactor, ciphertext.length); return ciphertext.map((val, i) => (val - noiseVector[i]) / key.scalingFactor.getFactor()); } var index$3 = /*#__PURE__*/Object.freeze({ __proto__: null, AuthHash: AuthHash, computeAuthHash: computeAuthHash, decryptVector: decryptVector, encryptVector: encryptVector, generateNoiseVector: generateNoiseVector, sampleNormalVector: sampleNormalVector, sampleUniformPoint: sampleUniformPoint, shuffle: shuffle, unshuffle: unshuffle }); /** * Enumeration for EDEK Types */ const EdekType = Object.freeze({ STANDALONE: "Standalone", SAAS_SHIELD: "SaasShield", DATA_CONTROL_PLATFORM: "DataControlPlatform" }); /** * Enumeration for Payload Types */ const PayloadType = Object.freeze({ DETERMINISTIC_FIELD: "DeterministicField", VECTOR_METADATA: "VectorMetadata", STANDARD_EDEK: "StandardEdek" }); /** * Represents the Key ID Header */ class KeyIdHeader { /** * @param {number} keyId - The key ID (integer). * @param {string} edekType - The EDEK type (from EdekType). * @param {string} payloadType - The payload type (from PayloadType). */ constructor(keyId, edekType, payloadType) { if (typeof keyId !== 'number') { throw new TypeError("keyId must be a number"); } if (!Object.values(EdekType).includes(edekType)) { throw new TypeError("edekType must be a valid EdekType value"); } if (!Object.values(PayloadType).includes(payloadType)) { throw new TypeError("payloadType must be a valid PayloadType value"); } this.keyId = keyId; this.edekType = edekType; this.payloadType = payloadType; } /** * Creates a KeyIdHeader instance. * @param {string} edekType - The EDEK type. * @param {string} payloadType - The payload type. * @param {number} keyId - The key ID. * @returns {KeyIdHeader} */ static createHeader(edekType, payloadType, keyId) { return new KeyIdHeader(keyId, edekType, payloadType); } /** * Serializes the KeyIdHeader to bytes. * @returns {Buffer} */ writeToBytes() { const buffer = Buffer.alloc(6); buffer.writeUInt32BE(this.keyId, 0); // Write keyId (4 bytes) buffer.writeUInt8(this._encodeTypeByte(), 4); // Write encoded type byte buffer.writeUInt8(0, 5); // Padding byte return buffer; } /** * Parses bytes and reconstructs a KeyIdHeader instance. * @param {Buffer} headerBytes - The serialized header bytes. * @returns {KeyIdHeader} */ static parseFromBytes(headerBytes) { if (headerBytes.length !== 6) { throw new InvalidInputError(`Header bytes must be 6 bytes long, got ${headerBytes.length}`); } const keyId = headerBytes.readUInt32BE(0); // Read keyId (4 bytes) const typeByte = headerBytes.readUInt8(4); // Read type byte const paddingByte = headerBytes.readUInt8(5); // Read padding byte if (paddingByte !== 0) { throw new InvalidInputError(`Padding byte in header is not zero: ${paddingByte}`); } const { edekType, payloadType } = this._decodeTypeByte(typeByte); return new KeyIdHeader(keyId, edekType, payloadType); } /** * Encodes EDEK type and Payload type into a single byte. * @returns {number} */ _encodeTypeByte() { const edekNumeric = Object.values(EdekType).indexOf(this.edekType) << 4; // Shift EDEK type to top 4 bits const payloadNumeric = Object.values(PayloadType).indexOf(this.payloadType); // Payload type in bottom 4 bits return edekNumeric | payloadNumeric; } /** * Decodes the type byte back to EDEK type and Payload type. * @param {number} typeByte - The encoded type byte. * @returns {{ edekType: string, payloadType: string }} */ static _decodeTypeByte(typeByte) { const edekTypeIndex = (typeByte & 0xF0) >> 4; // Extract top 4 bits const payloadTypeIndex = typeByte & 0x0F; // Extract bottom 4 bits const edekType = Object.values(EdekType)[edekTypeIndex]; const payloadType = Object.values(PayloadType)[payloadTypeIndex]; if (!edekType || !payloadType) { throw new InvalidInputError("Invalid type byte encoding"); } return { edekType, payloadType }; } } /** * Represents Vector Metadata, including IV and AuthHash. */ class VectorMetadata { /** * @param {KeyIdHeader} keyIdHeader - The KeyIdHeader instance. * @param {Buffer} iv - The initialization vector. * @param {AuthHash} authHash - The authentication hash. */ constructor(keyIdHeader, iv, authHash) { if (!(keyIdHeader instanceof KeyIdHeader)) { throw new TypeError("keyIdHeader must be an instance of KeyIdHeader"); } if (!Buffer.isBuffer(iv)) { throw new TypeError("iv must be a Buffer"); } if (!(authHash instanceof AuthHash)) { throw new TypeError("authHash must be an instance of AuthHash"); } this.keyIdHeader = keyIdHeader; this.iv = iv; this.authHash = authHash; } } /** * Encodes vector metadata into bytes. * @param {KeyIdHeader} keyIdHeader - The KeyIdHeader instance. * @param {Buffer} iv - The initialization vector. * @param {AuthHash} authHash - The authentication hash. * @returns {Buffer} */ function encodeVectorMetadata(keyIdHeader, iv, authHash) { return Buffer.concat([keyIdHeader.writeToBytes(), iv, authHash.getBytes()]); } /** * Decodes a byte stream with a prefixed KeyIdHeader. * @param {Buffer} valueBytes - The byte stream. * @returns {{ keyIdHeader: KeyIdHeader, remainingBytes: Buffer }} */ function decodeVersionPrefixedValue(valueBytes) { if (valueBytes.length < 6) { throw new InvalidInputError("Value bytes too short to contain KeyIdHeader"); } const headerBytes = valueBytes.subarray(0, 6); const remainingBytes = valueBytes.subarray(6); const keyIdHeader = KeyIdHeader.parseFromBytes(headerBytes); return { keyIdHeader, remainingBytes }; } var index$2 = /*#__PURE__*/Object.freeze({ __proto__: null, EdekType: EdekType, KeyIdHeader: KeyIdHeader, PayloadType: PayloadType, VectorMetadata: VectorMetadata, decodeVersionPrefixedValue: decodeVersionPrefixedValue, encodeVectorMetadata: encodeVectorMetadata }); /** * Abstract KeyProvider class for managing cryptographic keys. */ class KeyProvider { /** * Retrieves a key from the provider. * @param {string} [keyId] - The identifier for the key to retrieve. * @returns {Promise<Buffer>} - The raw key material as a Buffer. * @throws {Error} - If the key is not found or cannot be accessed. */ async getKey(keyId) { throw new Error("getKey method must be implemented by subclasses"); } /** * Stores a key in the provider. * @param {Buffer} keyMaterial - The raw key to store. * @param {string} [keyId] - Optional identifier for the key. * @returns {Promise<string>} - The identifier assigned to the stored key. * @throws {Error} - If the key cannot be stored. */ async storeKey(keyMaterial, keyId) { throw new Error("storeKey method must be implemented by subclasses"); } } /** * ClientKeyProvider class for managing keys in an in-memory store (e.g., Zustand). */ class ClientKeyProvider extends KeyProvider { /** * @param {Object} keyStore - The in-memory key store (e.g., Zustand store). */ constructor(keyStore) { super(); if (typeof keyStore !== "object" || keyStore === null) { throw new TypeError("keyStore must be a valid object"); } this.keyStore = keyStore; } /** * Retrieves a key from the in-memory store. * @param {string} [keyId] - The identifier for the key to retrieve. * @returns {Promise<Buffer>} - The raw key material as a Buffer. * @throws {Error} - If the key is not found. */ async getKey(keyId) { const key = this.keyStore[keyId || "default"]; if (!key) { throw new Error(`Key not found: ${keyId || "default"}`); } return Buffer.from(key, "base64"); } /** * Stores a key in the in-memory store. * @param {Buffer} keyMaterial - The raw key to store. * @param {string} [keyId] - Optional identifier for the key. * @returns {Promise<string>} - The identifier assigned to the stored key. */ async storeKey(keyMaterial, keyId) { const actualKeyId = keyId || "default"; this.keyStore[actualKeyId] = keyMaterial.toString("base64"); return actualKeyId; } } /** * LocalKeyProvider for managing keys locally * This provider stores keys in memory only for the lifetime of the instance */ class LocalKeyProvider extends KeyProvider { /** * @param {Object} config - Configuration options */ constructor() { super(); this.keys = {}; this.currentKey = null; } /** * Retrieves a key from the local store * @param {string} [keyId] - The identifier for the key to retrieve * @returns {Promise<Buffer>} - The raw key material as a Buffer * @throws {Error} - If the key is not found */ async getKey(keyId) { const key = this.keys[keyId || "default"] || this.currentKey; if (!key) { throw new Error(`Key not found: ${keyId || "default"}`); } return key; } /** * Stores a key in the local store * @param {Buffer} keyMaterial - The raw key to store * @param {string} [keyId] - Optional identifier for the key * @returns {Promise<string>} - The identifier assigned to the stored key */ async storeKey(keyMaterial, keyId) { const actualKeyId = keyId || "default"; this.keys[actualKeyId] = keyMaterial; return actualKeyId; } /** * Sets the encryption keys * @param {Object|Buffer} encryptionKeys - Encryption keys to set */ setKeys(encryptionKeys) { if (Buffer.isBuffer(encryptionKeys)) { this.currentKey = encryptionKeys; } else if (encryptionKeys && encryptionKeys instanceof VectorEncryptionKey) { // Preserve VectorEncryptionKey instance directly this.currentKey = encryptionKeys; } else if (encryptionKeys && typeof encryptionKeys === 'object') { // Handle objects that might be serialized VectorEncryptionKey if (encryptionKeys.scalingFactor && encryptionKeys.key) { // Attempt to reconstruct a VectorEncryptionKey const scalingFactor = new ScalingFactor(encryptionKeys.scalingFactor.factor || encryptionKeys.scalingFactor); const key = new EncryptionKey(Buffer.isBuffer(encryptionKeys.key.keyBytes) ? encryptionKeys.key.keyBytes : Buffer.from(encryptionKeys.key.keyBytes || encryptionKeys.key)); this.currentKey = new VectorEncryptionKey(scalingFactor, key); } else { // Fallback to previous behavior this.currentKey = encryptionKeys.key || encryptionKeys; } } else { throw new TypeError("Invalid encryption keys format"); } } /** * Gets the current encryption keys * @returns {Buffer} - The raw key material */ getKeys() { if (!this.currentKey) { throw new Error("No encryption keys have been set"); } return this.currentKey; } } var index$1 = /*#__PURE__*/Object.freeze({ __proto__: null, ClientKeyProvider: ClientKeyProvider, KeyProvider: KeyProvider, LocalKeyProvider: LocalKeyProvider }); /** * HMAC-based Key Derivation Function implementation. * * @param {Buffer} ikm - Input key material * @param {number} length - The desired length of the derived key * @param {Buffer} salt - Optional salt value (recommended) * @param {Buffer} info - Optional context and application specific information * @returns {Buffer} The derived key */ function hkdf(ikm, length, salt, info) { if (!Buffer.isBuffer(ikm)) { throw new TypeError('Input key material must be a Buffer'); } // Default values salt = salt || Buffer.alloc(0); info = info || Buffer.alloc(0); // Step 1: Extract const prk = crypto.createHmac('sha256', salt).update(ikm).digest(); // Step 2: Expand const result = Buffer.alloc(length); let previous = Buffer.alloc(0); let resultPosition = 0; const hashLen = 32; // SHA-256 hash length for (let i = 1; resultPosition < length; i++) { const hmac = crypto.createHmac('sha256', prk); hmac.update(Buffer.concat([previous, info, Buffer.from([i])])); const next = hmac.digest(); const remainder = Math.min(length - resultPosition, hashLen); next.copy(result, resultPosition, 0, remainder); previous = next; resultPosition += remainder; } return result; } /** * RagEncryptionClient: High-level interface for encryption and decryption. */ class RagEncryptionClient { /** * Initializes the RagEncryptionClient with encryption keys. * @param {Buffer|null} encryptionKey - Raw encryption key bytes. * @param {number} approximationFactor - Approximation factor for vector encryption. * @param {KeyProvider} keyProvider - Optional key provider implementation. * @param {string} keyId - Optional key identifier to use with the key provider. * @param {boolean} _skipValidation - Internal flag to skip validation (used by create method). */ constructor() { let encryptionKey = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : null; let approximationFactor = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 1.0; let keyProvider = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : null; let keyId = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : null; let _skipValidation = arguments.length > 4 && arguments[4] !== undefined ? arguments[4] : false; // Skip validation if internal flag is set (used by create method) if (_skipValidation) { return; } // For direct key initialization only - use static create() for key provider if (keyProvider || keyId) { throw new InvalidInputError("For async key provider initialization, use the static create() method"); } if (!encryptionKey) { throw new InvalidInputError("Encryption key must be provided when using constructor directly"); } this._initializeWithKey(encryptionKey, approximationFactor); } /** * Creates an instance of RagEncryptionClient. */ static async create() { let encryptionKey = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : null; let approximationFactor = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 1.0; let keyProvider = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : null; let keyId = arguments.length > 3 && arguments[3] !== undefined ? arguments[3] : null; // Create a new instance with validation skipped const client = new RagEncryptionClient(null, 1.0, null, null, true); if (keyProvider && keyId) { // Rest remains the same... // Check for the getKey method instead of instanceof if (typeof keyProvider.getKey !== 'function') { throw new TypeError("keyProvider must have a getKey method"); } await client._initializeWithKeyProvider(keyProvider, keyId, approximationFactor); } else if (encryptionKey) { client._initializeWithKey(encryptionKey, approximationFactor); } else { throw new InvalidInputError("Either encryptionKey or (keyProvider and keyId) must be provided"); } return client; } /** * Initialize client with direct key material * @private */ _initializeWithKey(encryptionKey, approximationFactor) { if (!Buffer.isBuffer(encryptionKey) || encryptionKey.length < 32) { throw new InvalidInputError("Encryption key must be a Buffer of at least 32 bytes"); } if (typeof approximationFactor !== 'number') { throw new InvalidInputError("Approximation factor must be a number"); } this.vectorEncryptionKey = new VectorEncryptionKey(new ScalingFactor(approximationFactor), new EncryptionKey(encryptionKey)); this.textEncryptionKey = new EncryptionKey(encryptionKey); this.deterministicEncryptionKey = new EncryptionKey(encryptionKey); this.approximationFactor = approximationFactor; this.keyId = "local-key"; this.keyProvider = null; } /** * Initialize client with key provider * @private */ async _initializeWithKeyProvider(keyProvider, keyId, approximationFactor) { try { const encryptionKey = await keyProvider.getKey(keyId); this._initializeWithKey(encryptionKey, approximationFactor); this.keyId = keyId; this.keyProvider = keyProvider; } catch (error) { throw new InvalidInputError(`Failed to get key from provider: ${error.message}`); } } // Update rotateKey to support key provider /** * Rotate to a new encryption key. * @param {Buffer} newKeyMaterial - New raw encryption key bytes (optional if using key provider). * @param {string} newKeyId - New key identifier to use with the current key provider (optional). */ async rotateKey() { let newKeyMaterial = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : null; let newKeyId = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : null; // Store old keys for reference this._oldVectorEncryptionKey = this.vectorEncryptionKey; this._oldTextEncryptionKey = this.textEncryptionKey; this._oldDeterministicEncryptionKey = this.deterministicEncryptionKey; // Get new key material let newKey; if (this.keyProvider && newKeyId) { // Get from provider if available try { newKey = await this.keyProvider.getKey(newKeyId); this.keyId = newKeyId; } catch (error) { throw new InvalidInputError(`Failed to get new key from provider: ${error.message}`); } } else if (newKeyMaterial) { // Use directly provided material if (!Buffer.isBuffer(newKeyMaterial) || newKeyMaterial.length < 32) { throw new InvalidInputError("New key material must be a Buffer of at least 32 bytes"); } newKey = newKeyMaterial; } else { throw new InvalidInputError("Either newKeyMaterial or newKeyId must be provided"); } // Update current keys with new material this.vectorEncryptionKey = new VectorEncryptionKey(new ScalingFactor(this.approximationFactor), new EncryptionKey(newKey)); this.textEncryptionKey = new EncryptionKey(newKey); this.deterministicEncryptionKey = new EncryptionKey(newKey); } /** * Encrypts a vector embedding. * @param {Array<number>} plaintextVector - The plaintext vector to encrypt. * @returns {[Array<number>, Buffer]} - A tuple containing the encrypted vector and metadata. */ encryptVector(plaintextVector) { if (!Array.isArray(plaintextVector) || !plaintextVector.every(x => typeof x === 'number')) { throw new InvalidInputError("Plaintext vector must be an array of numbers"); } // Shuffle the plaintext vector const shuffledVector = shuffle(this.textEncryptionKey, plaintextVector); // Encrypt the shuffled vector const encryptResult = encryptVector(this.vectorEncryptionKey, this.approximationFactor, shuffledVector); // Generate metadata const keyIdHeader = new KeyIdHeader((typeof this.keyId === 'string' ? this.keyId.split('').reduce((a, c) => a + c.charCodeAt(0), 0) : 1) % 9999, "Standalone", "VectorMetadata"); const metadata = encodeVectorMetadata(keyIdHeader, encryptResult.iv, encryptResult.authHash); // Return tuple format like Python return [encryptResult.ciphertext, metadata]; } /** * Decrypts an encrypted vector embedding. * @param {Array<number>} encryptedVector - The encrypted vector. * @param {Buffer} pairedIclInfo - The metadata associated with the encrypted vector. * @returns {Array<number>} - The decrypted plaintext vector. */ decryptVector(encryptedVector, pairedIclInfo) { if (!Array.isArray(encryptedVector) || !encryptedVector.every(x => typeof x === 'number')) { throw new InvalidInputError("Encrypted vector must be an array of numbers"); } if (!Buffer.isBuffer(pairedIclInfo)) { throw new InvalidInputError("Metadata must be a Buffer"); } /// Decode metadata const { keyIdHeader, remainingBytes } = decodeVersionPrefixedValue(pairedIclInfo); const iv = remainingBytes.subarray(0, 12); const authHashBytes = remainingBytes.subarray(12); // Convert the Buffer to an AuthHash object const authHash = new AuthHash(authHashBytes); // Decrypt the vector const shuffledVector = decryptVector(this.vectorEncryptionKey, this.approximationFactor, { ciphertext: encryptedVector, iv, authHash }); // Unshuffle the vector return unshuffle(this.textEncryptionKey, shuffledVector); } /** * Encrypts a text string using AES-GCM. * @param {string} plaintext - The plaintext string to encrypt. * @returns {Object} - Encrypted text, IV, and authentication tag. */ encryptText(plaintext) { if (typeof plaintext !== 'string') { throw new InvalidInputError("Plaintext must be a string."); } // Ensure the key is exactly 32 bytes const key = this.textEncryptionKey.getBytes().subarray(0, 32); const iv = crypto.randomBytes(12); const cipher = crypto.createCipheriv('aes-256-gcm', key, iv); const ciphertext = Buffer.concat([cipher.update(plaintext, 'utf8'), cipher.final()]); const tag = cipher.getAuthTag(); return { ciphertext, iv, tag }; } /** * Decrypts an AES-GCM encrypted text. * @param {Buffer} ciphertext - The encrypted text. * @param {Buffer} iv - The initialization vector. * @param {Buffer} tag - The authentication tag. * @returns {string} - The decrypted plaintext string. */ decryptText(ciphertext, iv, tag) { if (!Buffer.isBuffer(ciphertext) || !Buffer.isBuffer(iv) || !Buffer.isBuffer(tag)) { throw new InvalidInputError("Ciphertext, IV, and tag must be Buffers."); } // Ensure the key is exactly 32 bytes const key = this.textEncryptionKey.getBytes().subarray(0, 32); const decipher = crypto.createDecipheriv('aes-256-gcm', key, iv); decipher.setAuthTag(tag); const plaintext = Buffer.concat([decipher.update(ciphertext), decipher.final()]); return plaintext.toString('utf8'); } /** * Encrypts text deterministically using AES-GCM, mirroring Python implementation * with HKDF key derivation and deterministic nonce generation. * @param {string} plaintext - The plaintext string to encrypt. * @returns {Buffer} - The encrypted text with metadata. */ encryptDeterministicText(plaintext) { if (typeof plaintext !== 'string') { throw new InvalidInputError("Plaintext must be a string"); } // 1. Derive key using HKDF similarly to Python implementation const salt = Buffer.from('DCPE-Deterministic'); const info = Buffer.from('deterministic_encryption_key'); const derivedKey = hkdf(this.deterministicEncryptionKey.getBytes(), 32, salt, info); // 2. Create deterministic nonce using HMAC from plaintext const hmac = crypto.createHmac('sha256', derivedKey); hmac.update(Buffer.from(plaintext, 'utf8')); const deterministicNonce = hmac.digest().subarray(0, 12); // 3. Encrypt with AES-GCM using the derived key and deterministic nonce const cipher = crypto.createCipheriv('aes-256-gcm', derivedKey, deterministicNonce); const ciphertext = Buffer.concat([cipher.update(plaintext, 'utf8'), cipher.final()]); const tag = cipher.getAuthTag(); // 4. Match Python's output format: nonce + ciphertext + tag return Buffer.concat([deterministicNonce, ciphertext, tag]); } /** * Decrypts deterministically encrypted text, matching Python implementation. * @param {Buffer} encryptedData - The encrypted text. * @returns {string} - The decrypted plaintext string. */ decryptDeterministicText(encryptedData) { if (!Buffer.isBuffer(encryptedData)) { throw new InvalidInputError("Encrypted data must be a Buffer"); } if (encryptedData.length < 28) { // 12 (nonce) + 16 (min tag size) throw new InvalidInputError("Encrypted data too short"); } // 1. Split components: nonce + ciphertext + tag const nonce = encryptedData.subarray(0, 12); const ciphertext = encryptedData.subarray(12, encryptedData.length - 16); const tag = encryptedData.subarray(encryptedData.length - 16); // 2. Derive the same key used for encryption const salt = Buffer.from('DCPE-Deterministic'); const info = Buffer.from('deterministic_encryption_key'); const derivedKey = hkdf(this.deterministicEncryptionKey.getBytes(), 32, salt, info); // 3. Decrypt with AES-GCM try { const decipher = crypto.createDecipheriv('aes-256-gcm', derivedKey, nonce); decipher.setAuthTag(tag); const plaintext = Buffer.concat([decipher.update(ciphertext), decipher.final()]); return plaintext.toString('utf8'); } catch (e) { throw new DecryptError(`Deterministic text decryption failed: ${e.message}`); } } } // After the RagEncryptionClient class and its export: /** * Encryption client instance for utility functions * @type {RagEncryptionClient} * @private */ let _clientInstance = null; /** * Get or initialize the client instance * @param {Buffer|VectorEncryptionKey} keys - Encryption keys * @returns {RagEncryptionClient} * @private */ function _getClientInstance(keys) { if (!_clientInstance) { // Extract raw key material if keys is a VectorEncryptionKey let keyMaterial; if (keys && typeof keys === 'object' && keys.key && typeof keys.key.getBytes === 'function') { // If it's a VectorEncryptionKey object, extract the underlying key bytes keyMaterial = keys.key.getBytes(); } else if (Buffer.isBuffer(keys)) { // If it's already a Buffer, use it directly keyMaterial = keys; } else { throw new InvalidInputError('Invalid key format: expected Buffer or VectorEncryptionKey'); } _clientInstance = new RagEncryptionClient(keyMaterial); } return _clientInstance; } /** * Encrypts a text string using AES-GCM * @param {string} text - Text to encrypt * @param {Buffer} keys - Encryption keys * @param {Object} options - Encryption options * @returns {Object} - Encrypted text, IV, and authentication tag */ function encryptText(text, keys) { const client = _getClientInstance(keys); return client.encryptText(text); } /** * Decrypts an AES-GCM encrypted text * @param {Object} encryptedText - Encrypted text object with ciphertext, iv, and tag * @param {Buffer} keys - Encryption keys * @param {Object} options - Decryption options * @returns {string} - Decrypted text */ function decryptText(encryptedText, keys) { const client = _getClientInstance(keys); return client.decryptText(encryptedText.ciphertext, encryptedText.iv, encryptedText.tag); } /** * Encrypts text deterministically using AES-GCM * @param {string} value - Value to encrypt * @param {Buffer} keys - Encryption keys * @param {Object} options - Encryption options * @returns {Buffer} - Encrypted value */ function encryptDeterministicText(value, keys) { const client = _getClientInstance(keys); return client.encryptDeterministicText(value); } /** * Decrypts deterministically encrypted text * @param {Buffer} encryptedValue - Encrypted value * @param {Buffer} keys - Encryption keys * @param {Object} options - Decryption options