UNPKG

anthropic-vertex-ai

Version:

[nalaso/anthropic-vertex-ai](https://github.com/nalaso/anthropic-vertex-ai) is a community provider that uses Anthropic models through Vertex AI to provide language model support for the Vercel AI SDK.

670 lines (664 loc) 21.3 kB
// src/anthropic-vertex-provider.ts import { loadSetting, withoutTrailingSlash } from "@ai-sdk/provider-utils"; import { GoogleAuth } from "google-auth-library"; // src/anthropic-messages-language-model.ts import { UnsupportedFunctionalityError as UnsupportedFunctionalityError2 } from "@ai-sdk/provider"; import { combineHeaders, createEventSourceResponseHandler, createJsonResponseHandler, postJsonToApi } from "@ai-sdk/provider-utils"; import { z as z2 } from "zod"; // src/anthropic-error.ts import { createJsonErrorResponseHandler } from "@ai-sdk/provider-utils"; import { z } from "zod"; var anthropicErrorDataSchema = z.object({ type: z.literal("error"), error: z.object({ type: z.string(), message: z.string() }) }); var anthropicFailedResponseHandler = createJsonErrorResponseHandler({ errorSchema: anthropicErrorDataSchema, errorToMessage: (data) => data.error.message }); // src/convert-to-anthropic-messages-prompt.ts import { UnsupportedFunctionalityError } from "@ai-sdk/provider"; import { convertUint8ArrayToBase64 } from "@ai-sdk/provider-utils"; function convertToAnthropicMessagesPrompt(prompt) { var _a; const blocks = groupIntoBlocks(prompt); let system = void 0; const messages = []; for (let i = 0; i < blocks.length; i++) { const block = blocks[i]; const type = block.type; switch (type) { case "system": { if (system != null) { throw new UnsupportedFunctionalityError({ functionality: "Multiple system messages that are separated by user/assistant messages" }); } system = block.messages.map(({ content }) => content).join("\n"); break; } case "user": { const anthropicContent = []; for (const { role, content } of block.messages) { switch (role) { case "user": { for (const part of content) { switch (part.type) { case "text": { anthropicContent.push({ type: "text", text: part.text }); break; } case "image": { if (part.image instanceof URL) { throw new UnsupportedFunctionalityError({ functionality: "Image URLs in user messages" }); } anthropicContent.push({ type: "image", source: { type: "base64", media_type: (_a = part.mimeType) != null ? _a : "image/jpeg", data: convertUint8ArrayToBase64(part.image) } }); break; } } } break; } case "tool": { for (const part of content) { anthropicContent.push({ type: "tool_result", tool_use_id: part.toolCallId, content: JSON.stringify(part.result), is_error: part.isError }); } break; } default: { const _exhaustiveCheck = role; throw new Error(`Unsupported role: ${_exhaustiveCheck}`); } } } messages.push({ role: "user", content: anthropicContent }); break; } case "assistant": { if (block.messages.length > 1) { throw new UnsupportedFunctionalityError({ functionality: "Multiple assistant messages in block" }); } const { content } = block.messages[0]; messages.push({ role: "assistant", content: content.map((part, j) => { switch (part.type) { case "text": { if (i === blocks.length - 1 && j === block.messages.length - 1) { return { type: "text", text: part.text.trim() }; } return { type: "text", text: part.text }; } case "tool-call": { return { type: "tool_use", id: part.toolCallId, name: part.toolName, input: part.args }; } } }) }); break; } default: { const _exhaustiveCheck = type; throw new Error(`Unsupported type: ${_exhaustiveCheck}`); } } } return { system, messages }; } function groupIntoBlocks(prompt) { const blocks = []; let currentBlock = void 0; for (const { role, content } of prompt) { switch (role) { case "system": { if ((currentBlock == null ? void 0 : currentBlock.type) !== "system") { currentBlock = { type: "system", messages: [] }; blocks.push(currentBlock); } currentBlock.messages.push({ role, content }); break; } case "assistant": { if ((currentBlock == null ? void 0 : currentBlock.type) !== "assistant") { currentBlock = { type: "assistant", messages: [] }; blocks.push(currentBlock); } currentBlock.messages.push({ role, content }); break; } case "user": { if ((currentBlock == null ? void 0 : currentBlock.type) !== "user") { currentBlock = { type: "user", messages: [] }; blocks.push(currentBlock); } currentBlock.messages.push({ role, content }); break; } case "tool": { if ((currentBlock == null ? void 0 : currentBlock.type) !== "user") { currentBlock = { type: "user", messages: [] }; blocks.push(currentBlock); } currentBlock.messages.push({ role, content }); break; } default: { const _exhaustiveCheck = role; throw new Error(`Unsupported role: ${_exhaustiveCheck}`); } } } return blocks; } // src/map-anthropic-stop-reason.ts function mapAnthropicStopReason(finishReason) { switch (finishReason) { case "end_turn": case "stop_sequence": return "stop"; case "tool_use": return "tool-calls"; case "max_tokens": return "length"; default: return "other"; } } // src/anthropic-messages-language-model.ts var AnthropicMessagesLanguageModel = class { constructor(modelId, settings, config) { this.specificationVersion = "v1"; this.defaultObjectGenerationMode = "tool"; this.supportsImageUrls = false; this.defaultVersion = "vertex-2023-10-16"; this.modelId = modelId; this.settings = settings; this.config = config; this.path = `/projects/${config.projectId}/locations/${config.region}/publishers/anthropic/models/${modelId}`; } get provider() { return this.config.provider; } async getArgs({ mode, prompt, maxTokens, temperature, topP, topK, frequencyPenalty, presencePenalty, stopSequences, responseFormat, seed }) { const type = mode.type; const warnings = []; if (frequencyPenalty != null) { warnings.push({ type: "unsupported-setting", setting: "frequencyPenalty" }); } if (presencePenalty != null) { warnings.push({ type: "unsupported-setting", setting: "presencePenalty" }); } if (seed != null) { warnings.push({ type: "unsupported-setting", setting: "seed" }); } if (responseFormat != null && responseFormat.type !== "text") { warnings.push({ type: "unsupported-setting", setting: "responseFormat", details: "JSON response format is not supported." }); } const messagesPrompt = convertToAnthropicMessagesPrompt(prompt); const baseArgs = { anthropic_version: this.defaultVersion, // model specific settings: top_k: topK != null ? topK : this.settings.topK, // standardized settings: max_tokens: maxTokens != null ? maxTokens : 4096, // 4096: max model output tokens temperature, top_p: topP, stop_sequences: stopSequences, // prompt: system: messagesPrompt.system, messages: messagesPrompt.messages }; switch (type) { case "regular": { return { args: { ...baseArgs, ...prepareToolsAndToolChoice(mode) }, warnings }; } case "object-json": { throw new UnsupportedFunctionalityError2({ functionality: "json-mode object generation" }); } case "object-tool": { const { name, description, parameters } = mode.tool; return { args: { ...baseArgs, tools: [{ name, description, input_schema: parameters }], tool_choice: { type: "tool", name } }, warnings }; } default: { const _exhaustiveCheck = type; throw new Error(`Unsupported type: ${_exhaustiveCheck}`); } } } async doGenerate(options) { var _a; const { args, warnings } = await this.getArgs(options); const authClient = await ((_a = this.config.googleAuth) == null ? void 0 : _a.getClient()); const authHeaders = await (authClient == null ? void 0 : authClient.getRequestHeaders()); options.headers = { ...authHeaders, ...options.headers }; const specifier = "rawPredict"; const { responseHeaders, value: response } = await postJsonToApi({ url: `${this.config.baseURL}${this.path}:${specifier}`, headers: combineHeaders(this.config.headers(), options.headers), body: args, failedResponseHandler: anthropicFailedResponseHandler, successfulResponseHandler: createJsonResponseHandler( anthropicMessagesResponseSchema ), abortSignal: options.abortSignal, fetch: this.config.fetch }); const { messages: rawPrompt, ...rawSettings } = args; let text = ""; for (const content of response.content) { if (content.type === "text") { text += content.text; } } let toolCalls = void 0; if (response.content.some((content) => content.type === "tool_use")) { toolCalls = []; for (const content of response.content) { if (content.type === "tool_use") { toolCalls.push({ toolCallType: "function", toolCallId: content.id, toolName: content.name, args: JSON.stringify(content.input) }); } } } return { text, toolCalls, finishReason: mapAnthropicStopReason(response.stop_reason), usage: { promptTokens: response.usage.input_tokens, completionTokens: response.usage.output_tokens }, rawCall: { rawPrompt, rawSettings }, rawResponse: { headers: responseHeaders }, warnings }; } async doStream(options) { var _a; const { args, warnings } = await this.getArgs(options); const authClient = await ((_a = this.config.googleAuth) == null ? void 0 : _a.getClient()); const authHeaders = await (authClient == null ? void 0 : authClient.getRequestHeaders()); options.headers = { ...authHeaders, ...options.headers }; const specifier = "streamRawPredict"; const { responseHeaders, value: response } = await postJsonToApi({ url: `${this.config.baseURL}${this.path}:${specifier}`, headers: combineHeaders(this.config.headers(), options.headers), body: { ...args, stream: true }, failedResponseHandler: anthropicFailedResponseHandler, successfulResponseHandler: createEventSourceResponseHandler( anthropicMessagesChunkSchema ), abortSignal: options.abortSignal, fetch: this.config.fetch }); const { messages: rawPrompt, ...rawSettings } = args; let finishReason = "other"; const usage = { promptTokens: Number.NaN, completionTokens: Number.NaN }; const toolCallContentBlocks = {}; return { stream: response.pipeThrough( new TransformStream({ transform(chunk, controller) { if (!chunk.success) { controller.enqueue({ type: "error", error: chunk.error }); return; } const value = chunk.value; switch (value.type) { case "ping": { return; } case "content_block_start": { const contentBlockType = value.content_block.type; switch (contentBlockType) { case "text": { return; } case "tool_use": { toolCallContentBlocks[value.index] = { toolCallId: value.content_block.id, toolName: value.content_block.name, jsonText: "" }; return; } default: { const _exhaustiveCheck = contentBlockType; throw new Error( `Unsupported content block type: ${_exhaustiveCheck}` ); } } } case "content_block_stop": { if (toolCallContentBlocks[value.index] != null) { const contentBlock = toolCallContentBlocks[value.index]; controller.enqueue({ type: "tool-call", toolCallType: "function", toolCallId: contentBlock.toolCallId, toolName: contentBlock.toolName, args: contentBlock.jsonText }); delete toolCallContentBlocks[value.index]; } return; } case "content_block_delta": { const deltaType = value.delta.type; switch (deltaType) { case "text_delta": { controller.enqueue({ type: "text-delta", textDelta: value.delta.text }); return; } case "input_json_delta": { const contentBlock = toolCallContentBlocks[value.index]; controller.enqueue({ type: "tool-call-delta", toolCallType: "function", toolCallId: contentBlock.toolCallId, toolName: contentBlock.toolName, argsTextDelta: value.delta.partial_json }); contentBlock.jsonText += value.delta.partial_json; return; } default: { const _exhaustiveCheck = deltaType; throw new Error( `Unsupported delta type: ${_exhaustiveCheck}` ); } } } case "message_start": { usage.promptTokens = value.message.usage.input_tokens; usage.completionTokens = value.message.usage.output_tokens; return; } case "message_delta": { usage.completionTokens = value.usage.output_tokens; finishReason = mapAnthropicStopReason(value.delta.stop_reason); return; } case "message_stop": { controller.enqueue({ type: "finish", finishReason, usage }); return; } default: { const _exhaustiveCheck = value; throw new Error(`Unsupported chunk type: ${_exhaustiveCheck}`); } } } }) ), rawCall: { rawPrompt, rawSettings }, rawResponse: { headers: responseHeaders }, warnings }; } }; var anthropicMessagesResponseSchema = z2.object({ type: z2.literal("message"), content: z2.array( z2.discriminatedUnion("type", [ z2.object({ type: z2.literal("text"), text: z2.string() }), z2.object({ type: z2.literal("tool_use"), id: z2.string(), name: z2.string(), input: z2.unknown() }) ]) ), stop_reason: z2.string().optional().nullable(), usage: z2.object({ input_tokens: z2.number(), output_tokens: z2.number() }) }); var anthropicMessagesChunkSchema = z2.discriminatedUnion("type", [ z2.object({ type: z2.literal("message_start"), message: z2.object({ usage: z2.object({ input_tokens: z2.number(), output_tokens: z2.number() }) }) }), z2.object({ type: z2.literal("content_block_start"), index: z2.number(), content_block: z2.discriminatedUnion("type", [ z2.object({ type: z2.literal("text"), text: z2.string() }), z2.object({ type: z2.literal("tool_use"), id: z2.string(), name: z2.string() }) ]) }), z2.object({ type: z2.literal("content_block_delta"), index: z2.number(), delta: z2.discriminatedUnion("type", [ z2.object({ type: z2.literal("input_json_delta"), partial_json: z2.string() }), z2.object({ type: z2.literal("text_delta"), text: z2.string() }) ]) }), z2.object({ type: z2.literal("content_block_stop"), index: z2.number() }), z2.object({ type: z2.literal("message_delta"), delta: z2.object({ stop_reason: z2.string().optional().nullable() }), usage: z2.object({ output_tokens: z2.number() }) }), z2.object({ type: z2.literal("message_stop") }), z2.object({ type: z2.literal("ping") }) ]); function prepareToolsAndToolChoice(mode) { var _a; const tools = ((_a = mode.tools) == null ? void 0 : _a.length) ? mode.tools : void 0; if (tools == null) { return { tools: void 0, tool_choice: void 0 }; } const mappedTools = tools.map((tool) => ({ name: tool.name, description: tool.description, input_schema: tool.parameters })); const toolChoice = mode.toolChoice; if (toolChoice == null) { return { tools: mappedTools, tool_choice: void 0 }; } const type = toolChoice.type; switch (type) { case "auto": return { tools: mappedTools, tool_choice: { type: "auto" } }; case "required": return { tools: mappedTools, tool_choice: { type: "any" } }; case "none": return { tools: void 0, tool_choice: void 0 }; case "tool": return { tools: mappedTools, tool_choice: { type: "tool", name: toolChoice.toolName } }; default: { const _exhaustiveCheck = type; throw new Error(`Unsupported tool choice type: ${_exhaustiveCheck}`); } } } // src/anthropic-vertex-provider.ts function createAnthropicVertex(options = {}) { const getConfig = () => { const config = { projectId: loadSetting({ settingValue: options.projectId, settingName: "projectId", environmentVariableName: "GOOGLE_VERTEX_PROJECT_ID", description: "Google Vertex project id" }), region: loadSetting({ settingValue: options.region, settingName: "region", environmentVariableName: "GOOGLE_VERTEX_REGION", description: "Google Vertex region" }), googleAuth: options.googleAuth }; if (!config.region) { throw new Error( "No region was given. The client should be instantiated with the `region` option or the `GOOGLE_VERTEX_REGION` environment variable should be set." ); } if (!config.projectId) { throw new Error( "No project was given. The client should be instantiated with the `projectID` option or the `GOOGLE_VERTEX_PROJECT_ID` environment variable should be set." ); } return config; }; const createChatModel = (modelId, settings = {}) => { var _a, _b; const config = getConfig(); const baseURL = (_a = withoutTrailingSlash(options.baseURL)) != null ? _a : `https://${config.region}-aiplatform.googleapis.com/v1`; const auth = (_b = options.googleAuth) != null ? _b : new GoogleAuth({ scopes: "https://www.googleapis.com/auth/cloud-platform" }); return new AnthropicMessagesLanguageModel(modelId, settings, { provider: "anthropic.messages", baseURL, headers: () => ({ ...options.headers }), fetch: options.fetch, projectId: config.projectId, region: config.region, googleAuth: auth }); }; const provider = function(modelId, settings) { if (new.target) { throw new Error( "The Anthropic model function cannot be called with the new keyword." ); } return createChatModel(modelId, settings); }; provider.languageModel = createChatModel; provider.chat = createChatModel; return provider; } var anthropicVertex = createAnthropicVertex(); export { anthropicVertex, createAnthropicVertex }; //# sourceMappingURL=index.mjs.map