UNPKG

@jackhua/mini-langchain

Version:

A lightweight TypeScript implementation of LangChain with cost optimization features

334 lines 12.3 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.LLMRouter = exports.TaskType = void 0; const base_1 = require("../llms/base"); /** * Task types for routing */ var TaskType; (function (TaskType) { TaskType["CODE_GENERATION"] = "code_generation"; TaskType["CODE_REVIEW"] = "code_review"; TaskType["CREATIVE_WRITING"] = "creative_writing"; TaskType["TRANSLATION"] = "translation"; TaskType["SUMMARIZATION"] = "summarization"; TaskType["QA"] = "question_answering"; TaskType["CONVERSATION"] = "conversation"; TaskType["ANALYSIS"] = "analysis"; TaskType["MATH"] = "math"; TaskType["REASONING"] = "reasoning"; })(TaskType || (exports.TaskType = TaskType = {})); /** * Auto-Adaptive LLM Router * Automatically selects the best LLM based on task analysis */ class LLMRouter { constructor(config) { this.usageStats = new Map(); this.config = config; this.initializeStats(); } /** * Initialize usage statistics */ initializeStats() { for (const llmName of Object.keys(this.config.llms)) { this.usageStats.set(llmName, { calls: 0, errors: 0, totalCost: 0, avgResponseTime: 0 }); } } /** * Route to the best LLM based on prompt and context */ async route(prompt, context) { // Analyze the prompt to determine task type const detectedTaskType = context?.taskType || await this.detectTaskType(prompt); const complexity = context?.complexity || await this.assessComplexity(prompt); // Get suitable LLMs for this task const suitableLLMs = this.filterSuitableLLMs(detectedTaskType, complexity); // Score and rank LLMs const scoredLLMs = this.scoreLLMs(suitableLLMs, context || {}); // Select the best LLM const selectedLLM = this.selectBestLLM(scoredLLMs, context); return selectedLLM; } /** * Detect task type from prompt */ async detectTaskType(prompt) { const lowerPrompt = prompt.toLowerCase(); // Code-related keywords if (this.containsKeywords(lowerPrompt, [ 'code', 'function', 'implement', 'debug', 'program', 'class', 'method', 'variable', 'syntax', 'algorithm' ])) { return TaskType.CODE_GENERATION; } // Creative writing keywords if (this.containsKeywords(lowerPrompt, [ 'story', 'poem', 'creative', 'imagine', 'narrative', 'character', 'plot', 'describe', 'scene' ])) { return TaskType.CREATIVE_WRITING; } // Translation keywords if (this.containsKeywords(lowerPrompt, [ 'translate', 'translation', 'language', 'from', 'to', 'english', 'spanish', 'french', 'chinese' ])) { return TaskType.TRANSLATION; } // Summarization keywords if (this.containsKeywords(lowerPrompt, [ 'summarize', 'summary', 'brief', 'key points', 'main ideas', 'tldr', 'overview' ])) { return TaskType.SUMMARIZATION; } // Math keywords if (this.containsKeywords(lowerPrompt, [ 'calculate', 'solve', 'equation', 'math', 'algebra', 'calculus', 'statistics', 'probability' ])) { return TaskType.MATH; } // Analysis keywords if (this.containsKeywords(lowerPrompt, [ 'analyze', 'analysis', 'evaluate', 'assess', 'compare', 'contrast', 'examine' ])) { return TaskType.ANALYSIS; } // Default to QA return TaskType.QA; } /** * Check if text contains keywords */ containsKeywords(text, keywords) { return keywords.some(keyword => text.includes(keyword)); } /** * Assess complexity of the prompt */ async assessComplexity(prompt) { const wordCount = prompt.split(' ').length; const hasMultipleParts = prompt.includes('and') || prompt.includes('then'); const hasComplexRequirements = this.containsKeywords(prompt.toLowerCase(), [ 'detailed', 'comprehensive', 'in-depth', 'advanced', 'complex', 'sophisticated', 'elaborate' ]); if (hasComplexRequirements || wordCount > 100 || hasMultipleParts) { return 'high'; } else if (wordCount > 50) { return 'medium'; } return 'low'; } /** * Filter LLMs suitable for the task */ filterSuitableLLMs(taskType, complexity) { const suitable = []; for (const [name, config] of Object.entries(this.config.llms)) { // Check if LLM has required capabilities const hasCapability = this.checkCapability(config.capabilities, taskType); // Check if LLM can handle complexity const canHandleComplexity = this.checkComplexityMatch(config.qualityScore, complexity); if (hasCapability && canHandleComplexity) { suitable.push([name, config]); } } return suitable; } /** * Check if LLM has capability for task */ checkCapability(capabilities, taskType) { const requiredCapabilities = { [TaskType.CODE_GENERATION]: ['code', 'programming'], [TaskType.CODE_REVIEW]: ['code', 'analysis'], [TaskType.CREATIVE_WRITING]: ['creative', 'writing'], [TaskType.TRANSLATION]: ['translation', 'multilingual'], [TaskType.SUMMARIZATION]: ['summarization', 'text'], [TaskType.QA]: ['qa', 'general'], [TaskType.CONVERSATION]: ['conversation', 'chat'], [TaskType.ANALYSIS]: ['analysis', 'reasoning'], [TaskType.MATH]: ['math', 'calculation'], [TaskType.REASONING]: ['reasoning', 'logic'] }; const required = requiredCapabilities[taskType] || ['general']; return required.some(req => capabilities.some(cap => cap.toLowerCase().includes(req))); } /** * Check if LLM quality matches complexity */ checkComplexityMatch(qualityScore, complexity) { switch (complexity) { case 'low': return true; // Any LLM can handle low complexity case 'medium': return qualityScore >= 5; case 'high': return qualityScore >= 7; } } /** * Score LLMs based on multiple factors */ scoreLLMs(llms, context) { return llms.map(([name, config]) => { let score = 0; // Quality score (weight: 40%) score += config.qualityScore * 4; // Speed score (weight: 30%) if (context.requireSpeed) { score += config.speedScore * 6; // Double weight for speed } else { score += config.speedScore * 3; } // Cost efficiency (weight: 30%) if (context.requireQuality) { score += (10 - config.costPerToken * 10) * 1.5; // Less weight on cost } else { score += (10 - config.costPerToken * 10) * 3; } // Bonus for low error rate const stats = this.usageStats.get(name); if (stats && stats.calls > 0) { const errorRate = stats.errors / stats.calls; score += (1 - errorRate) * 10; } // Penalty for previous errors with this prompt if (context.previousErrors?.includes(name)) { score -= 20; } return { name, score, llm: config.llm }; }).sort((a, b) => b.score - a.score); } /** * Select the best LLM from scored options */ selectBestLLM(scoredLLMs, context) { if (scoredLLMs.length === 0) { // Fallback to default const defaultName = this.config.defaultLLM || Object.keys(this.config.llms)[0]; return this.config.llms[defaultName].llm; } // Apply load balancing if enabled if (this.config.enableLoadBalancing && scoredLLMs.length > 1) { // Simple round-robin with weighted probability const topTwo = scoredLLMs.slice(0, 2); const totalScore = topTwo.reduce((sum, llm) => sum + llm.score, 0); const random = Math.random() * totalScore; let accumulated = 0; for (const llm of topTwo) { accumulated += llm.score; if (random <= accumulated) { return llm.llm; } } } // Return the highest scoring LLM return scoredLLMs[0].llm; } /** * Update usage statistics after a call */ updateStats(llmName, success, cost, responseTime) { const stats = this.usageStats.get(llmName); if (!stats) return; stats.calls++; if (!success) stats.errors++; stats.totalCost += cost; // Update average response time stats.avgResponseTime = (stats.avgResponseTime * (stats.calls - 1) + responseTime) / stats.calls; } /** * Get routing statistics */ getStats() { const stats = {}; for (const [name, data] of this.usageStats.entries()) { stats[name] = { ...data, errorRate: data.calls > 0 ? data.errors / data.calls : 0, avgCostPerCall: data.calls > 0 ? data.totalCost / data.calls : 0 }; } return stats; } /** * Create a routed LLM that automatically selects the best provider */ createRoutedLLM() { const router = this; // Create a class that extends BaseLLM class RoutedLLM extends base_1.BaseLLM { async generate(messages, options) { const prompt = messages.map(m => m.content).join('\n'); const context = { requireSpeed: options?.timeout ? options.timeout < 5000 : false, maxCost: router.config.costThreshold }; const llm = await router.route(prompt, context); const startTime = Date.now(); try { const result = await llm.generate(messages, options); const responseTime = Date.now() - startTime; // Update stats const llmName = Object.entries(router.config.llms) .find(([_, config]) => config.llm === llm)?.[0]; if (llmName) { const cost = router.estimateCost(result.text, llmName); router.updateStats(llmName, true, cost, responseTime); } return result; } catch (error) { const llmName = Object.entries(router.config.llms) .find(([_, config]) => config.llm === llm)?.[0]; if (llmName) { router.updateStats(llmName, false, 0, Date.now() - startTime); } throw error; } } async *stream(messages, options) { const prompt = messages.map(m => m.content).join('\n'); const llm = await router.route(prompt); yield* llm.stream(messages, options); } get identifyingParams() { return { type: 'routed', providers: Object.keys(router.config.llms) }; } get llmType() { return 'routed'; } } return new RoutedLLM(); } /** * Estimate cost for a response */ estimateCost(text, llmName) { const config = this.config.llms[llmName]; if (!config) return 0; // Rough token estimation const tokens = Math.ceil(text.length / 4); return tokens * config.costPerToken; } } exports.LLMRouter = LLMRouter; //# sourceMappingURL=router.js.map