UNPKG

@friendliai/ai-provider

Version:

Learn how to use the FriendliAI provider for the Vercel AI SDK.

705 lines (699 loc) 23.4 kB
// src/friendli-provider.ts import { NoSuchModelError } from "@ai-sdk/provider"; import { loadApiKey, withoutTrailingSlash } from "@ai-sdk/provider-utils"; import { OpenAICompatibleCompletionLanguageModel } from "@ai-sdk/openai-compatible"; // src/friendli-settings.ts var FriendliAIServerlessModelIds = [ "meta-llama-3.1-8b-instruct", "meta-llama-3.1-70b-instruct", "meta-llama-3.3-70b-instruct", "deepseek-r1" ]; // src/friendli-chat-language-model.ts import { InvalidResponseDataError, UnsupportedFunctionalityError as UnsupportedFunctionalityError2 } from "@ai-sdk/provider"; import { combineHeaders, createEventSourceResponseHandler, createJsonErrorResponseHandler as createJsonErrorResponseHandler2, createJsonResponseHandler, generateId, isParsableJson, postJsonToApi } from "@ai-sdk/provider-utils"; import { convertToOpenAICompatibleChatMessages, getResponseMetadata, mapOpenAICompatibleFinishReason } from "@ai-sdk/openai-compatible/internal"; import { z as z2 } from "zod"; // src/friendli-error.ts import { z } from "zod"; import { createJsonErrorResponseHandler } from "@ai-sdk/provider-utils"; var friendliaiErrorSchema = z.object({ message: z.string() }); var friendliaiErrorStructure = { errorSchema: friendliaiErrorSchema, errorToMessage: (data) => data.message }; var friendliaiFailedResponseHandler = createJsonErrorResponseHandler( friendliaiErrorStructure ); // src/friendli-prepare-tools.ts import { UnsupportedFunctionalityError } from "@ai-sdk/provider"; function prepareTools({ mode, tools: hostedTools }) { var _a; const tools = ((_a = mode.tools) == null ? void 0 : _a.length) ? mode.tools : void 0; const toolWarnings = []; if (tools == null && hostedTools == null) { return { tools: void 0, tool_choice: void 0, toolWarnings }; } const toolChoice = mode.toolChoice; const mappedTools = []; if (tools) { for (const tool of tools) { if (tool.type === "provider-defined") { toolWarnings.push({ type: "unsupported-tool", tool }); } else { mappedTools.push({ type: "function", function: { name: tool.name, description: tool.description, parameters: tool.parameters } }); } } } const mappedHostedTools = hostedTools == null ? void 0 : hostedTools.map((tool) => { return { type: tool.type }; }); if (toolChoice == null) { return { tools: [...mappedTools != null ? mappedTools : [], ...mappedHostedTools != null ? mappedHostedTools : []], tool_choice: void 0, toolWarnings }; } const type = toolChoice.type; switch (type) { case "auto": case "none": case "required": return { tools: [...mappedTools != null ? mappedTools : [], ...mappedHostedTools != null ? mappedHostedTools : []], tool_choice: type, toolWarnings }; case "tool": return { tools: [...mappedTools != null ? mappedTools : [], ...mappedHostedTools != null ? mappedHostedTools : []], tool_choice: { type: "function", function: { name: toolChoice.toolName } }, toolWarnings }; default: { const _exhaustiveCheck = type; throw new UnsupportedFunctionalityError({ functionality: `Unsupported tool choice type: ${_exhaustiveCheck}` }); } } } // src/friendli-chat-language-model.ts var FriendliAIChatLanguageModel = class { constructor(modelId, settings, config) { this.specificationVersion = "v1"; var _a; this.modelId = modelId; this.settings = settings; this.config = config; this.failedResponseHandler = createJsonErrorResponseHandler2( friendliaiErrorStructure ); this.supportsStructuredOutputs = (_a = config.supportsStructuredOutputs) != null ? _a : true; } get defaultObjectGenerationMode() { var _a; return (_a = this.config.defaultObjectGenerationMode) != null ? _a : "json"; } get provider() { return this.config.provider; } getArgs({ mode, prompt, maxTokens, temperature, topP, topK, frequencyPenalty, presencePenalty, stopSequences, responseFormat, seed }) { const type = mode.type; const warnings = []; if ((responseFormat == null ? void 0 : responseFormat.type) === "json" && responseFormat.schema != null && !this.supportsStructuredOutputs) { warnings.push({ type: "unsupported-setting", setting: "responseFormat", details: "JSON response format schema is only supported with structuredOutputs" }); } const baseArgs = { // model id: model: this.modelId, // model specific settings: user: this.settings.user, parallel_tool_calls: this.settings.parallelToolCalls, // standardized settings: max_tokens: maxTokens, temperature, top_p: topP, top_k: topK, frequency_penalty: frequencyPenalty, presence_penalty: presencePenalty, response_format: (responseFormat == null ? void 0 : responseFormat.type) === "json" ? this.supportsStructuredOutputs === true && responseFormat.schema != null ? { type: "json_schema", json_schema: { schema: responseFormat.schema, description: responseFormat.description } } : { type: "json_object" } : void 0, stop: stopSequences, seed, // messages: messages: convertToOpenAICompatibleChatMessages(prompt) }; if (this.settings.regex != null && type !== "regular") { throw new UnsupportedFunctionalityError2({ functionality: "egular expression is only supported with regular mode (generateText, streamText)" }); } switch (type) { case "regular": { if (this.settings.regex != null) { if (this.settings.tools != null || mode.tools != null) { throw new UnsupportedFunctionalityError2({ functionality: "Regular expression and tools cannot be used together. Use either regular expression or tools." }); } return { args: { ...baseArgs, response_format: { type: "regex", schema: this.settings.regex.source } }, warnings }; } const { tools, tool_choice, toolWarnings } = prepareTools({ mode, tools: this.settings.tools }); return { args: { ...baseArgs, tools, tool_choice }, warnings: [...warnings, ...toolWarnings] }; } case "object-json": { return { args: { ...baseArgs, response_format: this.supportsStructuredOutputs === true && mode.schema != null ? { type: "json_schema", json_schema: { schema: mode.schema, description: mode.description } } : { type: "json_object" } }, warnings }; } case "object-tool": { return { args: { ...baseArgs, tool_choice: { type: "function", function: { name: mode.tool.name } }, tools: [ { type: "function", function: { name: mode.tool.name, description: mode.tool.description, parameters: mode.tool.parameters } } ] }, warnings }; } default: { const _exhaustiveCheck = type; throw new Error(`Unsupported type: ${_exhaustiveCheck}`); } } } async doGenerate(options) { var _a, _b, _c, _d, _e, _f; const { args, warnings } = this.getArgs({ ...options }); const body = JSON.stringify({ ...args, stream: false }); const { responseHeaders, value: response } = await postJsonToApi({ url: this.config.url({ path: "/chat/completions", modelId: this.modelId }), headers: combineHeaders(this.config.headers(), options.headers), body: { ...args, stream: false }, failedResponseHandler: this.failedResponseHandler, successfulResponseHandler: createJsonResponseHandler( friendliAIChatResponseSchema ), abortSignal: options.abortSignal, fetch: this.config.fetch }); const { messages: rawPrompt, ...rawSettings } = args; const choice = response.choices[0]; return { text: (_a = choice.message.content) != null ? _a : void 0, toolCalls: (_b = choice.message.tool_calls) == null ? void 0 : _b.map((toolCall) => { var _a2; return { toolCallType: "function", toolCallId: (_a2 = toolCall.id) != null ? _a2 : generateId(), toolName: toolCall.function.name, args: typeof toolCall.function.arguments === "string" ? toolCall.function.arguments : JSON.stringify(toolCall.function.arguments) }; }), finishReason: mapOpenAICompatibleFinishReason(choice.finish_reason), usage: { promptTokens: (_d = (_c = response.usage) == null ? void 0 : _c.prompt_tokens) != null ? _d : NaN, completionTokens: (_f = (_e = response.usage) == null ? void 0 : _e.completion_tokens) != null ? _f : NaN }, rawCall: { rawPrompt, rawSettings }, rawResponse: { headers: responseHeaders }, response: getResponseMetadata(response), warnings, request: { body } }; } async doStream(options) { const { args, warnings } = this.getArgs({ ...options }); const body = JSON.stringify({ ...args, stream: true }); const { responseHeaders, value: response } = await postJsonToApi({ url: this.config.url({ path: "/chat/completions", modelId: this.modelId }), headers: combineHeaders(this.config.headers(), options.headers), body: { ...args, stream: true, stream_options: { include_usage: true } }, failedResponseHandler: friendliaiFailedResponseHandler, successfulResponseHandler: createEventSourceResponseHandler( friendliaiChatChunkSchema ), abortSignal: options.abortSignal, fetch: this.config.fetch }); const { messages: rawPrompt, ...rawSettings } = args; const toolCalls = []; let finishReason = "unknown"; let usage = { promptTokens: void 0, completionTokens: void 0 }; let isFirstChunk = true; let providerMetadata; return { stream: response.pipeThrough( new TransformStream({ transform(chunk, controller) { var _a, _b, _c, _d, _e, _f, _g, _h, _i, _j, _k, _l, _m, _n; if (!chunk.success) { finishReason = "error"; controller.enqueue({ type: "error", error: chunk.error }); return; } const value = chunk.value; if ("status" in value) { switch (value.status) { case "STARTED": break; case "UPDATING": break; case "ENDED": break; case "ERRORED": finishReason = "error"; break; default: finishReason = "error"; controller.enqueue({ type: "error", error: new Error( `Unsupported tool call status: ${value.status}` ) }); } return; } if ("message" in value) { console.error("Error chunk:", value); finishReason = "error"; controller.enqueue({ type: "error", error: value.message }); return; } if (isFirstChunk) { isFirstChunk = false; controller.enqueue({ type: "response-metadata", ...getResponseMetadata(value) }); } if (value.usage != null) { usage = { promptTokens: (_a = value.usage.prompt_tokens) != null ? _a : void 0, completionTokens: (_b = value.usage.completion_tokens) != null ? _b : void 0 }; } const choice = value.choices[0]; if ((choice == null ? void 0 : choice.finish_reason) != null) { finishReason = mapOpenAICompatibleFinishReason( choice.finish_reason ); } if ((choice == null ? void 0 : choice.delta) == null) { return; } const delta = choice.delta; if (delta.content != null) { controller.enqueue({ type: "text-delta", textDelta: delta.content }); } if (delta.tool_calls != null) { for (const toolCallDelta of delta.tool_calls) { const index = toolCallDelta.index; if (toolCalls[index] == null) { if (toolCallDelta.type !== "function") { throw new InvalidResponseDataError({ data: toolCallDelta, message: `Expected 'function' type.` }); } if (toolCallDelta.id == null) { throw new InvalidResponseDataError({ data: toolCallDelta, message: `Expected 'id' to be a string.` }); } if (((_c = toolCallDelta.function) == null ? void 0 : _c.name) == null) { throw new InvalidResponseDataError({ data: toolCallDelta, message: `Expected 'function.name' to be a string.` }); } toolCalls[index] = { id: toolCallDelta.id, type: "function", function: { name: toolCallDelta.function.name, arguments: (_d = toolCallDelta.function.arguments) != null ? _d : "" } }; const toolCall2 = toolCalls[index]; if (((_e = toolCall2.function) == null ? void 0 : _e.name) != null && ((_f = toolCall2.function) == null ? void 0 : _f.arguments) != null) { if (toolCall2.function.arguments.length > 0) { controller.enqueue({ type: "tool-call-delta", toolCallType: "function", toolCallId: toolCall2.id, toolName: toolCall2.function.name, argsTextDelta: toolCall2.function.arguments }); } if (isParsableJson(toolCall2.function.arguments)) { controller.enqueue({ type: "tool-call", toolCallType: "function", toolCallId: (_g = toolCall2.id) != null ? _g : generateId(), toolName: toolCall2.function.name, args: toolCall2.function.arguments }); } } continue; } const toolCall = toolCalls[index]; if (((_h = toolCallDelta.function) == null ? void 0 : _h.arguments) != null) { toolCall.function.arguments += (_j = (_i = toolCallDelta.function) == null ? void 0 : _i.arguments) != null ? _j : ""; } controller.enqueue({ type: "tool-call-delta", toolCallType: "function", toolCallId: toolCall.id, toolName: toolCall.function.name, argsTextDelta: (_k = toolCallDelta.function.arguments) != null ? _k : "" }); if (((_l = toolCall.function) == null ? void 0 : _l.name) != null && ((_m = toolCall.function) == null ? void 0 : _m.arguments) != null && isParsableJson(toolCall.function.arguments)) { controller.enqueue({ type: "tool-call", toolCallType: "function", toolCallId: (_n = toolCall.id) != null ? _n : generateId(), toolName: toolCall.function.name, args: toolCall.function.arguments }); } } } }, flush(controller) { var _a, _b; controller.enqueue({ type: "finish", finishReason, usage: { promptTokens: (_a = usage.promptTokens) != null ? _a : NaN, completionTokens: (_b = usage.completionTokens) != null ? _b : NaN }, ...providerMetadata != null ? { providerMetadata } : {} }); } }) ), rawCall: { rawPrompt, rawSettings }, rawResponse: { headers: responseHeaders }, warnings, request: { body } }; } }; var friendliAIChatResponseSchema = z2.object({ id: z2.string().nullish(), created: z2.number().nullish(), model: z2.string().nullish(), choices: z2.array( z2.object({ message: z2.object({ role: z2.literal("assistant").nullish(), content: z2.string().nullish(), tool_calls: z2.array( z2.object({ id: z2.string().nullish(), type: z2.literal("function"), function: z2.object({ name: z2.string(), arguments: z2.union([z2.string(), z2.any()]).nullish() }) }) ).nullish() }), finish_reason: z2.string().nullish() }) ), usage: z2.object({ prompt_tokens: z2.number().nullish(), completion_tokens: z2.number().nullish() }).nullish() }); var friendliaiChatChunkSchema = z2.union([ z2.object({ id: z2.string().nullish(), created: z2.number().nullish(), model: z2.string().nullish(), choices: z2.array( z2.object({ delta: z2.object({ role: z2.enum(["assistant"]).nullish(), content: z2.string().nullish(), tool_calls: z2.array( z2.object({ index: z2.number(), id: z2.string().nullish(), type: z2.literal("function").optional(), function: z2.object({ name: z2.string().nullish(), arguments: z2.string().nullish() }) }) ).nullish() }).nullish(), finish_reason: z2.string().nullish() }) ), usage: z2.object({ prompt_tokens: z2.number().nullish(), completion_tokens: z2.number().nullish() }).nullish() }), z2.object({ name: z2.string(), status: z2.enum(["ENDED", "STARTED", "ERRORED", "UPDATING"]), message: z2.null(), parameters: z2.array( z2.object({ name: z2.string(), value: z2.string() }) ), result: z2.string().nullable(), error: z2.object({ type: z2.enum(["INVALID_PARAMETER", "UNKNOWN"]), msg: z2.string() }).nullable(), timestamp: z2.number(), usage: z2.null(), tool_call_id: z2.string().nullable() // temporary fix for "file:text" tool calls }), friendliaiErrorSchema ]); // src/friendli-provider.ts function createFriendli(options = {}) { const getHeaders = () => ({ Authorization: `Bearer ${loadApiKey({ apiKey: options.apiKey, environmentVariableName: "FRIENDLI_TOKEN", description: "FRIENDLI_TOKEN" })}`, "X-Friendli-Team": options.teamId, ...options.headers }); const baseURLAutoSelect = (modelId, endpoint, baseURL, tools) => { const customBaseURL = withoutTrailingSlash(baseURL); if (typeof customBaseURL === "string") { return { baseURL: customBaseURL, type: "custom" }; } const FriendliBaseURL = { beta: "https://api.friendli.ai/serverless/beta", serverless: "https://api.friendli.ai/serverless/v1", tools: "https://api.friendli.ai/serverless/tools/v1", dedicated: "https://api.friendli.ai/dedicated/v1" }; if (endpoint === "beta") { return { baseURL: FriendliBaseURL.beta, type: "beta" }; } if ( // If the endpoint setting is serverless or auto and the model is floating on serverless, endpoint === "serverless" || endpoint === "auto" && Object.values(FriendliAIServerlessModelIds).includes( modelId ) ) { if (tools && tools.length > 0) { return { baseURL: FriendliBaseURL.tools, type: "tools" }; } return { baseURL: FriendliBaseURL.serverless, type: "serverless" }; } else { return { baseURL: FriendliBaseURL.dedicated, type: "dedicated" }; } }; const createChatModel = (modelId, settings = {}) => { const { baseURL, type } = baseURLAutoSelect( modelId, settings.endpoint || "auto", options.baseURL, settings.tools ); return new FriendliAIChatLanguageModel(modelId, settings, { provider: `friendliai.${type}.chat`, url: ({ path }) => `${baseURL}${path}`, headers: getHeaders, fetch: options.fetch, defaultObjectGenerationMode: "json" }); }; const createCompletionModel = (modelId, settings = {}) => { const { baseURL, type } = baseURLAutoSelect( modelId, settings.endpoint || "auto", options.baseURL ); return new OpenAICompatibleCompletionLanguageModel(modelId, settings, { provider: `friendliai.${type}.completion`, url: ({ path }) => `${baseURL}${path}`, headers: getHeaders, fetch: options.fetch, errorStructure: friendliaiErrorStructure }); }; const createBetaModel = (modelId, settings = {}) => { const { baseURL, type } = baseURLAutoSelect( modelId, "beta", options.baseURL ); return new FriendliAIChatLanguageModel(modelId, settings, { provider: `friendliai.${type}.chat`, url: ({ path }) => `${baseURL}${path}`, headers: getHeaders, fetch: options.fetch, defaultObjectGenerationMode: "json" }); }; const createTextEmbeddingModel = (modelId) => { throw new NoSuchModelError({ modelId, modelType: "textEmbeddingModel" }); }; const provider = function(modelId, settings) { return createChatModel(modelId, settings); }; provider.beta = createBetaModel; provider.chat = createChatModel; provider.chatModel = createChatModel; provider.completion = createCompletionModel; provider.completionModel = createCompletionModel; provider.embedding = createTextEmbeddingModel; provider.textEmbeddingModel = createTextEmbeddingModel; return provider; } var friendli = createFriendli({}); export { createFriendli, friendli }; //# sourceMappingURL=index.mjs.map