anthropic-vertex-ai
Version:
[nalaso/anthropic-vertex-ai](https://github.com/nalaso/anthropic-vertex-ai) is a community provider that uses Anthropic models through Vertex AI to provide language model support for the Vercel AI SDK.
670 lines (664 loc) • 21.3 kB
JavaScript
// src/anthropic-vertex-provider.ts
import { loadSetting, withoutTrailingSlash } from "@ai-sdk/provider-utils";
import { GoogleAuth } from "google-auth-library";
// src/anthropic-messages-language-model.ts
import {
UnsupportedFunctionalityError as UnsupportedFunctionalityError2
} from "@ai-sdk/provider";
import {
combineHeaders,
createEventSourceResponseHandler,
createJsonResponseHandler,
postJsonToApi
} from "@ai-sdk/provider-utils";
import { z as z2 } from "zod";
// src/anthropic-error.ts
import { createJsonErrorResponseHandler } from "@ai-sdk/provider-utils";
import { z } from "zod";
var anthropicErrorDataSchema = z.object({
type: z.literal("error"),
error: z.object({
type: z.string(),
message: z.string()
})
});
var anthropicFailedResponseHandler = createJsonErrorResponseHandler({
errorSchema: anthropicErrorDataSchema,
errorToMessage: (data) => data.error.message
});
// src/convert-to-anthropic-messages-prompt.ts
import {
UnsupportedFunctionalityError
} from "@ai-sdk/provider";
import { convertUint8ArrayToBase64 } from "@ai-sdk/provider-utils";
function convertToAnthropicMessagesPrompt(prompt) {
var _a;
const blocks = groupIntoBlocks(prompt);
let system = void 0;
const messages = [];
for (let i = 0; i < blocks.length; i++) {
const block = blocks[i];
const type = block.type;
switch (type) {
case "system": {
if (system != null) {
throw new UnsupportedFunctionalityError({
functionality: "Multiple system messages that are separated by user/assistant messages"
});
}
system = block.messages.map(({ content }) => content).join("\n");
break;
}
case "user": {
const anthropicContent = [];
for (const { role, content } of block.messages) {
switch (role) {
case "user": {
for (const part of content) {
switch (part.type) {
case "text": {
anthropicContent.push({ type: "text", text: part.text });
break;
}
case "image": {
if (part.image instanceof URL) {
throw new UnsupportedFunctionalityError({
functionality: "Image URLs in user messages"
});
}
anthropicContent.push({
type: "image",
source: {
type: "base64",
media_type: (_a = part.mimeType) != null ? _a : "image/jpeg",
data: convertUint8ArrayToBase64(part.image)
}
});
break;
}
}
}
break;
}
case "tool": {
for (const part of content) {
anthropicContent.push({
type: "tool_result",
tool_use_id: part.toolCallId,
content: JSON.stringify(part.result),
is_error: part.isError
});
}
break;
}
default: {
const _exhaustiveCheck = role;
throw new Error(`Unsupported role: ${_exhaustiveCheck}`);
}
}
}
messages.push({ role: "user", content: anthropicContent });
break;
}
case "assistant": {
if (block.messages.length > 1) {
throw new UnsupportedFunctionalityError({
functionality: "Multiple assistant messages in block"
});
}
const { content } = block.messages[0];
messages.push({
role: "assistant",
content: content.map((part, j) => {
switch (part.type) {
case "text": {
if (i === blocks.length - 1 && j === block.messages.length - 1) {
return { type: "text", text: part.text.trim() };
}
return { type: "text", text: part.text };
}
case "tool-call": {
return {
type: "tool_use",
id: part.toolCallId,
name: part.toolName,
input: part.args
};
}
}
})
});
break;
}
default: {
const _exhaustiveCheck = type;
throw new Error(`Unsupported type: ${_exhaustiveCheck}`);
}
}
}
return {
system,
messages
};
}
function groupIntoBlocks(prompt) {
const blocks = [];
let currentBlock = void 0;
for (const { role, content } of prompt) {
switch (role) {
case "system": {
if ((currentBlock == null ? void 0 : currentBlock.type) !== "system") {
currentBlock = { type: "system", messages: [] };
blocks.push(currentBlock);
}
currentBlock.messages.push({ role, content });
break;
}
case "assistant": {
if ((currentBlock == null ? void 0 : currentBlock.type) !== "assistant") {
currentBlock = { type: "assistant", messages: [] };
blocks.push(currentBlock);
}
currentBlock.messages.push({ role, content });
break;
}
case "user": {
if ((currentBlock == null ? void 0 : currentBlock.type) !== "user") {
currentBlock = { type: "user", messages: [] };
blocks.push(currentBlock);
}
currentBlock.messages.push({ role, content });
break;
}
case "tool": {
if ((currentBlock == null ? void 0 : currentBlock.type) !== "user") {
currentBlock = { type: "user", messages: [] };
blocks.push(currentBlock);
}
currentBlock.messages.push({ role, content });
break;
}
default: {
const _exhaustiveCheck = role;
throw new Error(`Unsupported role: ${_exhaustiveCheck}`);
}
}
}
return blocks;
}
// src/map-anthropic-stop-reason.ts
function mapAnthropicStopReason(finishReason) {
switch (finishReason) {
case "end_turn":
case "stop_sequence":
return "stop";
case "tool_use":
return "tool-calls";
case "max_tokens":
return "length";
default:
return "other";
}
}
// src/anthropic-messages-language-model.ts
var AnthropicMessagesLanguageModel = class {
constructor(modelId, settings, config) {
this.specificationVersion = "v1";
this.defaultObjectGenerationMode = "tool";
this.supportsImageUrls = false;
this.defaultVersion = "vertex-2023-10-16";
this.modelId = modelId;
this.settings = settings;
this.config = config;
this.path = `/projects/${config.projectId}/locations/${config.region}/publishers/anthropic/models/${modelId}`;
}
get provider() {
return this.config.provider;
}
async getArgs({
mode,
prompt,
maxTokens,
temperature,
topP,
topK,
frequencyPenalty,
presencePenalty,
stopSequences,
responseFormat,
seed
}) {
const type = mode.type;
const warnings = [];
if (frequencyPenalty != null) {
warnings.push({
type: "unsupported-setting",
setting: "frequencyPenalty"
});
}
if (presencePenalty != null) {
warnings.push({
type: "unsupported-setting",
setting: "presencePenalty"
});
}
if (seed != null) {
warnings.push({
type: "unsupported-setting",
setting: "seed"
});
}
if (responseFormat != null && responseFormat.type !== "text") {
warnings.push({
type: "unsupported-setting",
setting: "responseFormat",
details: "JSON response format is not supported."
});
}
const messagesPrompt = convertToAnthropicMessagesPrompt(prompt);
const baseArgs = {
anthropic_version: this.defaultVersion,
// model specific settings:
top_k: topK != null ? topK : this.settings.topK,
// standardized settings:
max_tokens: maxTokens != null ? maxTokens : 4096,
// 4096: max model output tokens
temperature,
top_p: topP,
stop_sequences: stopSequences,
// prompt:
system: messagesPrompt.system,
messages: messagesPrompt.messages
};
switch (type) {
case "regular": {
return {
args: { ...baseArgs, ...prepareToolsAndToolChoice(mode) },
warnings
};
}
case "object-json": {
throw new UnsupportedFunctionalityError2({
functionality: "json-mode object generation"
});
}
case "object-tool": {
const { name, description, parameters } = mode.tool;
return {
args: {
...baseArgs,
tools: [{ name, description, input_schema: parameters }],
tool_choice: { type: "tool", name }
},
warnings
};
}
default: {
const _exhaustiveCheck = type;
throw new Error(`Unsupported type: ${_exhaustiveCheck}`);
}
}
}
async doGenerate(options) {
var _a;
const { args, warnings } = await this.getArgs(options);
const authClient = await ((_a = this.config.googleAuth) == null ? void 0 : _a.getClient());
const authHeaders = await (authClient == null ? void 0 : authClient.getRequestHeaders());
options.headers = { ...authHeaders, ...options.headers };
const specifier = "rawPredict";
const { responseHeaders, value: response } = await postJsonToApi({
url: `${this.config.baseURL}${this.path}:${specifier}`,
headers: combineHeaders(this.config.headers(), options.headers),
body: args,
failedResponseHandler: anthropicFailedResponseHandler,
successfulResponseHandler: createJsonResponseHandler(
anthropicMessagesResponseSchema
),
abortSignal: options.abortSignal,
fetch: this.config.fetch
});
const { messages: rawPrompt, ...rawSettings } = args;
let text = "";
for (const content of response.content) {
if (content.type === "text") {
text += content.text;
}
}
let toolCalls = void 0;
if (response.content.some((content) => content.type === "tool_use")) {
toolCalls = [];
for (const content of response.content) {
if (content.type === "tool_use") {
toolCalls.push({
toolCallType: "function",
toolCallId: content.id,
toolName: content.name,
args: JSON.stringify(content.input)
});
}
}
}
return {
text,
toolCalls,
finishReason: mapAnthropicStopReason(response.stop_reason),
usage: {
promptTokens: response.usage.input_tokens,
completionTokens: response.usage.output_tokens
},
rawCall: { rawPrompt, rawSettings },
rawResponse: { headers: responseHeaders },
warnings
};
}
async doStream(options) {
var _a;
const { args, warnings } = await this.getArgs(options);
const authClient = await ((_a = this.config.googleAuth) == null ? void 0 : _a.getClient());
const authHeaders = await (authClient == null ? void 0 : authClient.getRequestHeaders());
options.headers = { ...authHeaders, ...options.headers };
const specifier = "streamRawPredict";
const { responseHeaders, value: response } = await postJsonToApi({
url: `${this.config.baseURL}${this.path}:${specifier}`,
headers: combineHeaders(this.config.headers(), options.headers),
body: {
...args,
stream: true
},
failedResponseHandler: anthropicFailedResponseHandler,
successfulResponseHandler: createEventSourceResponseHandler(
anthropicMessagesChunkSchema
),
abortSignal: options.abortSignal,
fetch: this.config.fetch
});
const { messages: rawPrompt, ...rawSettings } = args;
let finishReason = "other";
const usage = {
promptTokens: Number.NaN,
completionTokens: Number.NaN
};
const toolCallContentBlocks = {};
return {
stream: response.pipeThrough(
new TransformStream({
transform(chunk, controller) {
if (!chunk.success) {
controller.enqueue({ type: "error", error: chunk.error });
return;
}
const value = chunk.value;
switch (value.type) {
case "ping": {
return;
}
case "content_block_start": {
const contentBlockType = value.content_block.type;
switch (contentBlockType) {
case "text": {
return;
}
case "tool_use": {
toolCallContentBlocks[value.index] = {
toolCallId: value.content_block.id,
toolName: value.content_block.name,
jsonText: ""
};
return;
}
default: {
const _exhaustiveCheck = contentBlockType;
throw new Error(
`Unsupported content block type: ${_exhaustiveCheck}`
);
}
}
}
case "content_block_stop": {
if (toolCallContentBlocks[value.index] != null) {
const contentBlock = toolCallContentBlocks[value.index];
controller.enqueue({
type: "tool-call",
toolCallType: "function",
toolCallId: contentBlock.toolCallId,
toolName: contentBlock.toolName,
args: contentBlock.jsonText
});
delete toolCallContentBlocks[value.index];
}
return;
}
case "content_block_delta": {
const deltaType = value.delta.type;
switch (deltaType) {
case "text_delta": {
controller.enqueue({
type: "text-delta",
textDelta: value.delta.text
});
return;
}
case "input_json_delta": {
const contentBlock = toolCallContentBlocks[value.index];
controller.enqueue({
type: "tool-call-delta",
toolCallType: "function",
toolCallId: contentBlock.toolCallId,
toolName: contentBlock.toolName,
argsTextDelta: value.delta.partial_json
});
contentBlock.jsonText += value.delta.partial_json;
return;
}
default: {
const _exhaustiveCheck = deltaType;
throw new Error(
`Unsupported delta type: ${_exhaustiveCheck}`
);
}
}
}
case "message_start": {
usage.promptTokens = value.message.usage.input_tokens;
usage.completionTokens = value.message.usage.output_tokens;
return;
}
case "message_delta": {
usage.completionTokens = value.usage.output_tokens;
finishReason = mapAnthropicStopReason(value.delta.stop_reason);
return;
}
case "message_stop": {
controller.enqueue({ type: "finish", finishReason, usage });
return;
}
default: {
const _exhaustiveCheck = value;
throw new Error(`Unsupported chunk type: ${_exhaustiveCheck}`);
}
}
}
})
),
rawCall: { rawPrompt, rawSettings },
rawResponse: { headers: responseHeaders },
warnings
};
}
};
var anthropicMessagesResponseSchema = z2.object({
type: z2.literal("message"),
content: z2.array(
z2.discriminatedUnion("type", [
z2.object({
type: z2.literal("text"),
text: z2.string()
}),
z2.object({
type: z2.literal("tool_use"),
id: z2.string(),
name: z2.string(),
input: z2.unknown()
})
])
),
stop_reason: z2.string().optional().nullable(),
usage: z2.object({
input_tokens: z2.number(),
output_tokens: z2.number()
})
});
var anthropicMessagesChunkSchema = z2.discriminatedUnion("type", [
z2.object({
type: z2.literal("message_start"),
message: z2.object({
usage: z2.object({
input_tokens: z2.number(),
output_tokens: z2.number()
})
})
}),
z2.object({
type: z2.literal("content_block_start"),
index: z2.number(),
content_block: z2.discriminatedUnion("type", [
z2.object({
type: z2.literal("text"),
text: z2.string()
}),
z2.object({
type: z2.literal("tool_use"),
id: z2.string(),
name: z2.string()
})
])
}),
z2.object({
type: z2.literal("content_block_delta"),
index: z2.number(),
delta: z2.discriminatedUnion("type", [
z2.object({
type: z2.literal("input_json_delta"),
partial_json: z2.string()
}),
z2.object({
type: z2.literal("text_delta"),
text: z2.string()
})
])
}),
z2.object({
type: z2.literal("content_block_stop"),
index: z2.number()
}),
z2.object({
type: z2.literal("message_delta"),
delta: z2.object({ stop_reason: z2.string().optional().nullable() }),
usage: z2.object({ output_tokens: z2.number() })
}),
z2.object({
type: z2.literal("message_stop")
}),
z2.object({
type: z2.literal("ping")
})
]);
function prepareToolsAndToolChoice(mode) {
var _a;
const tools = ((_a = mode.tools) == null ? void 0 : _a.length) ? mode.tools : void 0;
if (tools == null) {
return { tools: void 0, tool_choice: void 0 };
}
const mappedTools = tools.map((tool) => ({
name: tool.name,
description: tool.description,
input_schema: tool.parameters
}));
const toolChoice = mode.toolChoice;
if (toolChoice == null) {
return { tools: mappedTools, tool_choice: void 0 };
}
const type = toolChoice.type;
switch (type) {
case "auto":
return { tools: mappedTools, tool_choice: { type: "auto" } };
case "required":
return { tools: mappedTools, tool_choice: { type: "any" } };
case "none":
return { tools: void 0, tool_choice: void 0 };
case "tool":
return {
tools: mappedTools,
tool_choice: { type: "tool", name: toolChoice.toolName }
};
default: {
const _exhaustiveCheck = type;
throw new Error(`Unsupported tool choice type: ${_exhaustiveCheck}`);
}
}
}
// src/anthropic-vertex-provider.ts
function createAnthropicVertex(options = {}) {
const getConfig = () => {
const config = {
projectId: loadSetting({
settingValue: options.projectId,
settingName: "projectId",
environmentVariableName: "GOOGLE_VERTEX_PROJECT_ID",
description: "Google Vertex project id"
}),
region: loadSetting({
settingValue: options.region,
settingName: "region",
environmentVariableName: "GOOGLE_VERTEX_REGION",
description: "Google Vertex region"
}),
googleAuth: options.googleAuth
};
if (!config.region) {
throw new Error(
"No region was given. The client should be instantiated with the `region` option or the `GOOGLE_VERTEX_REGION` environment variable should be set."
);
}
if (!config.projectId) {
throw new Error(
"No project was given. The client should be instantiated with the `projectID` option or the `GOOGLE_VERTEX_PROJECT_ID` environment variable should be set."
);
}
return config;
};
const createChatModel = (modelId, settings = {}) => {
var _a, _b;
const config = getConfig();
const baseURL = (_a = withoutTrailingSlash(options.baseURL)) != null ? _a : `https://${config.region}-aiplatform.googleapis.com/v1`;
const auth = (_b = options.googleAuth) != null ? _b : new GoogleAuth({ scopes: "https://www.googleapis.com/auth/cloud-platform" });
return new AnthropicMessagesLanguageModel(modelId, settings, {
provider: "anthropic.messages",
baseURL,
headers: () => ({
...options.headers
}),
fetch: options.fetch,
projectId: config.projectId,
region: config.region,
googleAuth: auth
});
};
const provider = function(modelId, settings) {
if (new.target) {
throw new Error(
"The Anthropic model function cannot be called with the new keyword."
);
}
return createChatModel(modelId, settings);
};
provider.languageModel = createChatModel;
provider.chat = createChatModel;
return provider;
}
var anthropicVertex = createAnthropicVertex();
export {
anthropicVertex,
createAnthropicVertex
};
//# sourceMappingURL=index.mjs.map