UNPKG

@elpassion/semantic-chunking

Version:

Semantically create chunks from large texts. Useful for workflows involving large language models (LLMs).

217 lines (184 loc) 6.51 kB
import { LRUCache } from "lru-cache"; import pLimit from "p-limit"; // -------------------------- // -- LocalEmbeddingModel class -- // -------------------------- export class LocalEmbeddingModel { constructor(transformers) { if (!transformers) { throw new Error("Transformers object is required in constructor"); } if (!transformers.env || !transformers.pipeline || !transformers.AutoTokenizer) { throw new Error("Transformers object must contain env, pipeline, and AutoTokenizer"); } this.transformers = transformers; this.tokenizer = null; this.generateEmbedding = null; this.modelName = null; this.dtype = null; this.embeddingCache = new LRUCache({ max: 500, maxSize: 50_000_000, sizeCalculation: (value, key) => { return value.length * 4 + key.length; }, ttl: 1000 * 60 * 60, }); } async initialize(onnxEmbeddingModel, dtype = "fp32", localModelPath = null, modelCacheDir = null) { // Configure environment this.transformers.env.allowRemoteModels = true; if (localModelPath) this.transformers.env.localModelPath = localModelPath; if (modelCacheDir) this.transformers.env.cacheDir = modelCacheDir; this.tokenizer = await this.transformers.AutoTokenizer.from_pretrained(onnxEmbeddingModel); this.generateEmbedding = await this.transformers.pipeline("feature-extraction", onnxEmbeddingModel, { dtype: dtype, }); this.modelName = onnxEmbeddingModel; this.dtype = dtype; this.embeddingCache.clear(); return { modelName: onnxEmbeddingModel, dtype: dtype, }; } async createEmbedding(text) { if (!this.generateEmbedding) { throw new Error("Model not initialized. Call initialize() first."); } const cached = this.embeddingCache.get(text); if (cached) { return cached; } const embeddings = await this.generateEmbedding(text, { pooling: "mean", normalize: true, }); this.embeddingCache.set(text, embeddings.data); return embeddings.data; } async tokenize(text, options = {}) { if (!this.tokenizer) { throw new Error("Model not initialized. Call initialize() first."); } const tokenized = await this.tokenizer(text, options); return { size: tokenized.input_ids.size, }; } getModelInfo() { return { modelName: this.modelName, dtype: this.dtype, }; } } // -------------------------- // -- OpenAIEmbedding class -- // -------------------------- export class OpenAIEmbedding { constructor(openaiClient, options = {}) { if (!openaiClient) { throw new Error("OpenAI client is required in constructor"); } this.openaiClient = openaiClient; this.modelName = null; this.embeddingCache = new LRUCache({ max: 500, maxSize: 50_000_000, sizeCalculation: (value, key) => { return value.length * 4 + key.length; }, ttl: 1000 * 60 * 60, }); // Rate limiting and retry configuration this.maxRetries = options.maxRetries || 5; this.retryDelays = options.retryDelays || [10000, 30000, 60000, 60000, 60000]; // 10s, 30s, 60s this.concurrency = options.concurrency || 100; // Number of concurrent requests // Create p-limit instance for concurrency control this.limit = pLimit(this.concurrency); } async initialize(modelName = "text-embedding-3-small") { this.modelName = modelName; this.embeddingCache.clear(); return { modelName: modelName, dtype: "api", // API-based, no dtype }; } // Retry logic with exponential backoff async _retryOperation(operation) { let lastError; for (let attempt = 0; attempt <= this.maxRetries; attempt++) { try { return await operation(); } catch (error) { lastError = error; // Check if it's a rate limit error const isRateLimit = error.message && (error.message.includes("429") || error.message.includes("rate limit") || error.message.includes("quota")); if (isRateLimit && attempt < this.maxRetries) { // Extract retry-after time from error message if available const retryAfterMatch = error.message.match(/retry after (\d+) seconds/); const retryAfter = retryAfterMatch ? parseInt(retryAfterMatch[1]) * 1000 : null; // Use retry-after time or predefined delays const delay = retryAfter || this.retryDelays[attempt] || this.retryDelays[this.retryDelays.length - 1]; console.log(`Rate limit hit, retrying in ${delay / 1000}s (attempt ${attempt + 1}/${this.maxRetries + 1})`); await new Promise((resolve) => setTimeout(resolve, delay)); continue; } // For other errors, use shorter exponential backoff if (!isRateLimit && attempt < this.maxRetries) { const delay = 1000 * Math.pow(2, attempt); console.log(`Request failed, retrying in ${delay}ms (attempt ${attempt + 1}/${this.maxRetries + 1})`); await new Promise((resolve) => setTimeout(resolve, delay)); continue; } break; } } throw lastError; } // Simple embedding execution with retry async _executeEmbedding(text) { return this._retryOperation(async () => { const response = await this.openaiClient.embeddings.create({ model: this.modelName, input: text, }); return response.data[0].embedding; }); } async createEmbedding(text) { if (!this.modelName) { throw new Error("Model not initialized. Call initialize() first."); } const cached = this.embeddingCache.get(text); if (cached) { return cached; } try { // Use p-limit to control concurrency with retry logic const embedding = await this.limit(() => this._executeEmbedding(text)); this.embeddingCache.set(text, embedding); return embedding; } catch (error) { throw new Error(`OpenAI API error: ${error.message}`); } } async tokenize(text, options = {}) { if (!this.modelName) { throw new Error("Model not initialized. Call initialize() first."); } // For tokenization, we don't need to queue since it's just a local calculation const approximateTokenCount = Math.ceil(text.length / 4); return { size: approximateTokenCount }; } getModelInfo() { return { modelName: this.modelName, dtype: "api", }; } }