@just-every/ensemble
Version:
LLM provider abstraction layer with unified streaming interface
200 lines • 8.21 kB
JavaScript
import { BaseModelProvider } from './base_provider.js';
import { v4 as uuidv4 } from 'uuid';
import { costTracker } from '../index.js';
import { hasEventHandler } from '../utils/event_controller.js';
export const testProviderConfig = {
streamingDelay: 50,
shouldError: false,
errorMessage: 'Simulated error from test provider',
simulateRateLimit: false,
fixedResponse: undefined,
fixedThinking: undefined,
simulateToolCall: false,
toolName: 'web_search',
toolArguments: { query: 'test query' },
tokenUsage: {
inputTokens: 100,
outputTokens: 200,
},
chunkSize: 5,
};
export function resetTestProviderConfig() {
testProviderConfig.streamingDelay = 50;
testProviderConfig.shouldError = false;
testProviderConfig.errorMessage = 'Simulated error from test provider';
testProviderConfig.simulateRateLimit = false;
testProviderConfig.fixedResponse = undefined;
testProviderConfig.fixedThinking = undefined;
testProviderConfig.simulateToolCall = false;
testProviderConfig.toolName = 'web_search';
testProviderConfig.toolArguments = { query: 'test query' };
testProviderConfig.tokenUsage = {
inputTokens: 100,
outputTokens: 200,
};
testProviderConfig.chunkSize = 5;
}
const sleep = (ms) => new Promise(resolve => setTimeout(resolve, ms));
export class TestProvider extends BaseModelProvider {
config;
constructor(config = testProviderConfig) {
super('test');
this.config = config;
}
async *createResponseStream(messages, model, agent) {
console.log(`[TestProvider] Creating response stream for model: ${model}`);
const lastUserMessage = messages.filter(m => 'role' in m && m.role === 'user').pop();
const userMessageContent = lastUserMessage && 'content' in lastUserMessage
? typeof lastUserMessage.content === 'string'
? lastUserMessage.content
: JSON.stringify(lastUserMessage.content)
: '';
const inputTokenCount = this.config.tokenUsage?.inputTokens || Math.max(50, Math.ceil(userMessageContent.length / 4));
let response;
if (this.config.simulateRateLimit) {
const rateLimitError = '429 Too Many Requests: The server is currently processing too many requests. Please try again later.';
yield {
type: 'error',
error: rateLimitError,
};
return;
}
if (this.config.shouldError) {
yield {
type: 'error',
error: this.config.errorMessage || 'Simulated error from test provider',
};
return;
}
if (this.config.fixedResponse) {
response = this.config.fixedResponse;
}
else {
response = this.generateResponse(userMessageContent);
}
const messageId = uuidv4();
yield {
type: 'message_start',
message_id: messageId,
content: '',
};
if (this.config.fixedThinking) {
yield {
type: 'message_delta',
message_id: messageId,
content: '',
thinking_content: this.config.fixedThinking,
thinking_signature: '(Simulated thinking)',
};
await sleep(this.config.streamingDelay || 50);
}
if (this.config.simulateToolCall && agent) {
const { getToolsFromAgent } = await import('../utils/agent.js');
const currentTools = getToolsFromAgent(agent);
if (currentTools) {
const toolArray = await currentTools;
if (toolArray.length > 0) {
const availableTool = toolArray.find(tool => this.config.toolName ? tool.definition.function.name === this.config.toolName : true);
if (availableTool) {
const toolCall = {
id: uuidv4(),
type: 'function',
function: {
name: availableTool.definition.function.name,
arguments: JSON.stringify(this.config.toolArguments || {
query: userMessageContent.slice(0, 50),
}),
},
};
yield {
type: 'tool_start',
tool_call: toolCall,
};
await sleep(this.config.streamingDelay || 50);
response = `I've used the ${toolCall.function.name} tool to help answer your question.\n\n${response}`;
}
}
}
}
const chunkSize = this.config.chunkSize || 5;
let position = 0;
while (position < response.length) {
const chunk = response.slice(position, position + chunkSize);
position += chunkSize;
yield {
type: 'message_delta',
message_id: messageId,
content: chunk,
order: position / chunkSize,
};
await sleep(this.config.streamingDelay || 50);
}
yield {
type: 'message_complete',
message_id: messageId,
content: response,
};
const outputTokenCount = this.config.tokenUsage?.outputTokens || Math.ceil(response.length / 4);
const calculatedUsage = costTracker.addUsage({
model,
input_tokens: inputTokenCount,
output_tokens: outputTokenCount,
});
if (!hasEventHandler()) {
yield {
type: 'cost_update',
usage: {
...calculatedUsage,
total_tokens: inputTokenCount + outputTokenCount,
},
};
}
}
generateResponse(input) {
const lowercaseInput = input.toLowerCase();
if (lowercaseInput.includes('hello') || lowercaseInput.includes('hi')) {
return "Hello! I'm a test AI model. How can I help you today?";
}
else if (lowercaseInput.includes('help')) {
return "I'm here to help! What do you need assistance with?";
}
else if (lowercaseInput.includes('error') || lowercaseInput.includes('problem')) {
return "I understand you're experiencing an issue. Let me help troubleshoot the problem.";
}
else if (lowercaseInput.includes('json') || lowercaseInput.includes('person')) {
return '{"name": "John Doe", "age": 30}';
}
else if (lowercaseInput.includes('test')) {
return 'This is a test response. The test provider is working correctly!';
}
else if (lowercaseInput.includes('weather')) {
return 'The weather is sunny and 72°F.';
}
else if (lowercaseInput.includes('?')) {
return "That's an interesting question. As a test model, I'm designed to provide simulated responses for testing purposes.";
}
else {
return `I've received your message: "${input.slice(0, 50)}${input.length > 50 ? '...' : ''}". This is a simulated response from the test provider.`;
}
}
async createEmbedding(input, model, opts) {
const generateVector = (text) => {
const dimension = opts?.dimension || 384;
const vector = new Array(dimension);
for (let i = 0; i < dimension; i++) {
const charCode = text.charCodeAt(i % text.length) || 0;
const value = Math.sin(charCode * (i + 1) * 0.01) * 0.5 + 0.5;
vector[i] = value;
}
return vector;
};
if (Array.isArray(input)) {
return input.map(text => generateVector(text));
}
else {
return generateVector(input);
}
}
}
export const testProvider = new TestProvider();
//# sourceMappingURL=test_provider.js.map