UNPKG

@llumiverse/drivers

Version:

LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.

464 lines (403 loc) 19.2 kB
import { ContentBlock, ContentBlockParam, DocumentBlockParam, ImageBlockParam, Message, MessageParam, TextBlockParam, ToolResultBlockParam } from "@anthropic-ai/sdk/resources/index.js"; import { AIModel, Completion, CompletionChunkObject, ExecutionOptions, getMaxTokensLimitVertexAi, JSONObject, ModelType, PromptRole, PromptSegment, readStreamAsBase64, readStreamAsString, StatelessExecutionOptions, ToolUse, VertexAIClaudeOptions } from "@llumiverse/core"; import { asyncMap } from "@llumiverse/core/async"; import { VertexAIDriver } from "../index.js"; import { ModelDefinition } from "../models.js"; import { MessageCreateParamsBase, MessageCreateParamsNonStreaming, RawMessageStreamEvent } from "@anthropic-ai/sdk/resources/messages.js"; import { MessageStreamParams } from "@anthropic-ai/sdk/resources/index.mjs"; export const ANTHROPIC_REGIONS: Record<string, string> = { us: "us-east5", europe: "europe-west1", global: "global", } export const NON_GLOBAL_ANTHROPIC_MODELS = [ "claude-3-5", "claude-3", ]; interface ClaudePrompt { messages: MessageParam[]; system?: TextBlockParam[]; } function claudeFinishReason(reason: string | undefined) { if (!reason) return undefined; switch (reason) { case 'end_turn': return "stop"; case 'max_tokens': return "length"; default: return reason; //stop_sequence } } export function collectTools(content: ContentBlock[]): ToolUse[] | undefined { const out: ToolUse[] = []; for (const block of content) { if (block.type === "tool_use") { out.push({ id: block.id, tool_name: block.name, tool_input: block.input as JSONObject, }); } } return out.length > 0 ? out : undefined; } function collectAllTextContent(content: ContentBlock[], includeThoughts: boolean = false) { const textParts: string[] = []; // First pass: collect thinking blocks if (includeThoughts) { for (const block of content) { if (block.type === 'thinking' && block.thinking) { textParts.push(block.thinking); } else if (block.type === 'redacted_thinking' && block.data) { textParts.push(`[Redacted thinking: ${block.data}]`); } } if (textParts.length > 0) { textParts.push(''); // Create a new line after thinking blocks } } // Second pass: collect text blocks for (const block of content) { if (block.type === 'text' && block.text) { textParts.push(block.text); } } return textParts.join('\n'); } //Used to get a max_token value when not specified in the model options. Claude requires it to be set. function maxToken(option: StatelessExecutionOptions): number { const modelOptions = option.model_options as VertexAIClaudeOptions | undefined; if (modelOptions && typeof modelOptions.max_tokens === "number") { return modelOptions.max_tokens; } else { const thinking_budget = modelOptions?.thinking_budget_tokens ?? 0; let maxSupportedTokens = getMaxTokensLimitVertexAi(option.model); // Fallback to the default max tokens limit for the model if (option.model.includes('claude-3-7-sonnet') && (modelOptions?.thinking_budget_tokens ?? 0) < 48000) { maxSupportedTokens = 64000; // Claude 3.7 can go up to 128k with a beta header, but when no max tokens is specified, we default to 64k. } return Math.min(16000 + thinking_budget, maxSupportedTokens); // Cap to 16k, to avoid taking up too much context window and quota. } } // Type-safe overloads for collectFileBlocks async function collectFileBlocks(segment: PromptSegment, restrictedTypes: true): Promise<Array<TextBlockParam | ImageBlockParam>>; async function collectFileBlocks(segment: PromptSegment, restrictedTypes?: false): Promise<ContentBlockParam[]>; async function collectFileBlocks(segment: PromptSegment, restrictedTypes: boolean = false): Promise<ContentBlockParam[]> { const contentBlocks: ContentBlockParam[] = []; for (const file of segment.files || []) { if (file.mime_type?.startsWith("image/")) { const allowedTypes = ["image/png", "image/jpeg", "image/gif", "image/webp"]; if (!allowedTypes.includes(file.mime_type)) { throw new Error(`Unsupported image type: ${file.mime_type}`); } const mimeType = String(file.mime_type) as "image/png" | "image/jpeg" | "image/gif" | "image/webp"; contentBlocks.push({ type: 'image', source: { type: 'base64', data: await readStreamAsBase64(await file.getStream()), media_type: mimeType } } satisfies ImageBlockParam); } else if (!restrictedTypes) { if (file.mime_type === "application/pdf") { contentBlocks.push({ title: file.name, type: 'document', source: { type: 'base64', data: await readStreamAsBase64(await file.getStream()), media_type: 'application/pdf' } } satisfies DocumentBlockParam); } else if (file.mime_type?.startsWith("text/")) { contentBlocks.push({ title: file.name, type: 'document', source: { type: 'text', data: await readStreamAsString(await file.getStream()), media_type: 'text/plain' } } satisfies DocumentBlockParam); } } } return contentBlocks; } export class ClaudeModelDefinition implements ModelDefinition<ClaudePrompt> { model: AIModel constructor(modelId: string) { this.model = { id: modelId, name: modelId, provider: 'vertexai', type: ModelType.Text, can_stream: true, } satisfies AIModel; } async createPrompt(_driver: VertexAIDriver, segments: PromptSegment[], options: ExecutionOptions): Promise<ClaudePrompt> { // Convert the prompt to the format expected by the Claude API let system: TextBlockParam[] | undefined = segments .filter(segment => segment.role === PromptRole.system) .map(segment => ({ text: segment.content, type: 'text' })); if (options.result_schema) { let schemaText: string = ''; if (options.tools && options.tools.length > 0) { schemaText = "When not calling tools, the answer must be a JSON object using the following JSON Schema:\n" + JSON.stringify(options.result_schema); } else { schemaText = "The answer must be a JSON object using the following JSON Schema:\n" + JSON.stringify(options.result_schema); } const schemaSegments: TextBlockParam = { text: schemaText, type: 'text' } system.push(schemaSegments); } let messages: MessageParam[] = []; const safetyMessages: MessageParam[] = []; for (const segment of segments) { if (segment.role === PromptRole.system) { continue; } if (segment.role === PromptRole.tool) { if (!segment.tool_use_id) { throw new Error("Tool prompt segment must have a tool use ID"); } // Build content blocks for tool results (restricted types) const contentBlocks: Array<TextBlockParam | ImageBlockParam> = []; if (segment.content) { contentBlocks.push({ type: 'text', text: segment.content } satisfies TextBlockParam); } // Collect file blocks with type safety const fileBlocks = await collectFileBlocks(segment, true); contentBlocks.push(...fileBlocks); messages.push({ role: 'user', content: [{ type: 'tool_result', tool_use_id: segment.tool_use_id, content: contentBlocks, } satisfies ToolResultBlockParam] }); } else { // Build content blocks for regular messages (all types allowed) const contentBlocks: ContentBlockParam[] = []; if (segment.content) { contentBlocks.push({ type: 'text', text: segment.content } satisfies TextBlockParam); } // Collect file blocks without restrictions const fileBlocks = await collectFileBlocks(segment, false); contentBlocks.push(...fileBlocks); if (contentBlocks.length === 0) { continue; // skip empty segments } const messageParam: MessageParam = { role: segment.role === PromptRole.assistant ? 'assistant' : 'user', content: contentBlocks }; if (segment.role === PromptRole.safety) { safetyMessages.push(messageParam); } else { messages.push(messageParam); } } } messages = messages.concat(safetyMessages); if (system && system.length === 0) { system = undefined; // If system is empty, set to undefined } return { messages: messages, system: system } } async requestTextCompletion(driver: VertexAIDriver, prompt: ClaudePrompt, options: ExecutionOptions): Promise<Completion> { const splits = options.model.split("/"); let region: string | undefined = undefined; if (splits[0] === "locations" && splits.length >= 2) { region = splits[1]; } const modelName = splits[splits.length - 1]; options = { ...options, model: modelName }; const client = driver.getAnthropicClient(region); options.model_options = options.model_options as VertexAIClaudeOptions; if (options.model_options?._option_id !== "vertexai-claude") { driver.logger.warn("Invalid model options", { options: options.model_options }); } let conversation = updateConversation(options.conversation as ClaudePrompt, prompt); const { payload, requestOptions } = getClaudePayload(options, conversation); // disable streaming, the create function is overloaded so payload type matters. const nonStreamingPayload: MessageCreateParamsNonStreaming = { ...payload, stream: false }; const result = await client.messages.create(nonStreamingPayload, requestOptions) satisfies Message; // Use the new function to collect text content, including thinking if enabled const includeThoughts = options.model_options?.include_thoughts ?? false; const text = collectAllTextContent(result.content, includeThoughts); const tool_use = collectTools(result.content); conversation = updateConversation(conversation, createPromptFromResponse(result)); return { result: text ? [{ type: "text", value: text }] : [{ type: "text", value: '' }], tool_use, token_usage: { prompt: result.usage.input_tokens, result: result.usage.output_tokens, total: result.usage.input_tokens + result.usage.output_tokens }, // make sure we set finish_reason to the correct value (claude is normally setting this by itself) finish_reason: tool_use ? "tool_use" : claudeFinishReason(result?.stop_reason ?? ''), conversation } satisfies Completion; } async requestTextCompletionStream(driver: VertexAIDriver, prompt: ClaudePrompt, options: ExecutionOptions): Promise<AsyncIterable<CompletionChunkObject>> { const splits = options.model.split("/"); let region: string | undefined = undefined; if (splits[0] === "locations" && splits.length >= 2) { region = splits[1]; } const modelName = splits[splits.length - 1]; options = { ...options, model: modelName }; const client = driver.getAnthropicClient(region); const model_options = options.model_options as VertexAIClaudeOptions | undefined; if (model_options?._option_id !== "vertexai-claude") { driver.logger.warn("Invalid model options", { options: options.model_options }); } const { payload, requestOptions } = getClaudePayload(options, prompt); const streamingPayload: MessageStreamParams = { ...payload, stream: true }; const response_stream = await client.messages.stream(streamingPayload, requestOptions); const stream = asyncMap(response_stream, async (streamEvent: RawMessageStreamEvent) => { switch (streamEvent.type) { case "message_start": return { result: [{ type: "text", value: '' }], token_usage: { prompt: streamEvent.message.usage.input_tokens, result: streamEvent.message.usage.output_tokens } } satisfies CompletionChunkObject; case "message_delta": return { result: [{ type: "text", value: '' }], token_usage: { result: streamEvent.usage.output_tokens }, finish_reason: claudeFinishReason(streamEvent.delta.stop_reason ?? undefined), } satisfies CompletionChunkObject; case "content_block_start": // Handle redacted thinking blocks if (streamEvent.content_block.type === "redacted_thinking" && model_options?.include_thoughts) { return { result: [{ type: "text", value: `[Redacted thinking: ${streamEvent.content_block.data}]` }] } satisfies CompletionChunkObject; } break; case "content_block_delta": // Handle different delta types switch (streamEvent.delta.type) { case "text_delta": return { result: streamEvent.delta.text ? [{ type: "text", value: streamEvent.delta.text }] : [] } satisfies CompletionChunkObject; case "thinking_delta": if (model_options?.include_thoughts) { return { result: streamEvent.delta.thinking ? [{ type: "text", value: streamEvent.delta.thinking }] : [], } satisfies CompletionChunkObject; } break; case "signature_delta": // Signature deltas, signify the end of the thoughts. if (model_options?.include_thoughts) { return { result: [{ type: "text", value: '\n\n' }], // Double newline for more spacing } satisfies CompletionChunkObject; } break; } break; case "content_block_stop": // Handle the end of content blocks, for redacted thinking blocks if (model_options?.include_thoughts) { return { result: [{ type: "text", value: '\n\n' }] // Add double newline for spacing } satisfies CompletionChunkObject; } break; } // Default case for all other event types return { result: [] } satisfies CompletionChunkObject; }); return stream; } } function createPromptFromResponse(response: Message): ClaudePrompt { return { messages: [{ role: response.role, content: response.content, }], system: undefined } } /** * Update the conversation messages * @param prompt * @param response * @returns */ function updateConversation(conversation: ClaudePrompt | undefined | null, prompt: ClaudePrompt): ClaudePrompt { const baseSystemMessages = conversation?.system || []; const baseMessages = conversation?.messages || []; const system = baseSystemMessages.concat(prompt.system || []); return { messages: baseMessages.concat(prompt.messages || []), system: system.length > 0 ? system : undefined // If system is empty, set to undefined }; } interface RequestOptions { headers?: Record<string, string>; } function getClaudePayload(options: ExecutionOptions, prompt: ClaudePrompt): { payload: MessageCreateParamsBase, requestOptions: RequestOptions | undefined } { const modelName = options.model; // Model name is already extracted in the calling methods const model_options = options.model_options as VertexAIClaudeOptions; // Add beta header for Claude 3.7 models to enable 128k output tokens let requestOptions: RequestOptions | undefined = undefined; if (modelName.includes('claude-3-7-sonnet') && ((model_options?.max_tokens ?? 0) > 64000 || (model_options?.thinking_budget_tokens ?? 0) > 64000)) { requestOptions = { headers: { 'anthropic-beta': 'output-128k-2025-02-19' } }; } const payload = { messages: prompt.messages, system: prompt.system, tools: options.tools, // we are using the same shape as claude for tools temperature: model_options?.temperature, model: modelName, max_tokens: maxToken(options), top_p: model_options?.top_p, top_k: model_options?.top_k, stop_sequences: model_options?.stop_sequence, thinking: model_options?.thinking_mode ? { budget_tokens: model_options?.thinking_budget_tokens ?? 1024, type: "enabled" as const } : { type: "disabled" as const } }; return { payload, requestOptions }; }