quese-test
Version:
Package that make easier the searching process in javascript, through Embeddings and Semantic Similarity
175 lines (162 loc) • 5.71 kB
JavaScript
import { pipeline, env } from "@xenova/transformers";
env.allowLocalModels = false;
env.backends.onnx.wasm.numThreads = 1;
// import { openDB } from 'idb';
/**
* This class uses the Singleton pattern to ensure that only one instance of the
* pipeline is loaded. This is because loading the pipeline is an expensive
* operation and we don't want to do it every time we want to translate a sentence.
*/
export class MyTranslationPipeline {
static task = "feature-extraction";
static model = "Xenova/all-MiniLM-L6-v2";
static instance = null;
static async getInstance(progress_callback = null) {
if (this.instance === null) {
this.instance = pipeline(this.task, this.model, { progress_callback, });
}
return this.instance;
}
}
function dotProduct(a, b) {
if (a.length !== b.length) {
throw new Error("Both embeddings must have the same length!");
}
let result = 0;
for (let i = 0; i < a.length; i++) {
result += a[i] * b[i];
}
return result;
}
// async function saveToIndexedDB(key, value) {
// const db = await openDB('my-database', 1);
// const tx = db.transaction('embeddingCache', 'readwrite');
// tx.store.put(value, key);
// await tx.done;
// }
// async function getFromIndexedDB(key) {
// const db = await openDB('my-database', 1);
// const tx = db.transaction('embeddingCache', 'readonly');
// return tx.store.get(key)
// }
let embeddingCache = {}
// Listen for messages from the main thread
self.addEventListener("message", async (event) => {
try {
let data = event.data.data;
let text = event.data.text;
let by = event.data.by;
let template = event.data.template;
let accuracy = event.data.accuracy;
let dataFormatted = [];
if (!data || !data.length) {
throw new Error('The "data" param is required!');
}
if (!text) {
throw new Error('The "query" param is required!');
}
if (!by && !template) {
throw new Error(
'At least one of the params: "by" or "template" is needed!'
);
} else if (by && template) {
//UNIQ FROM LODASH REPLACED!
const uniqueProps = new Set(template.match(/\{(.*?)\}/g) || []);
const formattedProps = Array.from(uniqueProps).map((prop) =>
prop.replace("{", "").replace("}", "")
);
dataFormatted = data.map((item) => {
const formatted = template.replace(/\{(.*?)\}/g, (match, prop) => {
return item[prop];
});
return formatted;
});
} else if (template) {
const uniqueProps = new Set(template.match(/\{(.*?)\}/g) || []);
const formattedProps = Array.from(uniqueProps).map((prop) =>
prop.replace("{", "").replace("}", "")
);
dataFormatted = data.map((item) => {
const formatted = template.replace(/\{(.*?)\}/g, (match, prop) => {
return item[prop];
});
return formatted;
});
} else {
dataFormatted = data.map((item) => item[by]);
}
// Retrieve the translation pipeline. When called for the first time,
// this will load the pipeline and save it for future use.
let translator = await MyTranslationPipeline.getInstance((x) => {
// We also add a progress callback to the pipeline so that we can
// track model loading.
self.postMessage(x);
});
// Actually perform the translation
let embeddingsList = [];
for (const fragment of dataFormatted) {
let embedding;
// let embeddingCache = await getFromIndexedDB('embeddingCache')
if (embeddingCache[fragment]) {
// Usar el embedding desde el caché si está disponible
embedding = embeddingCache[fragment];
} else {
const calculatedEmbedding = await translator(fragment, {
pooling: "mean",
normalize: true,
callback_function: (x) => {
self.postMessage({
status: "update",
output: translator.tokenizer.decode(x[0].output_token_ids, {
skip_special_tokens: true,
}),
});
},
});
embedding = Array.from(calculatedEmbedding.data);
embeddingCache[fragment] = embedding;
// await saveToIndexedDB('embeddingCache', embeddingCache);
}
embeddingsList.push(embedding);
}
const qEmbedding = await translator(text, {
pooling: "mean",
normalize: true,
callback_function: (x) => {
self.postMessage({
status: "update",
output: translator.tokenizer.decode(x[0].output_token_ids, {
skip_special_tokens: true,
}),
});
},
});
const qEmbedding_formatted = Array.from(qEmbedding.data);
const similaritiesOrder = [];
for (const embedding of embeddingsList) {
const similarity = dotProduct(embedding, qEmbedding_formatted);
similaritiesOrder.push(similarity);
}
const sortedIndices = Array.from(
{ length: similaritiesOrder.length },
(_, index) => index
).sort((a, b) => similaritiesOrder[b] - similaritiesOrder[a]);
const results = [];
for (const index of sortedIndices) {
const similar = similaritiesOrder[index];
if (similar <= accuracy) {
continue;
} else {
// Add data that meets the threshold
results.push(data[index]);
}
}
// Send the output back to the main thread
self.postMessage({
status: "complete",
output: results,
});
} catch (error) {
console.log(error);
}
});