UNPKG

@unified-llm/core

Version:

Unified LLM interface (in-memory).

460 lines 21.2 kB
import Anthropic from '@anthropic-ai/sdk'; import BaseProvider from '../base-provider'; import { validateChatRequest } from '../../utils/validation'; import { ResponseFormat } from '../../response-format'; // Anthropic実装 export class AnthropicProvider extends BaseProvider { constructor({ apiKey, model, tools }) { super({ model, tools }); this.client = new Anthropic({ apiKey }); } async chat(request) { validateChatRequest(request); try { const anthropicRequest = await this.convertToAnthropicFormat(request); let response = await this.client.messages.create(anthropicRequest); let messages = [...anthropicRequest.messages]; // stop_reason が 'tool_use' の場合、ツールを実行して結果を返す while (response.stop_reason === 'tool_use' && this.tools) { const toolUseBlocks = response.content.filter(block => block.type === 'tool_use'); const toolResults = []; for (const toolBlock of toolUseBlocks) { const customFunction = this.tools.find(func => func.function.name === toolBlock.name); if (customFunction) { try { // CustomFunctionのargsとtool_useのinputをマージ const mergedArgs = { ...(customFunction.args || {}), ...toolBlock.input }; const result = await customFunction.handler(mergedArgs); toolResults.push({ type: 'tool_result', tool_use_id: toolBlock.id, content: typeof result === 'string' ? result : JSON.stringify(result), }); } catch (error) { toolResults.push({ type: 'tool_result', tool_use_id: toolBlock.id, is_error: true, content: error instanceof Error ? error.message : 'Unknown error', }); } } } // ツール実行結果を含めて再度リクエスト if (toolResults.length > 0) { messages = [ ...messages, { role: 'assistant', content: response.content, }, { role: 'user', content: toolResults, }, ]; const followUpRequest = { ...anthropicRequest, messages: messages, }; response = await this.client.messages.create(followUpRequest); } else { // ツール結果がない場合はループを抜ける break; } } return this.convertFromAnthropicFormat(response); } catch (error) { throw this.handleError(error); } } async *stream(request) { validateChatRequest(request); const anthropicRequest = await this.convertToAnthropicFormat(request); let messages = [...anthropicRequest.messages]; // Keep trying to get a response until we don't get tool calls while (true) { const stream = await this.client.messages.create({ ...anthropicRequest, messages, stream: true, }); // Accumulate content blocks const contentBlocks = []; let stopReason = null; let hasToolUse = false; // First pass: detect if there are tool calls const allChunks = []; for await (const chunk of stream) { allChunks.push(chunk); if (chunk.type === 'content_block_start') { contentBlocks.push({ ...chunk.content_block }); if (chunk.content_block.type === 'tool_use') { hasToolUse = true; } } else if (chunk.type === 'content_block_delta') { const blockIndex = chunk.index; if (blockIndex < contentBlocks.length) { const block = contentBlocks[blockIndex]; if (block.type === 'text' && chunk.delta.type === 'text_delta') { block.text = (block.text || '') + chunk.delta.text; } else if (block.type === 'tool_use' && chunk.delta.type === 'input_json_delta') { // Accumulate tool input JSON if (!block._rawInput) block._rawInput = ''; block._rawInput += chunk.delta.partial_json; } } } else if (chunk.type === 'content_block_stop') { const blockIndex = chunk.index; if (blockIndex < contentBlocks.length) { const block = contentBlocks[blockIndex]; if (block.type === 'tool_use' && block._rawInput) { // Parse the complete tool input try { block.input = JSON.parse(block._rawInput); delete block._rawInput; } catch (_e) { block.input = {}; } } } } else if (chunk.type === 'message_delta') { if (chunk.delta.stop_reason) { stopReason = chunk.delta.stop_reason; } } } // If we have tool use and tools are available, execute them if (stopReason === 'tool_use' && this.tools && hasToolUse) { const toolUseBlocks = contentBlocks.filter(block => block.type === 'tool_use'); const toolResults = []; for (const toolBlock of toolUseBlocks) { const customFunction = this.tools.find(func => func.function.name === toolBlock.name); if (customFunction) { try { // Merge default args with tool input const mergedArgs = { ...(customFunction.args || {}), ...toolBlock.input }; const result = await customFunction.handler(mergedArgs); toolResults.push({ type: 'tool_result', tool_use_id: toolBlock.id, content: typeof result === 'string' ? result : JSON.stringify(result), }); } catch (error) { toolResults.push({ type: 'tool_result', tool_use_id: toolBlock.id, is_error: true, content: error instanceof Error ? error.message : 'Unknown error', }); } } } // Continue with tool results if we have any if (toolResults.length > 0) { // Clean up contentBlocks before sending to API const cleanContentBlocks = contentBlocks.map(block => { const cleanBlock = { ...block }; delete cleanBlock._rawInput; return cleanBlock; }); messages = [ ...messages, { role: 'assistant', content: cleanContentBlocks, }, { role: 'user', content: toolResults, }, ]; // Continue the loop to get the next response continue; } } // Second pass: yield chunks if (!hasToolUse) { // No tool use, stream text deltas immediately for (const chunk of allChunks) { if (chunk.type === 'content_block_delta' && chunk.delta.type === 'text_delta') { yield this.convertStreamChunk(chunk); } } } else { // Tool use was executed, now stream the final response // Convert accumulated content blocks to streaming format for (const block of contentBlocks) { if (block.type === 'text' && block.text) { // Simulate text streaming const text = block.text; const chunkSize = 20; // Approximate chunk size for (let i = 0; i < text.length; i += chunkSize) { const chunkText = text.slice(i, Math.min(i + chunkSize, text.length)); yield { id: this.generateMessageId(), model: this.model || 'claude-3-5-haiku-latest', provider: 'anthropic', message: { id: this.generateMessageId(), role: 'assistant', content: [{ type: 'text', text: chunkText }], createdAt: new Date(), }, text: chunkText, createdAt: new Date(), rawResponse: null, }; } } } } break; } } async convertToAnthropicFormat(request) { var _a, _b, _c, _d, _f, _g, _h; if (!request.model && !this.model) { throw new Error('Model is required for Anthropic requests'); } const systemMessage = request.messages.find(m => m.role === 'system'); const otherMessages = request.messages.filter(m => m.role !== 'system'); let messages = await Promise.all(otherMessages.map(async (msg) => { const content = this.normalizeContent(msg.content); const anthropicContent = await Promise.all(content.map(async (c) => { var _a, _b; switch (c.type) { case 'text': return { type: 'text', text: c.text }; case 'image': return { type: 'image', source: { type: (c.source.url ? 'url' : 'base64'), mediaType: c.source.mediaType || 'image/jpeg', data: c.source.data, url: c.source.url, }, }; case 'tool_use': { // customFunctionsからツールを検索して実行 const customFunction = (_a = this.tools) === null || _a === void 0 ? void 0 : _a.find(func => func.function.name === c.name); if (customFunction) { try { // CustomFunctionのargsとtool_useのinputをマージ const mergedArgs = { ...(customFunction.args || {}), ...c.input }; const result = await customFunction.handler(mergedArgs); return { type: 'tool_result', tool_use_id: c.id, is_error: false, content: [{ type: 'text', text: typeof result === 'string' ? result : JSON.stringify(result), }], }; } catch (error) { return { type: 'tool_result', tool_use_id: c.id, is_error: true, content: [{ type: 'text', text: error instanceof Error ? error.message : 'Unknown error', }], }; } } return { type: 'tool_use', id: c.id, name: c.name, input: c.input, }; } case 'tool_result': return { type: 'tool_result', tool_use_id: c.toolUseId, is_error: c.isError, content: ((_b = c.content) === null || _b === void 0 ? void 0 : _b.map(tc => ({ type: 'text', text: tc.type === 'text' ? tc.text : '[Unsupported content]', }))) || [], }; default: return { type: 'text', text: '[Unsupported content type]' }; } })); return { role: msg.role === 'assistant' ? 'assistant' : 'user', content: anthropicContent, }; })); // Handle response_format for Anthropic if (((_a = request.generationConfig) === null || _a === void 0 ? void 0 : _a.responseFormat) instanceof ResponseFormat) { messages = request.generationConfig.responseFormat.addRequestSuffix(messages); } // toolsの結合: request.toolsとcustomFunctionsを統合 const tools = [ ...(((_b = request.tools) === null || _b === void 0 ? void 0 : _b.map(tool => ({ name: tool.function.name, description: tool.function.description || '', input_schema: { type: 'object', ...tool.function.parameters || {}, }, }))) || []), ...(this.tools ? this.tools.map((func) => ({ name: func.function.name, description: func.function.description || '', input_schema: { type: 'object', ...func.function.parameters || {}, }, })) : []), ]; return { model: request.model || this.model, messages: messages, system: systemMessage ? this.extractTextFromContent(systemMessage.content) : undefined, max_tokens: ((_c = request.generationConfig) === null || _c === void 0 ? void 0 : _c.max_tokens) || 4096, temperature: (_d = request.generationConfig) === null || _d === void 0 ? void 0 : _d.temperature, top_p: (_f = request.generationConfig) === null || _f === void 0 ? void 0 : _f.top_p, top_k: (_g = request.generationConfig) === null || _g === void 0 ? void 0 : _g.top_k, stop_sequences: (_h = request.generationConfig) === null || _h === void 0 ? void 0 : _h.stopSequences, tools: tools.length > 0 ? tools : undefined, }; } convertFromAnthropicFormat(response) { const content = response.content.map(block => { switch (block.type) { case 'text': return { type: 'text', text: block.text }; case 'tool_use': return { type: 'tool_use', id: block.id, name: block.name, input: block.input, }; default: return { type: 'text', text: '[Unknown content type]' }; } }); const unifiedMessage = { id: response.id, role: response.role, content, createdAt: new Date(), }; const usage = { inputTokens: response.usage.input_tokens, outputTokens: response.usage.output_tokens, totalTokens: response.usage.input_tokens + response.usage.output_tokens, }; // Extract text for convenience field const contentArray = Array.isArray(unifiedMessage.content) ? unifiedMessage.content : [{ type: 'text', text: unifiedMessage.content }]; const textContent = contentArray.find((c) => c.type === 'text'); return { id: response.id, model: response.model, provider: 'anthropic', message: unifiedMessage, text: (textContent === null || textContent === void 0 ? void 0 : textContent.text) || '', usage, finish_reason: response.stop_reason, createdAt: new Date(), rawResponse: response, }; } convertStreamChunk(chunk) { if (!this.model) { throw new Error('Model is required for streaming responses'); } const content = [{ type: 'text', text: chunk.delta.text, }]; const unifiedMessage = { id: this.generateMessageId(), role: 'assistant', content, createdAt: new Date(), }; // Extract text for convenience field const contentArray = Array.isArray(unifiedMessage.content) ? unifiedMessage.content : [{ type: 'text', text: unifiedMessage.content }]; const textContent = contentArray.find((c) => c.type === 'text'); return { id: this.generateMessageId(), model: this.model, provider: 'anthropic', message: unifiedMessage, text: (textContent === null || textContent === void 0 ? void 0 : textContent.text) || '', createdAt: new Date(), rawResponse: chunk, }; } extractTextFromContent(content) { if (typeof content === 'string') return content; const textContent = content.find(c => c.type === 'text'); return (textContent === null || textContent === void 0 ? void 0 : textContent.text) || ''; } handleError(error) { var _a; if (error instanceof Anthropic.APIError) { const errorBody = (_a = error.error) === null || _a === void 0 ? void 0 : _a.error; return { code: (errorBody === null || errorBody === void 0 ? void 0 : errorBody.type) || 'anthropic_error', message: (errorBody === null || errorBody === void 0 ? void 0 : errorBody.message) || error.message, type: this.mapErrorType(error.status), statusCode: error.status, provider: 'anthropic', details: error, }; } return { code: 'unknown_error', message: error.message || 'Unknown error occurred', type: 'api_error', provider: 'anthropic', details: error, }; } mapErrorType(status) { if (!status) return 'api_error'; if (status === 429) return 'rate_limit'; if (status === 401) return 'authentication'; if (status >= 400 && status < 500) return 'invalid_request'; if (status >= 500) return 'server_error'; return 'api_error'; } } //# sourceMappingURL=provider.js.map