@jackhua/mini-langchain
Version:
A lightweight TypeScript implementation of LangChain with cost optimization features
334 lines • 12.3 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.LLMRouter = exports.TaskType = void 0;
const base_1 = require("../llms/base");
/**
* Task types for routing
*/
var TaskType;
(function (TaskType) {
TaskType["CODE_GENERATION"] = "code_generation";
TaskType["CODE_REVIEW"] = "code_review";
TaskType["CREATIVE_WRITING"] = "creative_writing";
TaskType["TRANSLATION"] = "translation";
TaskType["SUMMARIZATION"] = "summarization";
TaskType["QA"] = "question_answering";
TaskType["CONVERSATION"] = "conversation";
TaskType["ANALYSIS"] = "analysis";
TaskType["MATH"] = "math";
TaskType["REASONING"] = "reasoning";
})(TaskType || (exports.TaskType = TaskType = {}));
/**
* Auto-Adaptive LLM Router
* Automatically selects the best LLM based on task analysis
*/
class LLMRouter {
constructor(config) {
this.usageStats = new Map();
this.config = config;
this.initializeStats();
}
/**
* Initialize usage statistics
*/
initializeStats() {
for (const llmName of Object.keys(this.config.llms)) {
this.usageStats.set(llmName, {
calls: 0,
errors: 0,
totalCost: 0,
avgResponseTime: 0
});
}
}
/**
* Route to the best LLM based on prompt and context
*/
async route(prompt, context) {
// Analyze the prompt to determine task type
const detectedTaskType = context?.taskType || await this.detectTaskType(prompt);
const complexity = context?.complexity || await this.assessComplexity(prompt);
// Get suitable LLMs for this task
const suitableLLMs = this.filterSuitableLLMs(detectedTaskType, complexity);
// Score and rank LLMs
const scoredLLMs = this.scoreLLMs(suitableLLMs, context || {});
// Select the best LLM
const selectedLLM = this.selectBestLLM(scoredLLMs, context);
return selectedLLM;
}
/**
* Detect task type from prompt
*/
async detectTaskType(prompt) {
const lowerPrompt = prompt.toLowerCase();
// Code-related keywords
if (this.containsKeywords(lowerPrompt, [
'code', 'function', 'implement', 'debug', 'program',
'class', 'method', 'variable', 'syntax', 'algorithm'
])) {
return TaskType.CODE_GENERATION;
}
// Creative writing keywords
if (this.containsKeywords(lowerPrompt, [
'story', 'poem', 'creative', 'imagine', 'narrative',
'character', 'plot', 'describe', 'scene'
])) {
return TaskType.CREATIVE_WRITING;
}
// Translation keywords
if (this.containsKeywords(lowerPrompt, [
'translate', 'translation', 'language', 'from', 'to',
'english', 'spanish', 'french', 'chinese'
])) {
return TaskType.TRANSLATION;
}
// Summarization keywords
if (this.containsKeywords(lowerPrompt, [
'summarize', 'summary', 'brief', 'key points',
'main ideas', 'tldr', 'overview'
])) {
return TaskType.SUMMARIZATION;
}
// Math keywords
if (this.containsKeywords(lowerPrompt, [
'calculate', 'solve', 'equation', 'math', 'algebra',
'calculus', 'statistics', 'probability'
])) {
return TaskType.MATH;
}
// Analysis keywords
if (this.containsKeywords(lowerPrompt, [
'analyze', 'analysis', 'evaluate', 'assess',
'compare', 'contrast', 'examine'
])) {
return TaskType.ANALYSIS;
}
// Default to QA
return TaskType.QA;
}
/**
* Check if text contains keywords
*/
containsKeywords(text, keywords) {
return keywords.some(keyword => text.includes(keyword));
}
/**
* Assess complexity of the prompt
*/
async assessComplexity(prompt) {
const wordCount = prompt.split(' ').length;
const hasMultipleParts = prompt.includes('and') || prompt.includes('then');
const hasComplexRequirements = this.containsKeywords(prompt.toLowerCase(), [
'detailed', 'comprehensive', 'in-depth', 'advanced',
'complex', 'sophisticated', 'elaborate'
]);
if (hasComplexRequirements || wordCount > 100 || hasMultipleParts) {
return 'high';
}
else if (wordCount > 50) {
return 'medium';
}
return 'low';
}
/**
* Filter LLMs suitable for the task
*/
filterSuitableLLMs(taskType, complexity) {
const suitable = [];
for (const [name, config] of Object.entries(this.config.llms)) {
// Check if LLM has required capabilities
const hasCapability = this.checkCapability(config.capabilities, taskType);
// Check if LLM can handle complexity
const canHandleComplexity = this.checkComplexityMatch(config.qualityScore, complexity);
if (hasCapability && canHandleComplexity) {
suitable.push([name, config]);
}
}
return suitable;
}
/**
* Check if LLM has capability for task
*/
checkCapability(capabilities, taskType) {
const requiredCapabilities = {
[TaskType.CODE_GENERATION]: ['code', 'programming'],
[TaskType.CODE_REVIEW]: ['code', 'analysis'],
[TaskType.CREATIVE_WRITING]: ['creative', 'writing'],
[TaskType.TRANSLATION]: ['translation', 'multilingual'],
[TaskType.SUMMARIZATION]: ['summarization', 'text'],
[TaskType.QA]: ['qa', 'general'],
[TaskType.CONVERSATION]: ['conversation', 'chat'],
[TaskType.ANALYSIS]: ['analysis', 'reasoning'],
[TaskType.MATH]: ['math', 'calculation'],
[TaskType.REASONING]: ['reasoning', 'logic']
};
const required = requiredCapabilities[taskType] || ['general'];
return required.some(req => capabilities.some(cap => cap.toLowerCase().includes(req)));
}
/**
* Check if LLM quality matches complexity
*/
checkComplexityMatch(qualityScore, complexity) {
switch (complexity) {
case 'low':
return true; // Any LLM can handle low complexity
case 'medium':
return qualityScore >= 5;
case 'high':
return qualityScore >= 7;
}
}
/**
* Score LLMs based on multiple factors
*/
scoreLLMs(llms, context) {
return llms.map(([name, config]) => {
let score = 0;
// Quality score (weight: 40%)
score += config.qualityScore * 4;
// Speed score (weight: 30%)
if (context.requireSpeed) {
score += config.speedScore * 6; // Double weight for speed
}
else {
score += config.speedScore * 3;
}
// Cost efficiency (weight: 30%)
if (context.requireQuality) {
score += (10 - config.costPerToken * 10) * 1.5; // Less weight on cost
}
else {
score += (10 - config.costPerToken * 10) * 3;
}
// Bonus for low error rate
const stats = this.usageStats.get(name);
if (stats && stats.calls > 0) {
const errorRate = stats.errors / stats.calls;
score += (1 - errorRate) * 10;
}
// Penalty for previous errors with this prompt
if (context.previousErrors?.includes(name)) {
score -= 20;
}
return { name, score, llm: config.llm };
}).sort((a, b) => b.score - a.score);
}
/**
* Select the best LLM from scored options
*/
selectBestLLM(scoredLLMs, context) {
if (scoredLLMs.length === 0) {
// Fallback to default
const defaultName = this.config.defaultLLM || Object.keys(this.config.llms)[0];
return this.config.llms[defaultName].llm;
}
// Apply load balancing if enabled
if (this.config.enableLoadBalancing && scoredLLMs.length > 1) {
// Simple round-robin with weighted probability
const topTwo = scoredLLMs.slice(0, 2);
const totalScore = topTwo.reduce((sum, llm) => sum + llm.score, 0);
const random = Math.random() * totalScore;
let accumulated = 0;
for (const llm of topTwo) {
accumulated += llm.score;
if (random <= accumulated) {
return llm.llm;
}
}
}
// Return the highest scoring LLM
return scoredLLMs[0].llm;
}
/**
* Update usage statistics after a call
*/
updateStats(llmName, success, cost, responseTime) {
const stats = this.usageStats.get(llmName);
if (!stats)
return;
stats.calls++;
if (!success)
stats.errors++;
stats.totalCost += cost;
// Update average response time
stats.avgResponseTime =
(stats.avgResponseTime * (stats.calls - 1) + responseTime) / stats.calls;
}
/**
* Get routing statistics
*/
getStats() {
const stats = {};
for (const [name, data] of this.usageStats.entries()) {
stats[name] = {
...data,
errorRate: data.calls > 0 ? data.errors / data.calls : 0,
avgCostPerCall: data.calls > 0 ? data.totalCost / data.calls : 0
};
}
return stats;
}
/**
* Create a routed LLM that automatically selects the best provider
*/
createRoutedLLM() {
const router = this;
// Create a class that extends BaseLLM
class RoutedLLM extends base_1.BaseLLM {
async generate(messages, options) {
const prompt = messages.map(m => m.content).join('\n');
const context = {
requireSpeed: options?.timeout ? options.timeout < 5000 : false,
maxCost: router.config.costThreshold
};
const llm = await router.route(prompt, context);
const startTime = Date.now();
try {
const result = await llm.generate(messages, options);
const responseTime = Date.now() - startTime;
// Update stats
const llmName = Object.entries(router.config.llms)
.find(([_, config]) => config.llm === llm)?.[0];
if (llmName) {
const cost = router.estimateCost(result.text, llmName);
router.updateStats(llmName, true, cost, responseTime);
}
return result;
}
catch (error) {
const llmName = Object.entries(router.config.llms)
.find(([_, config]) => config.llm === llm)?.[0];
if (llmName) {
router.updateStats(llmName, false, 0, Date.now() - startTime);
}
throw error;
}
}
async *stream(messages, options) {
const prompt = messages.map(m => m.content).join('\n');
const llm = await router.route(prompt);
yield* llm.stream(messages, options);
}
get identifyingParams() {
return { type: 'routed', providers: Object.keys(router.config.llms) };
}
get llmType() {
return 'routed';
}
}
return new RoutedLLM();
}
/**
* Estimate cost for a response
*/
estimateCost(text, llmName) {
const config = this.config.llms[llmName];
if (!config)
return 0;
// Rough token estimation
const tokens = Math.ceil(text.length / 4);
return tokens * config.costPerToken;
}
}
exports.LLMRouter = LLMRouter;
//# sourceMappingURL=router.js.map