lsh-index
Version:
Locality-Sensitive Hashing implementation for indexing vectors using random projections
144 lines (125 loc) • 4.72 kB
text/typescript
import { createHash } from "crypto";
export interface LSHOptions {
dimensions: number; // Input vector dimensions
numProjections: number; // Number of random projections
numBands: number; // Number of bands for bucketing
bucketSize?: number; // Size of each bucket for quantization (default: 4)
distanceMetric?: (v1: number[], v2: number[]) => number; // Distance metric to use (default: Euclidean)
}
export class LSH {
private projectionVectors: number[][];
private buckets: Map<string, Set<string>>[];
private vectors: Map<string, number[]>;
private readonly bucketSize: number;
private readonly rowsPerBand: number;
private distanceMetric: (v1: number[], v2: number[]) => number;
constructor(private options: LSHOptions) {
if (options.numProjections % options.numBands !== 0) {
throw new Error(
`Number of projections (${options.numProjections}) must be a multiple of number of bands (${options.numBands})`,
);
}
this.bucketSize = options.bucketSize || 4;
this.rowsPerBand = Math.floor(options.numProjections / options.numBands);
this.projectionVectors = this.initializeProjections();
this.buckets = Array(options.numBands)
.fill(null)
.map(() => new Map());
this.vectors = new Map();
this.distanceMetric = options.distanceMetric || this.euclideanDistance;
}
private initializeProjections(): number[][] {
return Array(this.options.numProjections)
.fill(0)
.map(() => {
// Generate random unit vector for projection
const vector = Array(this.options.dimensions)
.fill(0)
.map(() => Math.random() * 2 - 1);
// Normalize to unit vector
const magnitude = Math.sqrt(
vector.reduce((sum, val) => sum + val * val, 0),
);
return vector.map((v) => v / magnitude);
});
}
private projectVector(vector: number[]): number[] {
return this.projectionVectors.map((projVector) =>
vector.reduce((sum, val, idx) => sum + val * projVector[idx], 0),
);
}
private getBucketHash(projections: number[], bandIndex: number): string {
const start = bandIndex * this.rowsPerBand;
const bandProjections = projections
.slice(start, start + this.rowsPerBand)
// Quantize projections into discrete buckets
.map((p) => Math.floor(p * this.bucketSize));
const hash = createHash("sha256");
hash.update(bandProjections.join(":"));
return hash.digest("hex");
}
insert(params: { id: string; vector: number[] }): void {
const { id, vector } = params;
if (vector.length !== this.options.dimensions) {
throw new Error(`Vector must have ${this.options.dimensions} dimensions`);
}
// Store original vector
this.vectors.set(id, vector);
// Project vector and hash to buckets
const projections = this.projectVector(vector);
for (let i = 0; i < this.options.numBands; i++) {
const bucketHash = this.getBucketHash(projections, i);
if (!this.buckets[i].has(bucketHash)) {
this.buckets[i].set(bucketHash, new Set());
}
this.buckets[i].get(bucketHash)!.add(id);
}
}
query(params: { vector: number[]; maxDistance: number }): string[] {
const { vector, maxDistance } = params;
if (vector.length !== this.options.dimensions) {
throw new Error(
`Query vector must have ${this.options.dimensions} dimensions`,
);
}
const projections = this.projectVector(vector);
const candidates = new Set<string>();
// Collect candidates from all bands
for (let i = 0; i < this.options.numBands; i++) {
const bucketHash = this.getBucketHash(projections, i);
const bucket = this.buckets[i].get(bucketHash);
if (bucket) {
bucket.forEach((id) => candidates.add(id));
}
}
// Filter candidates by actual distance
return Array.from(candidates).filter((id) => {
const candidateVector = this.vectors.get(id)!;
return this.distanceMetric(vector, candidateVector) <= maxDistance;
});
}
private euclideanDistance(v1: number[], v2: number[]): number {
return Math.sqrt(
v1.reduce((sum, val, idx) => sum + Math.pow(val - v2[idx], 2), 0),
);
}
clear(): void {
this.buckets = Array(this.options.numBands)
.fill(null)
.map(() => new Map());
this.vectors.clear();
}
export() {
return {
options: this.options,
projectionVectors: this.projectionVectors,
buckets: this.buckets.map((bucket) =>
Array.from(bucket.entries()).map(([key, value]) => [
key,
Array.from(value),
]),
),
vectors: Array.from(this.vectors.entries()),
};
}
}