UNPKG

superminhash

Version:

TypeScript implementation of the SuperMinHash algorithm for Jaccard similarity estimation

265 lines (213 loc) 9.18 kB
import seedrandom from 'seedrandom'; export type HashableElement = unknown; type ElementProcessingContext = { positions: number[]; processedElements: number[]; bucketCounts: number[]; maxBucketIndex: number; }; export class SuperMinHash { public static readonly DEFAULT_SIGNATURE_SIZE = 256; public static readonly DEFAULT_SEED = 42; private readonly signature: Uint32Array; private empty = true; private static readonly MAX_INPUT_LENGTH = 100000; private static readonly MAX_HASH_VALUE = 0xffffffff; constructor( public readonly signatureSize: number = SuperMinHash.DEFAULT_SIGNATURE_SIZE, private readonly seed = SuperMinHash.DEFAULT_SEED, ) { if (this.signatureSize <= 0 || !Number.isInteger(this.signatureSize)) { throw new Error('Signature size must be a positive integer'); } this.signature = new Uint32Array(this.signatureSize).fill(SuperMinHash.MAX_HASH_VALUE); this.empty = true; } private generateSeedString(element: HashableElement): string { const serialized = typeof element === 'string' ? element : JSON.stringify(element); if (serialized.length > SuperMinHash.MAX_INPUT_LENGTH) { throw new Error(`Input exceeds maximum length of ${SuperMinHash.MAX_INPUT_LENGTH} characters`); } return `${this.seed}:${serialized}`; } public add(elements: Iterable<HashableElement>): void { const emptyContext = this.initProcessingContext(); for (const element of elements) { this.empty = false; const elementSeedString = this.generateSeedString(element); const randomGenerator = seedrandom(elementSeedString); const processingContext = this.cloneProcessingContext(emptyContext); this.processElementWithContext(processingContext, randomGenerator); } } private initProcessingContext(): ElementProcessingContext { const m = this.signatureSize; return { positions: Array.from({ length: m }, (_, i) => i), processedElements: new Array(m).fill(-1), bucketCounts: [...new Array(m - 1).fill(0), m], maxBucketIndex: m - 1, }; } private cloneProcessingContext(context: ElementProcessingContext): ElementProcessingContext { return { positions: [...context.positions], processedElements: [...context.processedElements], bucketCounts: [...context.bucketCounts], maxBucketIndex: context.maxBucketIndex, }; } private processElementWithContext(context: ElementProcessingContext, randomGenerator: seedrandom.PRNG): void { const { positions, processedElements, bucketCounts } = context; let { maxBucketIndex } = context; let currentPosition = 0; while (currentPosition <= maxBucketIndex) { const randomValue = Math.floor(randomGenerator() * SuperMinHash.MAX_HASH_VALUE); const randomPosition = this.selectRandomPosition(currentPosition, this.signatureSize, randomGenerator); this.ensurePositionsInitialized(currentPosition, randomPosition, positions, processedElements); this.swapPositions(currentPosition, randomPosition, positions); const signaturePosition = positions[currentPosition]; maxBucketIndex = this.updateSignatureIfNeeded( currentPosition, randomValue, signaturePosition, bucketCounts, maxBucketIndex, ); currentPosition++; } } private selectRandomPosition(currentPosition: number, size: number, randomGenerator: seedrandom.PRNG): number { return currentPosition + Math.floor(randomGenerator() * (size - currentPosition)); } private ensurePositionsInitialized( pos1: number, pos2: number, positions: number[], processedElements: number[], ): void { if (processedElements[pos1] !== 0) { processedElements[pos1] = 0; positions[pos1] = pos1; } if (processedElements[pos2] !== 0) { processedElements[pos2] = 0; positions[pos2] = pos2; } } private swapPositions(pos1: number, pos2: number, positions: number[]): void { const temp = positions[pos1]; positions[pos1] = positions[pos2]; positions[pos2] = temp; } private updateSignatureIfNeeded( currentPosition: number, randomValue: number, signaturePosition: number, bucketCounts: number[], maxBucketIndex: number, ): number { const newValue = (randomValue + currentPosition) % SuperMinHash.MAX_HASH_VALUE; const currentValue = this.signature[signaturePosition]; if (newValue < currentValue) { const previousBucket = Math.min(currentValue, this.signatureSize - 1); this.signature[signaturePosition] = newValue; if (currentPosition < previousBucket) { bucketCounts[previousBucket]--; bucketCounts[currentPosition]++; return this.adjustMaxBucketIndex(maxBucketIndex, bucketCounts); } } return maxBucketIndex; } private adjustMaxBucketIndex(currentMax: number, bucketCounts: number[]): number { let newMax = currentMax; while (newMax > 0 && bucketCounts[newMax] === 0) { newMax--; } return newMax; } public similarity(other: SuperMinHash): number { if (this.empty || other.empty) { return this.empty && other.empty ? 1.0 : 0.0; } return this.getJaccardIndex(other); } public getJaccardIndex(other: SuperMinHash): number { if (this.seed !== other.seed) { throw new Error('Cannot compare signatures generated with different seeds'); } if (this.signatureSize !== other.signatureSize) { throw new Error('Can only compare signatures of the same size'); } return ( this.signature.reduce((acc, value, index) => { return acc + (value === other.signature[index] ? 1 : 0); }, 0) / this.signatureSize ); } public getSignature(): Uint32Array { return new Uint32Array(this.signature); } public isEmpty(): boolean { return this.empty; } public serialize(): Uint8Array { const metadataSize = 9; // 4 bytes for size, 4 for seed, 1 for empty const bufferSize = metadataSize + this.signatureSize * 4; const buffer = new ArrayBuffer(bufferSize); const view = new DataView(buffer); view.setUint32(0, this.signatureSize, true); view.setUint32(4, this.seed, true); view.setUint8(8, this.empty ? 0 : 1); let offset = metadataSize; for (let position = 0; position < this.signatureSize; position++) { view.setUint32(offset, this.signature[position], true); offset += 4; } return new Uint8Array(buffer); } public static deserialize(binary: Uint8Array): SuperMinHash { if (binary.length < 9) { throw new Error('Invalid binary data: too short'); } const view = new DataView(binary.buffer); const signatureSize = view.getUint32(0, true); if (signatureSize <= 0) { throw new Error('Invalid binary data: signature size must be positive'); } const expectedLength = 9 + signatureSize * 4; if (binary.length !== expectedLength) { throw new Error(`Invalid binary data: expected length ${expectedLength}, got ${binary.length}`); } const seed = view.getUint32(4, true); const empty = view.getUint8(8) === 0; const minhash = new SuperMinHash(signatureSize, seed); minhash.empty = empty; const metadataSize = 9; for (let position = 0; position < signatureSize; position++) { minhash.signature[position] = view.getUint32(metadataSize + position * 4, true); } return minhash; } public static compareSerialized(firstSignature: Uint8Array, secondSignature: Uint8Array): number { const firstMinHash = SuperMinHash.deserialize(firstSignature); const secondMinHash = SuperMinHash.deserialize(secondSignature); return firstMinHash.similarity(secondMinHash); } public static fromRawSignature(signature: Uint32Array, seed: number, empty = false): SuperMinHash { const minhash = new SuperMinHash(signature.length, seed); minhash.signature.set(signature); minhash.empty = empty; return minhash; } public static fromIterable( elements: Iterable<HashableElement>, signatureSize = SuperMinHash.DEFAULT_SIGNATURE_SIZE, seed = SuperMinHash.DEFAULT_SEED, ): SuperMinHash { const minhash = new SuperMinHash(signatureSize, seed); minhash.add(elements); return minhash; } }