@elpassion/semantic-chunking
Version:
Semantically create chunks from large texts. Useful for workflows involving large language models (LLMs).
217 lines (184 loc) • 6.51 kB
JavaScript
import { LRUCache } from "lru-cache";
import pLimit from "p-limit";
// --------------------------
// -- LocalEmbeddingModel class --
// --------------------------
export class LocalEmbeddingModel {
constructor(transformers) {
if (!transformers) {
throw new Error("Transformers object is required in constructor");
}
if (!transformers.env || !transformers.pipeline || !transformers.AutoTokenizer) {
throw new Error("Transformers object must contain env, pipeline, and AutoTokenizer");
}
this.transformers = transformers;
this.tokenizer = null;
this.generateEmbedding = null;
this.modelName = null;
this.dtype = null;
this.embeddingCache = new LRUCache({
max: 500,
maxSize: 50_000_000,
sizeCalculation: (value, key) => {
return value.length * 4 + key.length;
},
ttl: 1000 * 60 * 60,
});
}
async initialize(onnxEmbeddingModel, dtype = "fp32", localModelPath = null, modelCacheDir = null) {
// Configure environment
this.transformers.env.allowRemoteModels = true;
if (localModelPath) this.transformers.env.localModelPath = localModelPath;
if (modelCacheDir) this.transformers.env.cacheDir = modelCacheDir;
this.tokenizer = await this.transformers.AutoTokenizer.from_pretrained(onnxEmbeddingModel);
this.generateEmbedding = await this.transformers.pipeline("feature-extraction", onnxEmbeddingModel, {
dtype: dtype,
});
this.modelName = onnxEmbeddingModel;
this.dtype = dtype;
this.embeddingCache.clear();
return {
modelName: onnxEmbeddingModel,
dtype: dtype,
};
}
async createEmbedding(text) {
if (!this.generateEmbedding) {
throw new Error("Model not initialized. Call initialize() first.");
}
const cached = this.embeddingCache.get(text);
if (cached) {
return cached;
}
const embeddings = await this.generateEmbedding(text, {
pooling: "mean",
normalize: true,
});
this.embeddingCache.set(text, embeddings.data);
return embeddings.data;
}
async tokenize(text, options = {}) {
if (!this.tokenizer) {
throw new Error("Model not initialized. Call initialize() first.");
}
const tokenized = await this.tokenizer(text, options);
return {
size: tokenized.input_ids.size,
};
}
getModelInfo() {
return {
modelName: this.modelName,
dtype: this.dtype,
};
}
}
// --------------------------
// -- OpenAIEmbedding class --
// --------------------------
export class OpenAIEmbedding {
constructor(openaiClient, options = {}) {
if (!openaiClient) {
throw new Error("OpenAI client is required in constructor");
}
this.openaiClient = openaiClient;
this.modelName = null;
this.embeddingCache = new LRUCache({
max: 500,
maxSize: 50_000_000,
sizeCalculation: (value, key) => {
return value.length * 4 + key.length;
},
ttl: 1000 * 60 * 60,
});
// Rate limiting and retry configuration
this.maxRetries = options.maxRetries || 5;
this.retryDelays = options.retryDelays || [10000, 30000, 60000, 60000, 60000]; // 10s, 30s, 60s
this.concurrency = options.concurrency || 100; // Number of concurrent requests
// Create p-limit instance for concurrency control
this.limit = pLimit(this.concurrency);
}
async initialize(modelName = "text-embedding-3-small") {
this.modelName = modelName;
this.embeddingCache.clear();
return {
modelName: modelName,
dtype: "api", // API-based, no dtype
};
}
// Retry logic with exponential backoff
async _retryOperation(operation) {
let lastError;
for (let attempt = 0; attempt <= this.maxRetries; attempt++) {
try {
return await operation();
} catch (error) {
lastError = error;
// Check if it's a rate limit error
const isRateLimit =
error.message &&
(error.message.includes("429") || error.message.includes("rate limit") || error.message.includes("quota"));
if (isRateLimit && attempt < this.maxRetries) {
// Extract retry-after time from error message if available
const retryAfterMatch = error.message.match(/retry after (\d+) seconds/);
const retryAfter = retryAfterMatch ? parseInt(retryAfterMatch[1]) * 1000 : null;
// Use retry-after time or predefined delays
const delay = retryAfter || this.retryDelays[attempt] || this.retryDelays[this.retryDelays.length - 1];
console.log(`Rate limit hit, retrying in ${delay / 1000}s (attempt ${attempt + 1}/${this.maxRetries + 1})`);
await new Promise((resolve) => setTimeout(resolve, delay));
continue;
}
// For other errors, use shorter exponential backoff
if (!isRateLimit && attempt < this.maxRetries) {
const delay = 1000 * Math.pow(2, attempt);
console.log(`Request failed, retrying in ${delay}ms (attempt ${attempt + 1}/${this.maxRetries + 1})`);
await new Promise((resolve) => setTimeout(resolve, delay));
continue;
}
break;
}
}
throw lastError;
}
// Simple embedding execution with retry
async _executeEmbedding(text) {
return this._retryOperation(async () => {
const response = await this.openaiClient.embeddings.create({
model: this.modelName,
input: text,
});
return response.data[0].embedding;
});
}
async createEmbedding(text) {
if (!this.modelName) {
throw new Error("Model not initialized. Call initialize() first.");
}
const cached = this.embeddingCache.get(text);
if (cached) {
return cached;
}
try {
// Use p-limit to control concurrency with retry logic
const embedding = await this.limit(() => this._executeEmbedding(text));
this.embeddingCache.set(text, embedding);
return embedding;
} catch (error) {
throw new Error(`OpenAI API error: ${error.message}`);
}
}
async tokenize(text, options = {}) {
if (!this.modelName) {
throw new Error("Model not initialized. Call initialize() first.");
}
// For tokenization, we don't need to queue since it's just a local calculation
const approximateTokenCount = Math.ceil(text.length / 4);
return { size: approximateTokenCount };
}
getModelInfo() {
return {
modelName: this.modelName,
dtype: "api",
};
}
}