@unified-llm/core
Version:
Unified LLM interface (in-memory).
460 lines • 21.2 kB
JavaScript
import Anthropic from '@anthropic-ai/sdk';
import BaseProvider from '../base-provider';
import { validateChatRequest } from '../../utils/validation';
import { ResponseFormat } from '../../response-format';
// Anthropic実装
export class AnthropicProvider extends BaseProvider {
constructor({ apiKey, model, tools }) {
super({ model, tools });
this.client = new Anthropic({ apiKey });
}
async chat(request) {
validateChatRequest(request);
try {
const anthropicRequest = await this.convertToAnthropicFormat(request);
let response = await this.client.messages.create(anthropicRequest);
let messages = [...anthropicRequest.messages];
// stop_reason が 'tool_use' の場合、ツールを実行して結果を返す
while (response.stop_reason === 'tool_use' && this.tools) {
const toolUseBlocks = response.content.filter(block => block.type === 'tool_use');
const toolResults = [];
for (const toolBlock of toolUseBlocks) {
const customFunction = this.tools.find(func => func.function.name === toolBlock.name);
if (customFunction) {
try {
// CustomFunctionのargsとtool_useのinputをマージ
const mergedArgs = {
...(customFunction.args || {}),
...toolBlock.input
};
const result = await customFunction.handler(mergedArgs);
toolResults.push({
type: 'tool_result',
tool_use_id: toolBlock.id,
content: typeof result === 'string' ? result : JSON.stringify(result),
});
}
catch (error) {
toolResults.push({
type: 'tool_result',
tool_use_id: toolBlock.id,
is_error: true,
content: error instanceof Error ? error.message : 'Unknown error',
});
}
}
}
// ツール実行結果を含めて再度リクエスト
if (toolResults.length > 0) {
messages = [
...messages,
{
role: 'assistant',
content: response.content,
},
{
role: 'user',
content: toolResults,
},
];
const followUpRequest = {
...anthropicRequest,
messages: messages,
};
response = await this.client.messages.create(followUpRequest);
}
else {
// ツール結果がない場合はループを抜ける
break;
}
}
return this.convertFromAnthropicFormat(response);
}
catch (error) {
throw this.handleError(error);
}
}
async *stream(request) {
validateChatRequest(request);
const anthropicRequest = await this.convertToAnthropicFormat(request);
let messages = [...anthropicRequest.messages];
// Keep trying to get a response until we don't get tool calls
while (true) {
const stream = await this.client.messages.create({
...anthropicRequest,
messages,
stream: true,
});
// Accumulate content blocks
const contentBlocks = [];
let stopReason = null;
let hasToolUse = false;
// First pass: detect if there are tool calls
const allChunks = [];
for await (const chunk of stream) {
allChunks.push(chunk);
if (chunk.type === 'content_block_start') {
contentBlocks.push({ ...chunk.content_block });
if (chunk.content_block.type === 'tool_use') {
hasToolUse = true;
}
}
else if (chunk.type === 'content_block_delta') {
const blockIndex = chunk.index;
if (blockIndex < contentBlocks.length) {
const block = contentBlocks[blockIndex];
if (block.type === 'text' && chunk.delta.type === 'text_delta') {
block.text = (block.text || '') + chunk.delta.text;
}
else if (block.type === 'tool_use' && chunk.delta.type === 'input_json_delta') {
// Accumulate tool input JSON
if (!block._rawInput)
block._rawInput = '';
block._rawInput += chunk.delta.partial_json;
}
}
}
else if (chunk.type === 'content_block_stop') {
const blockIndex = chunk.index;
if (blockIndex < contentBlocks.length) {
const block = contentBlocks[blockIndex];
if (block.type === 'tool_use' && block._rawInput) {
// Parse the complete tool input
try {
block.input = JSON.parse(block._rawInput);
delete block._rawInput;
}
catch (_e) {
block.input = {};
}
}
}
}
else if (chunk.type === 'message_delta') {
if (chunk.delta.stop_reason) {
stopReason = chunk.delta.stop_reason;
}
}
}
// If we have tool use and tools are available, execute them
if (stopReason === 'tool_use' && this.tools && hasToolUse) {
const toolUseBlocks = contentBlocks.filter(block => block.type === 'tool_use');
const toolResults = [];
for (const toolBlock of toolUseBlocks) {
const customFunction = this.tools.find(func => func.function.name === toolBlock.name);
if (customFunction) {
try {
// Merge default args with tool input
const mergedArgs = {
...(customFunction.args || {}),
...toolBlock.input
};
const result = await customFunction.handler(mergedArgs);
toolResults.push({
type: 'tool_result',
tool_use_id: toolBlock.id,
content: typeof result === 'string' ? result : JSON.stringify(result),
});
}
catch (error) {
toolResults.push({
type: 'tool_result',
tool_use_id: toolBlock.id,
is_error: true,
content: error instanceof Error ? error.message : 'Unknown error',
});
}
}
}
// Continue with tool results if we have any
if (toolResults.length > 0) {
// Clean up contentBlocks before sending to API
const cleanContentBlocks = contentBlocks.map(block => {
const cleanBlock = { ...block };
delete cleanBlock._rawInput;
return cleanBlock;
});
messages = [
...messages,
{
role: 'assistant',
content: cleanContentBlocks,
},
{
role: 'user',
content: toolResults,
},
];
// Continue the loop to get the next response
continue;
}
}
// Second pass: yield chunks
if (!hasToolUse) {
// No tool use, stream text deltas immediately
for (const chunk of allChunks) {
if (chunk.type === 'content_block_delta' && chunk.delta.type === 'text_delta') {
yield this.convertStreamChunk(chunk);
}
}
}
else {
// Tool use was executed, now stream the final response
// Convert accumulated content blocks to streaming format
for (const block of contentBlocks) {
if (block.type === 'text' && block.text) {
// Simulate text streaming
const text = block.text;
const chunkSize = 20; // Approximate chunk size
for (let i = 0; i < text.length; i += chunkSize) {
const chunkText = text.slice(i, Math.min(i + chunkSize, text.length));
yield {
id: this.generateMessageId(),
model: this.model || 'claude-3-5-haiku-latest',
provider: 'anthropic',
message: {
id: this.generateMessageId(),
role: 'assistant',
content: [{ type: 'text', text: chunkText }],
createdAt: new Date(),
},
text: chunkText,
createdAt: new Date(),
rawResponse: null,
};
}
}
}
}
break;
}
}
async convertToAnthropicFormat(request) {
var _a, _b, _c, _d, _f, _g, _h;
if (!request.model && !this.model) {
throw new Error('Model is required for Anthropic requests');
}
const systemMessage = request.messages.find(m => m.role === 'system');
const otherMessages = request.messages.filter(m => m.role !== 'system');
let messages = await Promise.all(otherMessages.map(async (msg) => {
const content = this.normalizeContent(msg.content);
const anthropicContent = await Promise.all(content.map(async (c) => {
var _a, _b;
switch (c.type) {
case 'text':
return { type: 'text', text: c.text };
case 'image':
return {
type: 'image',
source: {
type: (c.source.url ? 'url' : 'base64'),
mediaType: c.source.mediaType || 'image/jpeg',
data: c.source.data,
url: c.source.url,
},
};
case 'tool_use': {
// customFunctionsからツールを検索して実行
const customFunction = (_a = this.tools) === null || _a === void 0 ? void 0 : _a.find(func => func.function.name === c.name);
if (customFunction) {
try {
// CustomFunctionのargsとtool_useのinputをマージ
const mergedArgs = {
...(customFunction.args || {}),
...c.input
};
const result = await customFunction.handler(mergedArgs);
return {
type: 'tool_result',
tool_use_id: c.id,
is_error: false,
content: [{
type: 'text',
text: typeof result === 'string' ? result : JSON.stringify(result),
}],
};
}
catch (error) {
return {
type: 'tool_result',
tool_use_id: c.id,
is_error: true,
content: [{
type: 'text',
text: error instanceof Error ? error.message : 'Unknown error',
}],
};
}
}
return {
type: 'tool_use',
id: c.id,
name: c.name,
input: c.input,
};
}
case 'tool_result':
return {
type: 'tool_result',
tool_use_id: c.toolUseId,
is_error: c.isError,
content: ((_b = c.content) === null || _b === void 0 ? void 0 : _b.map(tc => ({
type: 'text',
text: tc.type === 'text' ? tc.text : '[Unsupported content]',
}))) || [],
};
default:
return { type: 'text', text: '[Unsupported content type]' };
}
}));
return {
role: msg.role === 'assistant' ? 'assistant' : 'user',
content: anthropicContent,
};
}));
// Handle response_format for Anthropic
if (((_a = request.generationConfig) === null || _a === void 0 ? void 0 : _a.responseFormat) instanceof ResponseFormat) {
messages = request.generationConfig.responseFormat.addRequestSuffix(messages);
}
// toolsの結合: request.toolsとcustomFunctionsを統合
const tools = [
...(((_b = request.tools) === null || _b === void 0 ? void 0 : _b.map(tool => ({
name: tool.function.name,
description: tool.function.description || '',
input_schema: {
type: 'object',
...tool.function.parameters || {},
},
}))) || []),
...(this.tools ? this.tools.map((func) => ({
name: func.function.name,
description: func.function.description || '',
input_schema: {
type: 'object',
...func.function.parameters || {},
},
})) : []),
];
return {
model: request.model || this.model,
messages: messages,
system: systemMessage ? this.extractTextFromContent(systemMessage.content) : undefined,
max_tokens: ((_c = request.generationConfig) === null || _c === void 0 ? void 0 : _c.max_tokens) || 4096,
temperature: (_d = request.generationConfig) === null || _d === void 0 ? void 0 : _d.temperature,
top_p: (_f = request.generationConfig) === null || _f === void 0 ? void 0 : _f.top_p,
top_k: (_g = request.generationConfig) === null || _g === void 0 ? void 0 : _g.top_k,
stop_sequences: (_h = request.generationConfig) === null || _h === void 0 ? void 0 : _h.stopSequences,
tools: tools.length > 0 ? tools : undefined,
};
}
convertFromAnthropicFormat(response) {
const content = response.content.map(block => {
switch (block.type) {
case 'text':
return { type: 'text', text: block.text };
case 'tool_use':
return {
type: 'tool_use',
id: block.id,
name: block.name,
input: block.input,
};
default:
return { type: 'text', text: '[Unknown content type]' };
}
});
const unifiedMessage = {
id: response.id,
role: response.role,
content,
createdAt: new Date(),
};
const usage = {
inputTokens: response.usage.input_tokens,
outputTokens: response.usage.output_tokens,
totalTokens: response.usage.input_tokens + response.usage.output_tokens,
};
// Extract text for convenience field
const contentArray = Array.isArray(unifiedMessage.content) ? unifiedMessage.content : [{ type: 'text', text: unifiedMessage.content }];
const textContent = contentArray.find((c) => c.type === 'text');
return {
id: response.id,
model: response.model,
provider: 'anthropic',
message: unifiedMessage,
text: (textContent === null || textContent === void 0 ? void 0 : textContent.text) || '',
usage,
finish_reason: response.stop_reason,
createdAt: new Date(),
rawResponse: response,
};
}
convertStreamChunk(chunk) {
if (!this.model) {
throw new Error('Model is required for streaming responses');
}
const content = [{
type: 'text',
text: chunk.delta.text,
}];
const unifiedMessage = {
id: this.generateMessageId(),
role: 'assistant',
content,
createdAt: new Date(),
};
// Extract text for convenience field
const contentArray = Array.isArray(unifiedMessage.content) ? unifiedMessage.content : [{ type: 'text', text: unifiedMessage.content }];
const textContent = contentArray.find((c) => c.type === 'text');
return {
id: this.generateMessageId(),
model: this.model,
provider: 'anthropic',
message: unifiedMessage,
text: (textContent === null || textContent === void 0 ? void 0 : textContent.text) || '',
createdAt: new Date(),
rawResponse: chunk,
};
}
extractTextFromContent(content) {
if (typeof content === 'string')
return content;
const textContent = content.find(c => c.type === 'text');
return (textContent === null || textContent === void 0 ? void 0 : textContent.text) || '';
}
handleError(error) {
var _a;
if (error instanceof Anthropic.APIError) {
const errorBody = (_a = error.error) === null || _a === void 0 ? void 0 : _a.error;
return {
code: (errorBody === null || errorBody === void 0 ? void 0 : errorBody.type) || 'anthropic_error',
message: (errorBody === null || errorBody === void 0 ? void 0 : errorBody.message) || error.message,
type: this.mapErrorType(error.status),
statusCode: error.status,
provider: 'anthropic',
details: error,
};
}
return {
code: 'unknown_error',
message: error.message || 'Unknown error occurred',
type: 'api_error',
provider: 'anthropic',
details: error,
};
}
mapErrorType(status) {
if (!status)
return 'api_error';
if (status === 429)
return 'rate_limit';
if (status === 401)
return 'authentication';
if (status >= 400 && status < 500)
return 'invalid_request';
if (status >= 500)
return 'server_error';
return 'api_error';
}
}
//# sourceMappingURL=provider.js.map