lsh-index
Version:
Locality-Sensitive Hashing implementation for indexing vectors using random projections
106 lines (105 loc) • 4.29 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", { value: true });
exports.LSH = void 0;
const crypto_1 = require("crypto");
class LSH {
constructor(options) {
this.options = options;
if (options.numProjections % options.numBands !== 0) {
throw new Error(`Number of projections (${options.numProjections}) must be a multiple of number of bands (${options.numBands})`);
}
this.bucketSize = options.bucketSize || 4;
this.rowsPerBand = Math.floor(options.numProjections / options.numBands);
this.projectionVectors = this.initializeProjections();
this.buckets = Array(options.numBands)
.fill(null)
.map(() => new Map());
this.vectors = new Map();
this.distanceMetric = options.distanceMetric || this.euclideanDistance;
}
initializeProjections() {
return Array(this.options.numProjections)
.fill(0)
.map(() => {
// Generate random unit vector for projection
const vector = Array(this.options.dimensions)
.fill(0)
.map(() => Math.random() * 2 - 1);
// Normalize to unit vector
const magnitude = Math.sqrt(vector.reduce((sum, val) => sum + val * val, 0));
return vector.map((v) => v / magnitude);
});
}
projectVector(vector) {
return this.projectionVectors.map((projVector) => vector.reduce((sum, val, idx) => sum + val * projVector[idx], 0));
}
getBucketHash(projections, bandIndex) {
const start = bandIndex * this.rowsPerBand;
const bandProjections = projections
.slice(start, start + this.rowsPerBand)
// Quantize projections into discrete buckets
.map((p) => Math.floor(p * this.bucketSize));
const hash = (0, crypto_1.createHash)("sha256");
hash.update(bandProjections.join(":"));
return hash.digest("hex");
}
insert(params) {
const { id, vector } = params;
if (vector.length !== this.options.dimensions) {
throw new Error(`Vector must have ${this.options.dimensions} dimensions`);
}
// Store original vector
this.vectors.set(id, vector);
// Project vector and hash to buckets
const projections = this.projectVector(vector);
for (let i = 0; i < this.options.numBands; i++) {
const bucketHash = this.getBucketHash(projections, i);
if (!this.buckets[i].has(bucketHash)) {
this.buckets[i].set(bucketHash, new Set());
}
this.buckets[i].get(bucketHash).add(id);
}
}
query(params) {
const { vector, maxDistance } = params;
if (vector.length !== this.options.dimensions) {
throw new Error(`Query vector must have ${this.options.dimensions} dimensions`);
}
const projections = this.projectVector(vector);
const candidates = new Set();
// Collect candidates from all bands
for (let i = 0; i < this.options.numBands; i++) {
const bucketHash = this.getBucketHash(projections, i);
const bucket = this.buckets[i].get(bucketHash);
if (bucket) {
bucket.forEach((id) => candidates.add(id));
}
}
// Filter candidates by actual distance
return Array.from(candidates).filter((id) => {
const candidateVector = this.vectors.get(id);
return this.distanceMetric(vector, candidateVector) <= maxDistance;
});
}
euclideanDistance(v1, v2) {
return Math.sqrt(v1.reduce((sum, val, idx) => sum + Math.pow(val - v2[idx], 2), 0));
}
clear() {
this.buckets = Array(this.options.numBands)
.fill(null)
.map(() => new Map());
this.vectors.clear();
}
export() {
return {
options: this.options,
projectionVectors: this.projectionVectors,
buckets: this.buckets.map((bucket) => Array.from(bucket.entries()).map(([key, value]) => [
key,
Array.from(value),
])),
vectors: Array.from(this.vectors.entries()),
};
}
}
exports.LSH = LSH;