fastembed
Version:
NodeJS implementation of @Qdrant/fastembed
512 lines • 23.1 kB
JavaScript
import { AddedToken, Tokenizer } from "@anush008/tokenizers";
import fs from "fs";
import https from "https";
import * as ort from "onnxruntime-node";
import path from "path";
import Progress from "progress";
import tar from "tar";
import { downloadFileToCacheDir } from "@huggingface/hub";
export var ExecutionProvider;
(function (ExecutionProvider) {
ExecutionProvider["CPU"] = "cpu";
ExecutionProvider["CUDA"] = "cuda";
ExecutionProvider["WebGL"] = "webgl";
ExecutionProvider["WASM"] = "wasm";
ExecutionProvider["XNNPACK"] = "xnnpack";
})(ExecutionProvider || (ExecutionProvider = {}));
export var EmbeddingModel;
(function (EmbeddingModel) {
EmbeddingModel["AllMiniLML6V2"] = "fast-all-MiniLM-L6-v2";
EmbeddingModel["BGEBaseEN"] = "fast-bge-base-en";
EmbeddingModel["BGEBaseENV15"] = "fast-bge-base-en-v1.5";
EmbeddingModel["BGESmallEN"] = "fast-bge-small-en";
EmbeddingModel["BGESmallENV15"] = "fast-bge-small-en-v1.5";
EmbeddingModel["BGESmallZH"] = "fast-bge-small-zh-v1.5";
EmbeddingModel["MLE5Large"] = "fast-multilingual-e5-large";
EmbeddingModel["CUSTOM"] = "custom";
})(EmbeddingModel || (EmbeddingModel = {}));
export var SparseEmbeddingModel;
(function (SparseEmbeddingModel) {
SparseEmbeddingModel["SpladePPEnV1"] = "prithivida/Splade_PP_en_v1";
SparseEmbeddingModel["CUSTOM"] = "custom";
})(SparseEmbeddingModel || (SparseEmbeddingModel = {}));
function normalize(v) {
const norm = Math.sqrt(v.reduce((acc, val) => acc + val * val, 0));
const epsilon = 1e-12;
return v.map((val) => val / Math.max(norm, epsilon));
}
function getEmbeddings(data, dimensions) {
const [x, y, z] = dimensions;
return new Array(x).fill(undefined).map((_, index) => {
const startIndex = index * y * z;
const endIndex = startIndex + z;
return data.slice(startIndex, endIndex);
});
}
class Embedding {
}
class SparseEmbedding {
}
export class FlagEmbedding extends Embedding {
constructor(tokenizer, session, model) {
super();
this.tokenizer = tokenizer;
this.session = session;
this.model = model;
}
static async init({ model = EmbeddingModel.BGESmallENV15, executionProviders = [ExecutionProvider.CPU], maxLength = 512, cacheDir = "local_cache", showDownloadProgress = true, modelAbsoluteDirPath = "", modelName = "", } = {}) {
if (model === EmbeddingModel.CUSTOM) {
if (!modelAbsoluteDirPath) {
throw new Error("For custom model, modelAbsoluteDirPath is required in FlagEmbedding.init");
}
if (!modelName) {
throw new Error("For custom model, modelName is required in FlagEmbedding.init");
}
}
const modelDir = model === EmbeddingModel.CUSTOM
? modelAbsoluteDirPath
: await FlagEmbedding.retrieveModel(model, cacheDir, showDownloadProgress);
const tokenizer = this.loadTokenizer(modelDir, maxLength);
const defaultModelName = model === EmbeddingModel.MLE5Large ||
model === EmbeddingModel.AllMiniLML6V2
? "model.onnx"
: "model_optimized.onnx";
const modelPath = path.join(modelDir.toString(), modelName || defaultModelName);
if (!fs.existsSync(modelPath)) {
throw new Error(`Model file not found at ${modelPath}`);
}
const session = await ort.InferenceSession.create(modelPath, {
executionProviders,
graphOptimizationLevel: "all",
});
return new FlagEmbedding(tokenizer, session, model);
}
static loadTokenizer(modelDir, maxLength) {
const tokenizerPath = path.join(modelDir.toString(), "tokenizer.json");
if (!fs.existsSync(tokenizerPath)) {
throw new Error(`Tokenizer file not found at ${tokenizerPath}`);
}
const configPath = path.join(modelDir.toString(), "config.json");
if (!fs.existsSync(configPath)) {
throw new Error(`Config file not found at ${configPath}`);
}
const config = JSON.parse(fs.readFileSync(configPath, "utf-8"));
const tokenizerFilePath = path.join(modelDir.toString(), "tokenizer_config.json");
if (!fs.existsSync(tokenizerFilePath)) {
throw new Error(`Tokenizer file not found at ${tokenizerFilePath}`);
}
const tokenizerConfig = JSON.parse(fs.readFileSync(tokenizerFilePath, "utf-8"));
maxLength = Math.min(maxLength, tokenizerConfig["model_max_length"]);
const tokensMapPath = path.join(modelDir.toString(), "special_tokens_map.json");
if (!fs.existsSync(tokensMapPath)) {
throw new Error(`Tokens map file not found at ${tokensMapPath}`);
}
const tokensMap = JSON.parse(fs.readFileSync(tokensMapPath, "utf-8"));
const tokenizer = Tokenizer.fromFile(tokenizerPath);
tokenizer.setTruncation(maxLength);
tokenizer.setPadding({
maxLength,
padId: config["pad_token_id"],
padToken: tokenizerConfig["pad_token"],
});
for (let token of Object.values(tokensMap)) {
if (typeof token === "string") {
tokenizer.addSpecialTokens([token]);
}
else if (isAddedTokenMap(token)) {
const addedToken = new AddedToken(token["content"], true, {
singleWord: token["single_word"],
leftStrip: token["lstrip"],
rightStrip: token["rstrip"],
normalized: token["normalized"],
});
tokenizer.addAddedTokens([addedToken]);
}
}
return tokenizer;
}
static async downloadFileFromGCS(outputFilePath, model, showDownloadProgress = true) {
if (fs.existsSync(outputFilePath)) {
return outputFilePath;
}
// The AllMiniLML6V2 model URL doesn't follow the same naming convention as the other models
// So, we transform "fast-all-MiniLM-L6-v2" -> "sentence-transformers-all-MiniLM-L6-v2" in the download URL
// The model directory name in the GCS storage remains "fast-all-MiniLM-L6-v2"
if (model === EmbeddingModel.AllMiniLML6V2) {
model = "sentence-transformers" + model.substring(model.indexOf("-"));
}
const url = `https://storage.googleapis.com/qdrant-fastembed/${model}.tar.gz`;
const fileStream = fs.createWriteStream(outputFilePath);
return new Promise((resolve, reject) => {
https
.get(url, { headers: { "User-Agent": "Mozilla/5.0" } }, (response) => {
const totalSizeInBytes = parseInt(response.headers["content-length"] || "0", 10);
if (totalSizeInBytes === 0) {
console.warn(`Warning: Content-length header is missing or zero in the response from ${url}.`);
}
if (showDownloadProgress) {
const progressBar = new Progress(`Downloading ${model} [:bar] :percent :etas`, {
complete: "=",
width: 20,
total: totalSizeInBytes,
});
response.on("data", (chunk) => {
progressBar.tick(chunk.length, { speed: "N/A" });
});
}
response.on("error", (error) => {
reject(error);
});
response.pipe(fileStream);
fileStream.on("finish", () => {
fileStream.close();
resolve(outputFilePath);
});
fileStream.on("error", (error) => {
reject(error);
});
})
.on("error", (error) => {
fs.unlink(outputFilePath, () => {
reject(error);
});
});
});
}
static async decompressToCache(targzPath, cacheDir) {
// Implementation for decompressing a .tar.gz file to a cache directory
if (path.extname(targzPath.toString()) === ".gz") {
await tar.x({
file: targzPath,
// @ts-ignore
cwd: cacheDir,
});
}
else {
throw new Error(`Unsupported file extension: ${targzPath}`);
}
}
static async retrieveModel(model, cacheDir, showDownloadProgress = true) {
if (!fs.existsSync(cacheDir)) {
fs.mkdirSync(cacheDir, {
mode: 0o777,
});
}
const modelDir = path.join(cacheDir.toString(), model);
if (fs.existsSync(modelDir)) {
return modelDir;
}
const modelTarGz = path.join(cacheDir.toString(), `${model}.tar.gz`);
await this.downloadFileFromGCS(modelTarGz, model, showDownloadProgress);
await this.decompressToCache(modelTarGz, cacheDir);
fs.unlinkSync(modelTarGz);
return modelDir;
}
async *embed(textStrings, batchSize = 256) {
for (let i = 0; i < textStrings.length; i += batchSize) {
const batchTexts = textStrings.slice(i, i + batchSize);
const encodedTexts = await Promise.all(batchTexts.map((textString) => this.tokenizer.encode(textString)));
const idsArray = [];
const maskArray = [];
const typeIdsArray = [];
encodedTexts.forEach((text) => {
const ids = text.getIds().map(BigInt);
const mask = text.getAttentionMask().map(BigInt);
const typeIds = text.getTypeIds().map(BigInt);
idsArray.push(ids);
maskArray.push(mask);
typeIdsArray.push(typeIds);
});
const maxLength = idsArray[0].length;
const batchInputIds = new ort.Tensor("int64", idsArray.flat(), [batchTexts.length, maxLength]);
const batchAttentionMask = new ort.Tensor("int64", maskArray.flat(), [batchTexts.length, maxLength]);
const batchTokenTypeId = new ort.Tensor("int64", typeIdsArray.flat(), [batchTexts.length, maxLength]);
const inputs = {
input_ids: batchInputIds,
attention_mask: batchAttentionMask,
token_type_ids: batchTokenTypeId,
};
// Exclude token_type_ids for MLE5Large
if (this.model === EmbeddingModel.MLE5Large) {
delete inputs.token_type_ids;
}
const output = await this.session.run(inputs);
// Remove attention pooling
// Ref: https://github.com/qdrant/fastembed/commit/a335c8898f11042fdb311fce2dab3acf50c23011
// const lastHiddenState: number[][][] = create3DArray(
// output.last_hidden_state.data as unknown[] as number[],
// output.last_hidden_state.dims as number[]
// );
// const embeddings = lastHiddenState.map((layer, layerIdx) => {
// const weightedSum = layer.reduce((acc, tokenEmbedding, idx) => {
// const attentionWeight = maskArray[layerIdx][idx];
// return acc.map(
// (val, i) => val + tokenEmbedding[i] * Number(attentionWeight)
// );
// }, new Array(layer[0].length).fill(0));
// const inputMaskSum = maskArray[layerIdx].reduce(
// (acc, attentionWeight) => acc + Number(attentionWeight),
// 0
// );
// return weightedSum.map((val) => val / Math.max(inputMaskSum, 1e-9));
// });
// const embeddings = lastHiddenState.map((sentence) => sentence[0]);
const embeddings = getEmbeddings(output.last_hidden_state.data, output.last_hidden_state.dims);
yield embeddings.map(normalize);
}
}
passageEmbed(texts, batchSize = 256) {
texts = texts.map((text) => `passage: ${text}`);
return this.embed(texts, batchSize);
}
async queryEmbed(query) {
return (await this.embed([`query: ${query}`]).next()).value[0];
}
listSupportedModels() {
return [
{
model: EmbeddingModel.BGESmallEN,
dim: 384,
description: "Fast English model",
},
{
model: EmbeddingModel.BGESmallENV15,
dim: 384,
description: "v1.5 release of the fast, default English model",
},
{
model: EmbeddingModel.BGEBaseEN,
dim: 768,
description: "Base English model",
},
{
model: EmbeddingModel.BGEBaseENV15,
dim: 768,
description: "v1.5 release of Base English model",
},
{
model: EmbeddingModel.BGESmallZH,
dim: 512,
description: "v1.5 release of the fast, Chinese model",
},
{
model: EmbeddingModel.AllMiniLML6V2,
dim: 384,
description: "Sentence Transformer model, MiniLM-L6-v2",
},
{
model: EmbeddingModel.MLE5Large,
dim: 1024,
description: "Multilingual model, e5-large. Recommend using this model for non-English languages",
},
];
}
}
// Sparse embedding implementation class
export class SparseTextEmbedding extends SparseEmbedding {
constructor(tokenizer, session, model, vocabSize) {
super();
this.tokenizer = tokenizer;
this.session = session;
this.model = model;
this.vocabSize = vocabSize;
}
static async init({ model = SparseEmbeddingModel.SpladePPEnV1, executionProviders = [ExecutionProvider.CPU], maxLength = 512, cacheDir = "local_cache", showDownloadProgress = true, modelAbsoluteDirPath = "", modelName = "", } = {}) {
if (model === SparseEmbeddingModel.CUSTOM) {
if (!modelAbsoluteDirPath) {
throw new Error("For custom model, modelAbsoluteDirPath is required in SparseTextEmbedding.init");
}
if (!modelName) {
throw new Error("For custom model, modelName is required in SparseTextEmbedding.init");
}
}
const modelDir = model === SparseEmbeddingModel.CUSTOM
? modelAbsoluteDirPath
: await SparseTextEmbedding.retrieveModel(model, cacheDir, showDownloadProgress);
const { tokenizer, vocabSize } = this.loadTokenizer(modelDir, maxLength);
const defaultModelName = "model.onnx";
const modelPath = path.join(modelDir.toString(), "onnx", modelName || defaultModelName);
if (!fs.existsSync(modelPath)) {
throw new Error(`Model file not found at ${modelPath}`);
}
const session = await ort.InferenceSession.create(modelPath, {
executionProviders,
graphOptimizationLevel: "all",
});
return new SparseTextEmbedding(tokenizer, session, model, vocabSize);
}
static loadTokenizer(modelDir, maxLength) {
const tokenizerPath = path.join(modelDir.toString(), "tokenizer.json");
if (!fs.existsSync(tokenizerPath)) {
throw new Error(`Tokenizer file not found at ${tokenizerPath}`);
}
const configPath = path.join(modelDir.toString(), "config.json");
if (!fs.existsSync(configPath)) {
throw new Error(`Config file not found at ${configPath}`);
}
const config = JSON.parse(fs.readFileSync(configPath, "utf-8"));
const tokenizerFilePath = path.join(modelDir.toString(), "tokenizer_config.json");
if (!fs.existsSync(tokenizerFilePath)) {
throw new Error(`Tokenizer file not found at ${tokenizerFilePath}`);
}
const tokenizerConfig = JSON.parse(fs.readFileSync(tokenizerFilePath, "utf-8"));
maxLength = Math.min(maxLength, tokenizerConfig["model_max_length"]);
const tokensMapPath = path.join(modelDir.toString(), "special_tokens_map.json");
if (!fs.existsSync(tokensMapPath)) {
throw new Error(`Tokens map file not found at ${tokensMapPath}`);
}
const tokensMap = JSON.parse(fs.readFileSync(tokensMapPath, "utf-8"));
const tokenizer = Tokenizer.fromFile(tokenizerPath);
tokenizer.setTruncation(maxLength);
tokenizer.setPadding({
maxLength,
padId: config["pad_token_id"],
padToken: tokenizerConfig["pad_token"],
});
for (let token of Object.values(tokensMap)) {
if (typeof token === "string") {
tokenizer.addSpecialTokens([token]);
}
else if (isAddedTokenMap(token)) {
const addedToken = new AddedToken(token["content"], true, {
singleWord: token["single_word"],
leftStrip: token["lstrip"],
rightStrip: token["rstrip"],
normalized: token["normalized"],
});
tokenizer.addAddedTokens([addedToken]);
}
}
const vocabSize = config["vocab_size"] || 30522;
return { tokenizer, vocabSize };
}
static async retrieveModel(model, cacheDir, showDownloadProgress = true) {
if (!fs.existsSync(cacheDir)) {
fs.mkdirSync(cacheDir, {
mode: 0o777,
});
}
const modelDir = path.join(cacheDir.toString(), model.replace("/", "_"));
if (fs.existsSync(modelDir)) {
return modelDir;
}
fs.mkdirSync(modelDir, { mode: 0o777 });
// Download required files from hf
const filesToDownload = [
"onnx/model.onnx",
"tokenizer.json",
"tokenizer_config.json",
"config.json",
"special_tokens_map.json",
];
for (const fileName of filesToDownload) {
const outputPath = path.join(modelDir, fileName);
const outputDir = path.dirname(outputPath);
if (!fs.existsSync(outputDir)) {
fs.mkdirSync(outputDir, { recursive: true, mode: 0o777 });
}
// Use HuggingFace Hub library to download
const downloaded = await downloadFileToCacheDir({
repo: model,
path: fileName,
});
// Copy from HF cache to our cache directory
// In Node.js, downloadFile returns a string path
if (downloaded && typeof downloaded === "string") {
fs.copyFileSync(downloaded, outputPath);
}
}
return modelDir;
}
async *embed(textStrings, batchSize = 256) {
for (let i = 0; i < textStrings.length; i += batchSize) {
const batchTexts = textStrings.slice(i, i + batchSize);
const encodedTexts = await Promise.all(batchTexts.map((textString) => this.tokenizer.encode(textString)));
const idsArray = [];
const maskArray = [];
const typeIdsArray = [];
encodedTexts.forEach((text) => {
const ids = text.getIds().map(BigInt);
const mask = text.getAttentionMask();
const typeIds = text.getTypeIds().map(BigInt);
idsArray.push(ids);
maskArray.push(mask);
typeIdsArray.push(typeIds);
});
const maxLength = idsArray[0].length;
const batchInputIds = new ort.Tensor("int64", idsArray.flat(), [batchTexts.length, maxLength]);
const batchAttentionMask = new ort.Tensor("int64", maskArray.flat().map(BigInt), [batchTexts.length, maxLength]);
const batchTokenTypeId = new ort.Tensor("int64", typeIdsArray.flat(), [batchTexts.length, maxLength]);
const inputs = {
input_ids: batchInputIds,
input_mask: batchAttentionMask,
segment_ids: batchTokenTypeId,
};
const output = await this.session.run(inputs);
// SPLADE postprocessing: log(1 + ReLU(logits))
// @ts-expect-error this is incorrect it is there?
const logits = output.output.cpuData;
const dims = output.output.dims;
const [currentBatchSize, seqLen, vocabSize] = dims;
const sparseVectors = [];
for (let batchIdx = 0; batchIdx < currentBatchSize; batchIdx++) {
const values = new Float32Array(vocabSize).fill(0);
// Apply log(1 + ReLU(logits)) and max pooling
for (let seqIdx = 0; seqIdx < seqLen; seqIdx++) {
const attentionValue = maskArray[batchIdx][seqIdx];
if (attentionValue > 0) {
for (let vocabIdx = 0; vocabIdx < vocabSize; vocabIdx++) {
const logitIdx = batchIdx * seqLen * vocabSize + seqIdx * vocabSize + vocabIdx;
const logitValue = logits[logitIdx];
// ReLU
const reluValue = Math.max(0, logitValue);
// log(1 + ReLU)
const logValue = Math.log(1 + reluValue);
// Max pooling over sequence
values[vocabIdx] = Math.max(values[vocabIdx], logValue);
}
}
}
// Convert to sparse representation (only non-zero values)
const sparseVector = {
values: [],
indices: []
};
for (let tokenId = 0; tokenId < vocabSize; tokenId++) {
if (values[tokenId] > 0) {
sparseVector.indices.push(tokenId);
sparseVector.values.push(values[tokenId]);
}
}
sparseVectors.push(sparseVector);
}
yield sparseVectors;
}
}
passageEmbed(texts, batchSize = 256) {
// SPLADE doesn't use passage/query prefixes like dense models
return this.embed(texts, batchSize);
}
async queryEmbed(query) {
return (await this.embed([query]).next()).value[0];
}
listSupportedModels() {
return [
{
model: SparseEmbeddingModel.SpladePPEnV1,
vocabSize: 30522,
description: "SPLADE++ English model for sparse retrieval",
},
];
}
}
function isAddedTokenMap(token) {
return (typeof token === "object" &&
token !== null &&
"token" in token &&
"single_word" in token &&
"rstrip" in token &&
"lstrip" in token &&
"normalized" in token);
}
//# sourceMappingURL=fastembed.js.map