UNPKG

chromadb-default-embed

Version:

Chroma's fork of @xenova/transformers serving as our default embedding function

416 lines (380 loc) 12.5 kB
/** * @file Custom data structures. * * These are only used internally, meaning an end-user shouldn't * need to access anything here. * * @module utils/data-structures */ /** * Efficient Heap-based Implementation of a Priority Queue. * It uses an array-based binary heap, where the root is at index `0`, and the * children of node `i` are located at indices `2i + 1` and `2i + 2`, respectively. * * Adapted from the following sources: * - https://stackoverflow.com/a/42919752/13989043 (original) * - https://github.com/belladoreai/llama-tokenizer-js (minor improvements) */ export class PriorityQueue { /** * Create a new PriorityQueue. * @param {Function} comparator Comparator function to determine priority. Defaults to a MaxHeap. */ constructor(comparator = (a, b) => a > b) { this._heap = []; this._comparator = comparator; } /** * The size of the queue */ get size() { return this._heap.length; } /** * Check if the queue is empty. * @returns {boolean} `true` if the queue is empty, `false` otherwise. */ isEmpty() { return this.size === 0; } /** * Return the element with the highest priority in the queue. * @returns {any} The highest priority element in the queue. */ peek() { return this._heap[0]; } /** * Add one or more elements to the queue. * @param {...any} values The values to push into the queue. * @returns {number} The new size of the queue. */ push(...values) { return this.extend(values); } /** * Add multiple elements to the queue. * @param {any[]} values The values to push into the queue. * @returns {number} The new size of the queue. */ extend(values) { for (const value of values) { this._heap.push(value); this._siftUp(); } return this.size; } /** * Remove and return the element with the highest priority in the queue. * @returns {any} The element with the highest priority in the queue. */ pop() { const poppedValue = this.peek(); const bottom = this.size - 1; if (bottom > 0) { this._swap(0, bottom); } this._heap.pop(); this._siftDown(); return poppedValue; } /** * Replace the element with the highest priority in the queue with a new value. * @param {*} value The new value. * @returns {*} The replaced value. */ replace(value) { const replacedValue = this.peek(); this._heap[0] = value; this._siftDown(); return replacedValue; } /** * Compute the index for the parent of the node at index `i`. * @param {number} i The index of the node to get the parent of. * @returns {number} The index of the parent node. * @private */ _parent(i) { return ((i + 1) >>> 1) - 1; } /** * Compute the index for the left child of the node at index `i`. * @param {number} i The index of the node to get the left child of. * @returns {number} The index of the left child. * @private */ _left(i) { return (i << 1) + 1; } /** * Compute the index for the right child of the node at index `i`. * @param {number} i The index of the node to get the right child of. * @returns {number} The index of the right child. * @private */ _right(i) { return (i + 1) << 1; } /** * Check if the element at index `i` is greater than the element at index `j`. * @param {number} i The index of the first element to compare. * @param {number} j The index of the second element to compare. * @returns {boolean} `true` if the element at index `i` is greater than the element at index `j`, `false` otherwise. * @private */ _greater(i, j) { return this._comparator(this._heap[i], this._heap[j]); } /** * Swap the elements at indices `i` and `j`. * @param {number} i The index of the first element to swap. * @param {number} j The index of the second element to swap. * @private */ _swap(i, j) { const temp = this._heap[i]; this._heap[i] = this._heap[j]; this._heap[j] = temp; } /** * Maintain the heap property by updating positions in the heap, * starting at the last element and moving up the heap. * @private */ _siftUp() { let node = this.size - 1; while (node > 0 && this._greater(node, this._parent(node))) { this._swap(node, this._parent(node)); node = this._parent(node); } } /** * Maintain the heap property by updating positions in the heap, * starting at the first element and moving down the heap. * @private */ _siftDown() { let node = 0; while ( (this._left(node) < this.size && this._greater(this._left(node), node)) || (this._right(node) < this.size && this._greater(this._right(node), node)) ) { const maxChild = (this._right(node) < this.size && this._greater(this._right(node), this._left(node))) ? this._right(node) : this._left(node); this._swap(node, maxChild); node = maxChild; } } } /** * A trie structure to efficiently store and search for strings. */ export class CharTrie { constructor() { this.root = CharTrieNode.default(); } /** * Adds one or more `texts` to the trie. * @param {string[]} texts The strings to add to the trie. */ extend(texts) { for (let text of texts) { this.push(text); } } /** * Adds text to the trie. * @param {string} text The string to add to the trie. */ push(text) { let node = this.root; for (let ch of text) { let child = node.children.get(ch); if (child === undefined) { child = CharTrieNode.default(); node.children.set(ch, child); } node = child; } node.isLeaf = true; } /** * Searches the trie for all strings with a common prefix of `text`. * @param {string} text The common prefix to search for. * @yields {string} Each string in the trie that has `text` as a prefix. */ *commonPrefixSearch(text) { let node = this.root; let prefix = ""; for (let i = 0; i < text.length && node !== undefined; ++i) { const ch = text[i]; prefix += ch; node = node.children.get(ch); if (node !== undefined && node.isLeaf) { yield prefix; } } } } /** * Represents a node in a character trie. */ class CharTrieNode { /** * Create a new CharTrieNode. * @param {boolean} isLeaf Whether the node is a leaf node or not. * @param {Map<string, CharTrieNode>} children A map containing the node's children, where the key is a character and the value is a `CharTrieNode`. */ constructor(isLeaf, children) { this.isLeaf = isLeaf; this.children = children; } /** * Returns a new `CharTrieNode` instance with default values. * @returns {CharTrieNode} A new `CharTrieNode` instance with `isLeaf` set to `false` and an empty `children` map. */ static default() { return new CharTrieNode(false, new Map()); } } /** * A lattice data structure to be used for tokenization. */ export class TokenLattice { /** * Creates a new TokenLattice instance. * * @param {string} sentence The input sentence to be tokenized. * @param {number} bosTokenId The beginning-of-sequence token ID. * @param {number} eosTokenId The end-of-sequence token ID. */ constructor(sentence, bosTokenId, eosTokenId) { this.sentence = sentence; this.len = sentence.length; this.bosTokenId = bosTokenId; this.eosTokenId = eosTokenId; this.nodes = []; this.beginNodes = Array.from({ length: this.len + 1 }, () => []); this.endNodes = Array.from({ length: this.len + 1 }, () => []); const bos = new TokenLatticeNode(this.bosTokenId, 0, 0, 0, 0.0); const eos = new TokenLatticeNode(this.eosTokenId, 1, this.len, 0, 0.0); this.nodes.push(bos.clone()); this.nodes.push(eos.clone()); this.beginNodes[this.len].push(eos); this.endNodes[0].push(bos); } /** * Inserts a new token node into the token lattice. * * @param {number} pos The starting position of the token. * @param {number} length The length of the token. * @param {number} score The score of the token. * @param {number} tokenId The token ID of the token. */ insert(pos, length, score, tokenId) { const nodeId = this.nodes.length; const node = new TokenLatticeNode(tokenId, nodeId, pos, length, score); this.beginNodes[pos].push(node); this.endNodes[pos + length].push(node); this.nodes.push(node); } /** * Implements the Viterbi algorithm to compute the most likely sequence of tokens. * * @returns {TokenLatticeNode[]} The array of nodes representing the most likely sequence of tokens. */ viterbi() { const len = this.len; let pos = 0; while (pos <= len) { if (this.beginNodes[pos].length == 0) { return []; } for (let rnode of this.beginNodes[pos]) { rnode.prev = null; let bestScore = 0.0; let bestNode = null; for (let lnode of this.endNodes[pos]) { const score = lnode.backtraceScore + rnode.score; if (bestNode === null || score > bestScore) { bestNode = lnode.clone(); bestScore = score; } } if (bestNode !== null) { rnode.prev = bestNode; rnode.backtraceScore = bestScore; } else { return []; } } ++pos; } const results = []; const root = this.beginNodes[len][0]; const prev = root.prev; if (prev === null) { return []; } let node = prev.clone(); while (node.prev !== null) { results.push(node.clone()); const n = node.clone(); node = n.prev.clone(); } results.reverse(); return results; } /** * @param {TokenLatticeNode} node * @returns {string} The array of nodes representing the most likely sequence of tokens. */ piece(node) { return this.sentence.slice(node.pos, node.pos + node.length); } /** * @returns {Array} The array of nodes representing the most likely sequence of tokens. */ tokens() { const nodes = this.viterbi(); return nodes.map(x => this.piece(x)); } /** * @returns {Array} The array of nodes representing the most likely sequence of tokens. */ tokenIds() { const nodes = this.viterbi(); return nodes.map(x => x.tokenId); } } class TokenLatticeNode { /** * Represents a node in a token lattice for a given sentence. * @param {number} tokenId The ID of the token associated with this node. * @param {number} nodeId The ID of this node. * @param {number} pos The starting position of the token in the sentence. * @param {number} length The length of the token. * @param {number} score The score associated with the token. */ constructor(tokenId, nodeId, pos, length, score) { this.tokenId = tokenId; this.nodeId = nodeId; this.pos = pos; this.length = length; this.score = score; this.prev = null; this.backtraceScore = 0.0; } /** * Returns a clone of this node. * @returns {TokenLatticeNode} A clone of this node. */ clone() { const n = new TokenLatticeNode(this.tokenId, this.nodeId, this.pos, this.length, this.score); n.prev = this.prev; n.backtraceScore = this.backtraceScore; return n; } }