@leolee9086/hnsw
Version:
JavaScript HNSW (Hierarchical Navigable Small World) 向量索引库,支持动态操作和泛型搜索
734 lines (729 loc) • 24.1 kB
JavaScript
// src/binary-heap.ts
var BinaryHeapGeneric = class {
constructor(data = [], compare = defaultCompare) {
this.data = data.slice();
this.size = this.data.length;
this.compare = compare;
if (this.size > 0) {
this.heapify();
}
}
get length() {
return this.size;
}
push(item) {
const data = this.data;
const pos = this.size++;
data[pos] = item;
this.upHeap(pos, data);
}
/**
* 批量插入 - 泛型版本的高效实现
*/
pushBulk(items) {
if (items.length === 0) return;
const currentSize = this.size;
const newSize = currentSize + items.length;
if (newSize > this.data.length) {
const newCapacity = Math.max(newSize, this.data.length * 2);
const newData = new Array(newCapacity);
for (let i = 0; i < currentSize; i++) {
newData[i] = this.data[i];
}
this.data = newData;
}
if (items.length < currentSize / 8) {
for (const item of items) {
this.push(item);
}
} else {
for (let i = 0; i < items.length; i++) {
this.data[currentSize + i] = items[i];
}
this.size = newSize;
this.heapify();
}
}
pop() {
const size = this.size;
if (size === 0) return void 0;
const data = this.data;
const top = data[0];
const newSize = size - 1;
this.size = newSize;
if (newSize > 0) {
data[0] = data[newSize];
this.downHeap(0, data, newSize);
}
return top;
}
peek() {
return this.data[0];
}
heapify() {
const data = this.data;
const size = this.size;
for (let i = (size >>> 1) - 1; i >= 0; i--) {
this.downHeap(i, data, size);
}
}
upHeap(pos, data) {
const item = data[pos];
const compare = this.compare;
while (pos > 0) {
const parent = pos - 1 >>> 1;
const parentValue = data[parent];
if (compare(item, parentValue) >= 0) break;
data[pos] = parentValue;
pos = parent;
}
data[pos] = item;
}
downHeap(pos, data, size) {
const halfLength = size >>> 1;
const item = data[pos];
const compare = this.compare;
while (pos < halfLength) {
let bestChild = (pos << 1) + 1;
let bestValue = data[bestChild];
const rightChild = bestChild + 1;
if (rightChild < size && compare(data[rightChild], bestValue) < 0) {
bestChild = rightChild;
bestValue = data[rightChild];
}
if (compare(item, bestValue) <= 0) break;
data[pos] = bestValue;
pos = bestChild;
}
data[pos] = item;
}
};
function defaultCompare(a, b) {
return a < b ? -1 : a > b ? 1 : 0;
}
// src/midi-heap.ts
var MidiHeapGeneric = class {
constructor(capacity, compare, initialData) {
this.size = 0;
if (capacity <= 0) {
throw new Error("Heap capacity must be greater than 0");
}
this.capacity = capacity;
this.data = new Array(capacity);
this.compare = compare;
if (initialData && initialData.length > 0) {
const initialSize = Math.min(initialData.length, this.capacity);
this.size = initialSize;
for (let i = 0; i < initialSize; i++) {
this.data[i] = initialData[i];
}
this.heapify();
}
}
get length() {
return this.size;
}
isFull() {
return this.size >= this.capacity;
}
/**
* 向堆中添加一个元素。
* @注意 如果堆已满,此操作会静默失败。请使用 isFull() 检查。
*/
push(item) {
if (this.isFull()) {
return;
}
const data = this.data;
const compare = this.compare;
let pos = this.size++;
while (pos > 0) {
const parent = pos - 1 >>> 1;
if (compare(item, data[parent]) >= 0) break;
data[pos] = data[parent];
pos = parent;
}
data[pos] = item;
}
pop() {
if (this.size === 0) return void 0;
const data = this.data;
const top = data[0];
const newSize = --this.size;
if (newSize > 0) {
data[0] = data[newSize];
this.downHeap(0, newSize);
}
return top;
}
/**
* 替换堆顶元素,比 pop() + push() 更高效。
* @returns 返回被替换掉的原堆顶元素。
*/
replace(item) {
const top = this.data[0];
this.data[0] = item;
this.downHeap(0, this.size);
return top;
}
downHeap(pos, size) {
const data = this.data;
const compare = this.compare;
const item = data[pos];
while (true) {
const leftChild = (pos << 1) + 1;
if (leftChild >= size) break;
let bestChild = leftChild;
const rightChild = leftChild + 1;
if (rightChild < size && compare(data[rightChild], data[leftChild]) < 0) {
bestChild = rightChild;
}
if (compare(item, data[bestChild]) <= 0) break;
data[pos] = data[bestChild];
pos = bestChild;
}
data[pos] = item;
}
peek() {
return this.size > 0 ? this.data[0] : void 0;
}
heapify() {
const size = this.size;
for (let i = (size >>> 1) - 1; i >= 0; i--) {
this.downHeap(i, size);
}
}
clear() {
this.size = 0;
}
toArray() {
return this.data.slice(0, this.size);
}
toSortedArray() {
const sorted = this.data.slice(0, this.size);
sorted.sort(this.compare);
return sorted;
}
};
// src/generic.ts
function createHNSWIndex(config) {
const { M, efConstruction, distanceFunction, distanceToQuery } = config;
const vectors = [];
const neighbors = [];
const entryPoint = { idx: -1, level: -1 };
const deletedNodes = /* @__PURE__ */ new Set();
const nodeToLevels = /* @__PURE__ */ new Map();
let visited = null;
const maxNodes = 1e4;
visited = new Uint8Array(maxNodes);
function isValidNode(idx) {
return idx >= 0 && idx < vectors.length && !deletedNodes.has(idx);
}
function distance(idxA, idxB) {
if (!isValidNode(idxA) || !isValidNode(idxB)) {
return Infinity;
}
return distanceFunction(vectors[idxA], vectors[idxB]);
}
function distanceToTarget(queryVector, targetIdx) {
if (!isValidNode(targetIdx)) {
return Infinity;
}
const targetVector = vectors[targetIdx];
if (distanceToQuery) {
return distanceToQuery(queryVector, targetVector);
}
return distanceFunction(queryVector, targetVector);
}
function searchLayerWithQuery(queryVector, startNodeIdx, level, ef) {
var _a;
if (!visited) throw new Error("Visited set not initialized");
visited.fill(0);
if (!isValidNode(startNodeIdx)) {
return [];
}
const candidates = new BinaryHeapGeneric([], (a, b) => a.distance - b.distance);
const results = new MidiHeapGeneric(ef, (a, b) => b.distance - a.distance);
const startNodeDist = distanceToTarget(queryVector, startNodeIdx);
visited[startNodeIdx] = 1;
candidates.push({ idx: startNodeIdx, distance: startNodeDist });
results.push({ idx: startNodeIdx, distance: startNodeDist });
while (candidates.length > 0) {
const bestCandidate = candidates.peek();
const farthestResult = results.peek();
if (farthestResult && bestCandidate.distance > farthestResult.distance && results.isFull()) {
break;
}
const cand = candidates.pop();
const nodeNeighbors = ((_a = neighbors[cand.idx]) == null ? void 0 : _a[level]) || [];
for (const neighborIdx of nodeNeighbors) {
if (visited[neighborIdx] === 0 && isValidNode(neighborIdx)) {
visited[neighborIdx] = 1;
const dist = distanceToTarget(queryVector, neighborIdx);
const currentFarthest = results.peek();
if (!currentFarthest || !results.isFull() || dist < currentFarthest.distance) {
candidates.push({ idx: neighborIdx, distance: dist });
if (!results.isFull()) {
results.push({ idx: neighborIdx, distance: dist });
} else {
results.replace({ idx: neighborIdx, distance: dist });
}
}
}
}
}
return results.toSortedArray().reverse();
}
function getNeighborsByHeuristic(candidates, M2) {
if (candidates.length <= M2) {
return candidates;
}
const result = [];
const visited2 = /* @__PURE__ */ new Set();
for (const cand of candidates) {
if (result.length >= M2) break;
if (visited2.has(cand.idx)) continue;
let good = true;
for (const res of result) {
if (distance(cand.idx, res.idx) < cand.distance) {
good = false;
break;
}
}
if (good) {
result.push(cand);
visited2.add(cand.idx);
}
}
return result;
}
function removeNodeFromNeighbors(nodeIdx) {
for (let i = 0; i < neighbors.length; i++) {
if (!isValidNode(i)) continue;
const nodeNeighbors = neighbors[i];
if (!nodeNeighbors) continue;
for (let level = 0; level < nodeNeighbors.length; level++) {
const levelNeighbors = nodeNeighbors[level];
if (!levelNeighbors) continue;
nodeNeighbors[level] = levelNeighbors.filter((neighborIdx) => neighborIdx !== nodeIdx);
}
}
}
function reselectEntryPoint() {
let newEntryPoint = { idx: -1, level: -1 };
for (let i = 0; i < vectors.length; i++) {
if (!isValidNode(i)) continue;
const nodeLevel = nodeToLevels.get(i) || 0;
if (nodeLevel > newEntryPoint.level) {
newEntryPoint = { idx: i, level: nodeLevel };
}
}
if (newEntryPoint.idx === -1) {
entryPoint.idx = -1;
entryPoint.level = -1;
} else {
entryPoint.idx = newEntryPoint.idx;
entryPoint.level = newEntryPoint.level;
}
}
function connectNeighbors(nodeIdx, level, nearestNeighbors) {
const maxConnections = level === 0 ? M * 2 : M;
const selectedNeighbors = getNeighborsByHeuristic(nearestNeighbors, maxConnections);
neighbors[nodeIdx][level] = selectedNeighbors.map((n) => n.idx);
for (const neighbor of selectedNeighbors) {
if (!Array.isArray(neighbors[neighbor.idx])) {
neighbors[neighbor.idx] = [];
}
if (!Array.isArray(neighbors[neighbor.idx][level])) {
neighbors[neighbor.idx][level] = [];
}
neighbors[neighbor.idx][level].push(nodeIdx);
if (neighbors[neighbor.idx][level].length > maxConnections) {
const connections = neighbors[neighbor.idx][level];
const heap = new MidiHeapGeneric(maxConnections, (a, b) => b.distance - a.distance);
for (const connIdx of connections) {
const dist = distance(neighbor.idx, connIdx);
const connNode = { idx: connIdx, distance: dist };
if (!heap.isFull()) {
heap.push(connNode);
} else if (dist < heap.peek().distance) {
heap.replace(connNode);
}
}
neighbors[neighbor.idx][level] = heap.toArray().map((c) => c.idx);
}
}
}
function insertNode(vector) {
const newNodeIdx = vectors.length;
vectors.push(vector);
neighbors.push([]);
if (entryPoint.idx === -1) {
entryPoint.idx = newNodeIdx;
entryPoint.level = 0;
neighbors[newNodeIdx][0] = [];
nodeToLevels.set(newNodeIdx, 0);
return;
}
function assignLevelDeterministic(arrayIndex, maxLevels) {
const internalId = arrayIndex + 1;
const lowestBit = internalId & -internalId;
const level = 31 - Math.clz32(lowestBit);
return Math.min(level, maxLevels - 1);
}
const max_level = 16;
const randomLevel = assignLevelDeterministic(newNodeIdx, max_level);
const topLevel = entryPoint.level;
nodeToLevels.set(newNodeIdx, randomLevel);
let currentNodeIdx = entryPoint.idx;
for (let level = topLevel; level > randomLevel; level--) {
const results = searchLayerWithQuery(vector, currentNodeIdx, level, 1);
if (results.length > 0) {
currentNodeIdx = results[0].idx;
}
}
for (let level = Math.min(randomLevel, topLevel); level >= 0; level--) {
const nearestCandidates = searchLayerWithQuery(vector, currentNodeIdx, level, efConstruction);
const M_level = level === 0 ? M * 2 : M;
const selectedNeighbors = getNeighborsByHeuristic(nearestCandidates, M_level);
connectNeighbors(newNodeIdx, level, selectedNeighbors);
if (selectedNeighbors.length > 0) {
currentNodeIdx = selectedNeighbors[0].idx;
}
}
if (randomLevel > topLevel) {
entryPoint.idx = newNodeIdx;
entryPoint.level = randomLevel;
}
}
function deleteNode(nodeIdx) {
if (nodeIdx < 0 || nodeIdx >= vectors.length || deletedNodes.has(nodeIdx)) {
return false;
}
deletedNodes.add(nodeIdx);
removeNodeFromNeighbors(nodeIdx);
if (nodeIdx === entryPoint.idx) {
reselectEntryPoint();
}
return true;
}
function search(queryVector, k, efSearch) {
if (entryPoint.idx === -1 || !isValidNode(entryPoint.idx)) return [];
let currentNodeIdx = entryPoint.idx;
const topLevel = entryPoint.level;
for (let level = topLevel; level > 0; level--) {
const results = searchLayerWithQuery(queryVector, currentNodeIdx, level, 1);
if (results.length > 0) {
currentNodeIdx = results[0].idx;
} else {
reselectEntryPoint();
if (entryPoint.idx === -1) return [];
currentNodeIdx = entryPoint.idx;
break;
}
}
const finalEf = Math.max(k, efSearch || efConstruction);
const finalResults = searchLayerWithQuery(queryVector, currentNodeIdx, 0, finalEf);
return finalResults.slice(0, k);
}
return {
insertNode,
search,
deleteNode,
getStats: () => ({
nodeCount: vectors.length,
activeNodeCount: vectors.length - deletedNodes.size,
deletedNodeCount: deletedNodes.size,
entryPoint: { ...entryPoint }
})
};
}
// src/vector.ts
function createHNSWIndex2(config) {
const { M, efConstruction, metricType } = config;
const vectors = [];
const norms = [];
const neighbors = [];
const entryPoint = { idx: -1, level: -1 };
const deletedNodes = /* @__PURE__ */ new Set();
const nodeToLevels = /* @__PURE__ */ new Map();
let visited = null;
const maxNodes = 1e4;
visited = new Uint8Array(maxNodes);
function isValidNode(idx) {
return idx >= 0 && idx < vectors.length && !deletedNodes.has(idx);
}
function distance(idxA, idxB) {
if (!isValidNode(idxA) || !isValidNode(idxB)) {
return Infinity;
}
const vecA = vectors[idxA];
const vecB = vectors[idxB];
const normA = norms[idxA];
const normB = norms[idxB];
const len = vecA.length;
if (metricType === "l2") {
let d = 0;
for (let j = 0; j < len; j++) {
const diff = vecA[j] - vecB[j];
d += diff * diff;
}
return d;
}
let dotProduct = 0;
let i = 0;
for (; i < len - 3; i += 4) {
dotProduct += vecA[i] * vecB[i] + vecA[i + 1] * vecB[i + 1] + vecA[i + 2] * vecB[i + 2] + vecA[i + 3] * vecB[i + 3];
}
for (; i < len; i++) {
dotProduct += vecA[i] * vecB[i];
}
return normA === 0 || normB === 0 ? 1 : 1 - dotProduct / (normA * normB);
}
function distanceToQuery(queryVector, queryNorm, targetIdx) {
if (!isValidNode(targetIdx)) {
return Infinity;
}
const targetVector = vectors[targetIdx];
const targetNorm = norms[targetIdx];
const len = queryVector.length;
if (metricType === "l2") {
let d = 0;
for (let j = 0; j < len; j++) {
const diff = queryVector[j] - targetVector[j];
d += diff * diff;
}
return d;
}
let dotProduct = 0;
let i = 0;
for (; i < len - 3; i += 4) {
dotProduct += queryVector[i] * targetVector[i] + queryVector[i + 1] * targetVector[i + 1] + queryVector[i + 2] * targetVector[i + 2] + queryVector[i + 3] * targetVector[i + 3];
}
for (; i < len; i++) {
dotProduct += queryVector[i] * targetVector[i];
}
return queryNorm === 0 || targetNorm === 0 ? 1 : 1 - dotProduct / (queryNorm * targetNorm);
}
function searchLayerWithQuery(queryVector, queryNorm, startNodeIdx, level, ef) {
var _a;
if (!visited) throw new Error("Visited set not initialized");
visited.fill(0);
if (!isValidNode(startNodeIdx)) {
return [];
}
const candidates = new BinaryHeapGeneric([], (a, b) => a.distance - b.distance);
const results = new MidiHeapGeneric(ef, (a, b) => b.distance - a.distance);
const startNodeDist = distanceToQuery(queryVector, queryNorm, startNodeIdx);
visited[startNodeIdx] = 1;
candidates.push({ idx: startNodeIdx, distance: startNodeDist });
results.push({ idx: startNodeIdx, distance: startNodeDist });
while (candidates.length > 0) {
const bestCandidate = candidates.peek();
const farthestResult = results.peek();
if (farthestResult && bestCandidate.distance > farthestResult.distance && results.isFull()) {
break;
}
const cand = candidates.pop();
const nodeNeighbors = ((_a = neighbors[cand.idx]) == null ? void 0 : _a[level]) || [];
for (const neighborIdx of nodeNeighbors) {
if (visited[neighborIdx] === 0 && isValidNode(neighborIdx)) {
visited[neighborIdx] = 1;
const dist = distanceToQuery(queryVector, queryNorm, neighborIdx);
const currentFarthest = results.peek();
if (!currentFarthest || !results.isFull() || dist < currentFarthest.distance) {
candidates.push({ idx: neighborIdx, distance: dist });
if (!results.isFull()) {
results.push({ idx: neighborIdx, distance: dist });
} else {
results.replace({ idx: neighborIdx, distance: dist });
}
}
}
}
}
return results.toSortedArray().reverse();
}
function getNeighborsByHeuristic(candidates, M2) {
if (candidates.length <= M2) {
return candidates;
}
const result = [];
const visited2 = /* @__PURE__ */ new Set();
for (const cand of candidates) {
if (result.length >= M2) break;
if (visited2.has(cand.idx)) continue;
let good = true;
for (const res of result) {
if (distance(cand.idx, res.idx) < cand.distance) {
good = false;
break;
}
}
if (good) {
result.push(cand);
visited2.add(cand.idx);
}
}
return result;
}
function removeNodeFromNeighbors(nodeIdx) {
for (let i = 0; i < neighbors.length; i++) {
if (!isValidNode(i)) continue;
const nodeNeighbors = neighbors[i];
if (!nodeNeighbors) continue;
for (let level = 0; level < nodeNeighbors.length; level++) {
const levelNeighbors = nodeNeighbors[level];
if (!levelNeighbors) continue;
nodeNeighbors[level] = levelNeighbors.filter((neighborIdx) => neighborIdx !== nodeIdx);
}
}
}
function reselectEntryPoint() {
let newEntryPoint = { idx: -1, level: -1 };
for (let i = 0; i < vectors.length; i++) {
if (!isValidNode(i)) continue;
const nodeLevel = nodeToLevels.get(i) || 0;
if (nodeLevel > newEntryPoint.level) {
newEntryPoint = { idx: i, level: nodeLevel };
}
}
if (newEntryPoint.idx === -1) {
entryPoint.idx = -1;
entryPoint.level = -1;
} else {
entryPoint.idx = newEntryPoint.idx;
entryPoint.level = newEntryPoint.level;
}
}
function connectNeighbors(nodeIdx, level, nearestNeighbors) {
const maxConnections = level === 0 ? M * 2 : M;
const selectedNeighbors = getNeighborsByHeuristic(nearestNeighbors, maxConnections);
neighbors[nodeIdx][level] = selectedNeighbors.map((n) => n.idx);
for (const neighbor of selectedNeighbors) {
if (!Array.isArray(neighbors[neighbor.idx])) {
neighbors[neighbor.idx] = [];
}
if (!Array.isArray(neighbors[neighbor.idx][level])) {
neighbors[neighbor.idx][level] = [];
}
neighbors[neighbor.idx][level].push(nodeIdx);
if (neighbors[neighbor.idx][level].length > maxConnections) {
const connections = neighbors[neighbor.idx][level];
const heap = new MidiHeapGeneric(maxConnections, (a, b) => b.distance - a.distance);
for (const connIdx of connections) {
const dist = distance(neighbor.idx, connIdx);
const connNode = { idx: connIdx, distance: dist };
if (!heap.isFull()) {
heap.push(connNode);
} else if (dist < heap.peek().distance) {
heap.replace(connNode);
}
}
neighbors[neighbor.idx][level] = heap.toArray().map((c) => c.idx);
}
}
}
function insertNode(vector) {
const newNodeIdx = vectors.length;
let norm = 0;
for (let i = 0; i < vector.length; i++) {
norm += vector[i] * vector[i];
}
norm = Math.sqrt(norm);
vectors.push(vector);
norms.push(norm);
neighbors.push([]);
if (entryPoint.idx === -1) {
entryPoint.idx = newNodeIdx;
entryPoint.level = 0;
neighbors[newNodeIdx][0] = [];
nodeToLevels.set(newNodeIdx, 0);
return;
}
function assignLevelDeterministic(arrayIndex, maxLevels) {
const internalId = arrayIndex + 1;
const lowestBit = internalId & -internalId;
const level = 31 - Math.clz32(lowestBit);
return Math.min(level, maxLevels - 1);
}
const max_level = 16;
const randomLevel = assignLevelDeterministic(newNodeIdx, max_level);
const topLevel = entryPoint.level;
nodeToLevels.set(newNodeIdx, randomLevel);
let currentNodeIdx = entryPoint.idx;
for (let level = topLevel; level > randomLevel; level--) {
const results = searchLayerWithQuery(vector, norm, currentNodeIdx, level, 1);
if (results.length > 0) {
currentNodeIdx = results[0].idx;
}
}
for (let level = Math.min(randomLevel, topLevel); level >= 0; level--) {
const nearestCandidates = searchLayerWithQuery(vector, norm, currentNodeIdx, level, efConstruction);
const M_level = level === 0 ? M * 2 : M;
const selectedNeighbors = getNeighborsByHeuristic(nearestCandidates, M_level);
connectNeighbors(newNodeIdx, level, selectedNeighbors);
if (selectedNeighbors.length > 0) {
currentNodeIdx = selectedNeighbors[0].idx;
}
}
if (randomLevel > topLevel) {
entryPoint.idx = newNodeIdx;
entryPoint.level = randomLevel;
}
}
function deleteNode(nodeIdx) {
if (nodeIdx < 0 || nodeIdx >= vectors.length || deletedNodes.has(nodeIdx)) {
return false;
}
deletedNodes.add(nodeIdx);
removeNodeFromNeighbors(nodeIdx);
if (nodeIdx === entryPoint.idx) {
reselectEntryPoint();
}
return true;
}
function search(queryVector, k, efSearch) {
let queryNorm = 0;
for (let i = 0; i < queryVector.length; i++) {
queryNorm += queryVector[i] * queryVector[i];
}
queryNorm = Math.sqrt(queryNorm);
if (entryPoint.idx === -1 || !isValidNode(entryPoint.idx)) return [];
let currentNodeIdx = entryPoint.idx;
const topLevel = entryPoint.level;
for (let level = topLevel; level > 0; level--) {
const results = searchLayerWithQuery(queryVector, queryNorm, currentNodeIdx, level, 1);
if (results.length > 0) {
currentNodeIdx = results[0].idx;
} else {
reselectEntryPoint();
if (entryPoint.idx === -1) return [];
currentNodeIdx = entryPoint.idx;
break;
}
}
const finalEf = Math.max(k, efSearch || efConstruction);
const finalResults = searchLayerWithQuery(queryVector, queryNorm, currentNodeIdx, 0, finalEf);
return finalResults.slice(0, k);
}
return {
insertNode,
search,
deleteNode,
getStats: () => ({
nodeCount: vectors.length,
activeNodeCount: vectors.length - deletedNodes.size,
deletedNodeCount: deletedNodes.size,
entryPoint: { ...entryPoint }
})
};
}
// src/index.ts
var hnsw = {
createIndex: createHNSWIndex2,
createIndexGeneric: createHNSWIndex
};
export { hnsw };
//# sourceMappingURL=index.mjs.map
//# sourceMappingURL=index.mjs.map