@friendliai/ai-provider
Version:
Learn how to use the FriendliAI provider for the Vercel AI SDK.
705 lines (699 loc) • 23.4 kB
JavaScript
// src/friendli-provider.ts
import { NoSuchModelError } from "@ai-sdk/provider";
import {
loadApiKey,
withoutTrailingSlash
} from "@ai-sdk/provider-utils";
import { OpenAICompatibleCompletionLanguageModel } from "@ai-sdk/openai-compatible";
// src/friendli-settings.ts
var FriendliAIServerlessModelIds = [
"meta-llama-3.1-8b-instruct",
"meta-llama-3.1-70b-instruct",
"meta-llama-3.3-70b-instruct",
"deepseek-r1"
];
// src/friendli-chat-language-model.ts
import {
InvalidResponseDataError,
UnsupportedFunctionalityError as UnsupportedFunctionalityError2
} from "@ai-sdk/provider";
import {
combineHeaders,
createEventSourceResponseHandler,
createJsonErrorResponseHandler as createJsonErrorResponseHandler2,
createJsonResponseHandler,
generateId,
isParsableJson,
postJsonToApi
} from "@ai-sdk/provider-utils";
import {
convertToOpenAICompatibleChatMessages,
getResponseMetadata,
mapOpenAICompatibleFinishReason
} from "@ai-sdk/openai-compatible/internal";
import { z as z2 } from "zod";
// src/friendli-error.ts
import { z } from "zod";
import { createJsonErrorResponseHandler } from "@ai-sdk/provider-utils";
var friendliaiErrorSchema = z.object({
message: z.string()
});
var friendliaiErrorStructure = {
errorSchema: friendliaiErrorSchema,
errorToMessage: (data) => data.message
};
var friendliaiFailedResponseHandler = createJsonErrorResponseHandler(
friendliaiErrorStructure
);
// src/friendli-prepare-tools.ts
import {
UnsupportedFunctionalityError
} from "@ai-sdk/provider";
function prepareTools({
mode,
tools: hostedTools
}) {
var _a;
const tools = ((_a = mode.tools) == null ? void 0 : _a.length) ? mode.tools : void 0;
const toolWarnings = [];
if (tools == null && hostedTools == null) {
return { tools: void 0, tool_choice: void 0, toolWarnings };
}
const toolChoice = mode.toolChoice;
const mappedTools = [];
if (tools) {
for (const tool of tools) {
if (tool.type === "provider-defined") {
toolWarnings.push({ type: "unsupported-tool", tool });
} else {
mappedTools.push({
type: "function",
function: {
name: tool.name,
description: tool.description,
parameters: tool.parameters
}
});
}
}
}
const mappedHostedTools = hostedTools == null ? void 0 : hostedTools.map((tool) => {
return {
type: tool.type
};
});
if (toolChoice == null) {
return {
tools: [...mappedTools != null ? mappedTools : [], ...mappedHostedTools != null ? mappedHostedTools : []],
tool_choice: void 0,
toolWarnings
};
}
const type = toolChoice.type;
switch (type) {
case "auto":
case "none":
case "required":
return {
tools: [...mappedTools != null ? mappedTools : [], ...mappedHostedTools != null ? mappedHostedTools : []],
tool_choice: type,
toolWarnings
};
case "tool":
return {
tools: [...mappedTools != null ? mappedTools : [], ...mappedHostedTools != null ? mappedHostedTools : []],
tool_choice: {
type: "function",
function: {
name: toolChoice.toolName
}
},
toolWarnings
};
default: {
const _exhaustiveCheck = type;
throw new UnsupportedFunctionalityError({
functionality: `Unsupported tool choice type: ${_exhaustiveCheck}`
});
}
}
}
// src/friendli-chat-language-model.ts
var FriendliAIChatLanguageModel = class {
constructor(modelId, settings, config) {
this.specificationVersion = "v1";
var _a;
this.modelId = modelId;
this.settings = settings;
this.config = config;
this.failedResponseHandler = createJsonErrorResponseHandler2(
friendliaiErrorStructure
);
this.supportsStructuredOutputs = (_a = config.supportsStructuredOutputs) != null ? _a : true;
}
get defaultObjectGenerationMode() {
var _a;
return (_a = this.config.defaultObjectGenerationMode) != null ? _a : "json";
}
get provider() {
return this.config.provider;
}
getArgs({
mode,
prompt,
maxTokens,
temperature,
topP,
topK,
frequencyPenalty,
presencePenalty,
stopSequences,
responseFormat,
seed
}) {
const type = mode.type;
const warnings = [];
if ((responseFormat == null ? void 0 : responseFormat.type) === "json" && responseFormat.schema != null && !this.supportsStructuredOutputs) {
warnings.push({
type: "unsupported-setting",
setting: "responseFormat",
details: "JSON response format schema is only supported with structuredOutputs"
});
}
const baseArgs = {
// model id:
model: this.modelId,
// model specific settings:
user: this.settings.user,
parallel_tool_calls: this.settings.parallelToolCalls,
// standardized settings:
max_tokens: maxTokens,
temperature,
top_p: topP,
top_k: topK,
frequency_penalty: frequencyPenalty,
presence_penalty: presencePenalty,
response_format: (responseFormat == null ? void 0 : responseFormat.type) === "json" ? this.supportsStructuredOutputs === true && responseFormat.schema != null ? {
type: "json_schema",
json_schema: {
schema: responseFormat.schema,
description: responseFormat.description
}
} : { type: "json_object" } : void 0,
stop: stopSequences,
seed,
// messages:
messages: convertToOpenAICompatibleChatMessages(prompt)
};
if (this.settings.regex != null && type !== "regular") {
throw new UnsupportedFunctionalityError2({
functionality: "egular expression is only supported with regular mode (generateText, streamText)"
});
}
switch (type) {
case "regular": {
if (this.settings.regex != null) {
if (this.settings.tools != null || mode.tools != null) {
throw new UnsupportedFunctionalityError2({
functionality: "Regular expression and tools cannot be used together. Use either regular expression or tools."
});
}
return {
args: {
...baseArgs,
response_format: {
type: "regex",
schema: this.settings.regex.source
}
},
warnings
};
}
const { tools, tool_choice, toolWarnings } = prepareTools({
mode,
tools: this.settings.tools
});
return {
args: { ...baseArgs, tools, tool_choice },
warnings: [...warnings, ...toolWarnings]
};
}
case "object-json": {
return {
args: {
...baseArgs,
response_format: this.supportsStructuredOutputs === true && mode.schema != null ? {
type: "json_schema",
json_schema: {
schema: mode.schema,
description: mode.description
}
} : { type: "json_object" }
},
warnings
};
}
case "object-tool": {
return {
args: {
...baseArgs,
tool_choice: {
type: "function",
function: { name: mode.tool.name }
},
tools: [
{
type: "function",
function: {
name: mode.tool.name,
description: mode.tool.description,
parameters: mode.tool.parameters
}
}
]
},
warnings
};
}
default: {
const _exhaustiveCheck = type;
throw new Error(`Unsupported type: ${_exhaustiveCheck}`);
}
}
}
async doGenerate(options) {
var _a, _b, _c, _d, _e, _f;
const { args, warnings } = this.getArgs({ ...options });
const body = JSON.stringify({ ...args, stream: false });
const { responseHeaders, value: response } = await postJsonToApi({
url: this.config.url({
path: "/chat/completions",
modelId: this.modelId
}),
headers: combineHeaders(this.config.headers(), options.headers),
body: {
...args,
stream: false
},
failedResponseHandler: this.failedResponseHandler,
successfulResponseHandler: createJsonResponseHandler(
friendliAIChatResponseSchema
),
abortSignal: options.abortSignal,
fetch: this.config.fetch
});
const { messages: rawPrompt, ...rawSettings } = args;
const choice = response.choices[0];
return {
text: (_a = choice.message.content) != null ? _a : void 0,
toolCalls: (_b = choice.message.tool_calls) == null ? void 0 : _b.map((toolCall) => {
var _a2;
return {
toolCallType: "function",
toolCallId: (_a2 = toolCall.id) != null ? _a2 : generateId(),
toolName: toolCall.function.name,
args: typeof toolCall.function.arguments === "string" ? toolCall.function.arguments : JSON.stringify(toolCall.function.arguments)
};
}),
finishReason: mapOpenAICompatibleFinishReason(choice.finish_reason),
usage: {
promptTokens: (_d = (_c = response.usage) == null ? void 0 : _c.prompt_tokens) != null ? _d : NaN,
completionTokens: (_f = (_e = response.usage) == null ? void 0 : _e.completion_tokens) != null ? _f : NaN
},
rawCall: { rawPrompt, rawSettings },
rawResponse: { headers: responseHeaders },
response: getResponseMetadata(response),
warnings,
request: { body }
};
}
async doStream(options) {
const { args, warnings } = this.getArgs({ ...options });
const body = JSON.stringify({ ...args, stream: true });
const { responseHeaders, value: response } = await postJsonToApi({
url: this.config.url({
path: "/chat/completions",
modelId: this.modelId
}),
headers: combineHeaders(this.config.headers(), options.headers),
body: {
...args,
stream: true,
stream_options: { include_usage: true }
},
failedResponseHandler: friendliaiFailedResponseHandler,
successfulResponseHandler: createEventSourceResponseHandler(
friendliaiChatChunkSchema
),
abortSignal: options.abortSignal,
fetch: this.config.fetch
});
const { messages: rawPrompt, ...rawSettings } = args;
const toolCalls = [];
let finishReason = "unknown";
let usage = {
promptTokens: void 0,
completionTokens: void 0
};
let isFirstChunk = true;
let providerMetadata;
return {
stream: response.pipeThrough(
new TransformStream({
transform(chunk, controller) {
var _a, _b, _c, _d, _e, _f, _g, _h, _i, _j, _k, _l, _m, _n;
if (!chunk.success) {
finishReason = "error";
controller.enqueue({ type: "error", error: chunk.error });
return;
}
const value = chunk.value;
if ("status" in value) {
switch (value.status) {
case "STARTED":
break;
case "UPDATING":
break;
case "ENDED":
break;
case "ERRORED":
finishReason = "error";
break;
default:
finishReason = "error";
controller.enqueue({
type: "error",
error: new Error(
`Unsupported tool call status: ${value.status}`
)
});
}
return;
}
if ("message" in value) {
console.error("Error chunk:", value);
finishReason = "error";
controller.enqueue({ type: "error", error: value.message });
return;
}
if (isFirstChunk) {
isFirstChunk = false;
controller.enqueue({
type: "response-metadata",
...getResponseMetadata(value)
});
}
if (value.usage != null) {
usage = {
promptTokens: (_a = value.usage.prompt_tokens) != null ? _a : void 0,
completionTokens: (_b = value.usage.completion_tokens) != null ? _b : void 0
};
}
const choice = value.choices[0];
if ((choice == null ? void 0 : choice.finish_reason) != null) {
finishReason = mapOpenAICompatibleFinishReason(
choice.finish_reason
);
}
if ((choice == null ? void 0 : choice.delta) == null) {
return;
}
const delta = choice.delta;
if (delta.content != null) {
controller.enqueue({
type: "text-delta",
textDelta: delta.content
});
}
if (delta.tool_calls != null) {
for (const toolCallDelta of delta.tool_calls) {
const index = toolCallDelta.index;
if (toolCalls[index] == null) {
if (toolCallDelta.type !== "function") {
throw new InvalidResponseDataError({
data: toolCallDelta,
message: `Expected 'function' type.`
});
}
if (toolCallDelta.id == null) {
throw new InvalidResponseDataError({
data: toolCallDelta,
message: `Expected 'id' to be a string.`
});
}
if (((_c = toolCallDelta.function) == null ? void 0 : _c.name) == null) {
throw new InvalidResponseDataError({
data: toolCallDelta,
message: `Expected 'function.name' to be a string.`
});
}
toolCalls[index] = {
id: toolCallDelta.id,
type: "function",
function: {
name: toolCallDelta.function.name,
arguments: (_d = toolCallDelta.function.arguments) != null ? _d : ""
}
};
const toolCall2 = toolCalls[index];
if (((_e = toolCall2.function) == null ? void 0 : _e.name) != null && ((_f = toolCall2.function) == null ? void 0 : _f.arguments) != null) {
if (toolCall2.function.arguments.length > 0) {
controller.enqueue({
type: "tool-call-delta",
toolCallType: "function",
toolCallId: toolCall2.id,
toolName: toolCall2.function.name,
argsTextDelta: toolCall2.function.arguments
});
}
if (isParsableJson(toolCall2.function.arguments)) {
controller.enqueue({
type: "tool-call",
toolCallType: "function",
toolCallId: (_g = toolCall2.id) != null ? _g : generateId(),
toolName: toolCall2.function.name,
args: toolCall2.function.arguments
});
}
}
continue;
}
const toolCall = toolCalls[index];
if (((_h = toolCallDelta.function) == null ? void 0 : _h.arguments) != null) {
toolCall.function.arguments += (_j = (_i = toolCallDelta.function) == null ? void 0 : _i.arguments) != null ? _j : "";
}
controller.enqueue({
type: "tool-call-delta",
toolCallType: "function",
toolCallId: toolCall.id,
toolName: toolCall.function.name,
argsTextDelta: (_k = toolCallDelta.function.arguments) != null ? _k : ""
});
if (((_l = toolCall.function) == null ? void 0 : _l.name) != null && ((_m = toolCall.function) == null ? void 0 : _m.arguments) != null && isParsableJson(toolCall.function.arguments)) {
controller.enqueue({
type: "tool-call",
toolCallType: "function",
toolCallId: (_n = toolCall.id) != null ? _n : generateId(),
toolName: toolCall.function.name,
args: toolCall.function.arguments
});
}
}
}
},
flush(controller) {
var _a, _b;
controller.enqueue({
type: "finish",
finishReason,
usage: {
promptTokens: (_a = usage.promptTokens) != null ? _a : NaN,
completionTokens: (_b = usage.completionTokens) != null ? _b : NaN
},
...providerMetadata != null ? { providerMetadata } : {}
});
}
})
),
rawCall: { rawPrompt, rawSettings },
rawResponse: { headers: responseHeaders },
warnings,
request: { body }
};
}
};
var friendliAIChatResponseSchema = z2.object({
id: z2.string().nullish(),
created: z2.number().nullish(),
model: z2.string().nullish(),
choices: z2.array(
z2.object({
message: z2.object({
role: z2.literal("assistant").nullish(),
content: z2.string().nullish(),
tool_calls: z2.array(
z2.object({
id: z2.string().nullish(),
type: z2.literal("function"),
function: z2.object({
name: z2.string(),
arguments: z2.union([z2.string(), z2.any()]).nullish()
})
})
).nullish()
}),
finish_reason: z2.string().nullish()
})
),
usage: z2.object({
prompt_tokens: z2.number().nullish(),
completion_tokens: z2.number().nullish()
}).nullish()
});
var friendliaiChatChunkSchema = z2.union([
z2.object({
id: z2.string().nullish(),
created: z2.number().nullish(),
model: z2.string().nullish(),
choices: z2.array(
z2.object({
delta: z2.object({
role: z2.enum(["assistant"]).nullish(),
content: z2.string().nullish(),
tool_calls: z2.array(
z2.object({
index: z2.number(),
id: z2.string().nullish(),
type: z2.literal("function").optional(),
function: z2.object({
name: z2.string().nullish(),
arguments: z2.string().nullish()
})
})
).nullish()
}).nullish(),
finish_reason: z2.string().nullish()
})
),
usage: z2.object({
prompt_tokens: z2.number().nullish(),
completion_tokens: z2.number().nullish()
}).nullish()
}),
z2.object({
name: z2.string(),
status: z2.enum(["ENDED", "STARTED", "ERRORED", "UPDATING"]),
message: z2.null(),
parameters: z2.array(
z2.object({
name: z2.string(),
value: z2.string()
})
),
result: z2.string().nullable(),
error: z2.object({
type: z2.enum(["INVALID_PARAMETER", "UNKNOWN"]),
msg: z2.string()
}).nullable(),
timestamp: z2.number(),
usage: z2.null(),
tool_call_id: z2.string().nullable()
// temporary fix for "file:text" tool calls
}),
friendliaiErrorSchema
]);
// src/friendli-provider.ts
function createFriendli(options = {}) {
const getHeaders = () => ({
Authorization: `Bearer ${loadApiKey({
apiKey: options.apiKey,
environmentVariableName: "FRIENDLI_TOKEN",
description: "FRIENDLI_TOKEN"
})}`,
"X-Friendli-Team": options.teamId,
...options.headers
});
const baseURLAutoSelect = (modelId, endpoint, baseURL, tools) => {
const customBaseURL = withoutTrailingSlash(baseURL);
if (typeof customBaseURL === "string") {
return { baseURL: customBaseURL, type: "custom" };
}
const FriendliBaseURL = {
beta: "https://api.friendli.ai/serverless/beta",
serverless: "https://api.friendli.ai/serverless/v1",
tools: "https://api.friendli.ai/serverless/tools/v1",
dedicated: "https://api.friendli.ai/dedicated/v1"
};
if (endpoint === "beta") {
return {
baseURL: FriendliBaseURL.beta,
type: "beta"
};
}
if (
// If the endpoint setting is serverless or auto and the model is floating on serverless,
endpoint === "serverless" || endpoint === "auto" && Object.values(FriendliAIServerlessModelIds).includes(
modelId
)
) {
if (tools && tools.length > 0) {
return {
baseURL: FriendliBaseURL.tools,
type: "tools"
};
}
return {
baseURL: FriendliBaseURL.serverless,
type: "serverless"
};
} else {
return {
baseURL: FriendliBaseURL.dedicated,
type: "dedicated"
};
}
};
const createChatModel = (modelId, settings = {}) => {
const { baseURL, type } = baseURLAutoSelect(
modelId,
settings.endpoint || "auto",
options.baseURL,
settings.tools
);
return new FriendliAIChatLanguageModel(modelId, settings, {
provider: `friendliai.${type}.chat`,
url: ({ path }) => `${baseURL}${path}`,
headers: getHeaders,
fetch: options.fetch,
defaultObjectGenerationMode: "json"
});
};
const createCompletionModel = (modelId, settings = {}) => {
const { baseURL, type } = baseURLAutoSelect(
modelId,
settings.endpoint || "auto",
options.baseURL
);
return new OpenAICompatibleCompletionLanguageModel(modelId, settings, {
provider: `friendliai.${type}.completion`,
url: ({ path }) => `${baseURL}${path}`,
headers: getHeaders,
fetch: options.fetch,
errorStructure: friendliaiErrorStructure
});
};
const createBetaModel = (modelId, settings = {}) => {
const { baseURL, type } = baseURLAutoSelect(
modelId,
"beta",
options.baseURL
);
return new FriendliAIChatLanguageModel(modelId, settings, {
provider: `friendliai.${type}.chat`,
url: ({ path }) => `${baseURL}${path}`,
headers: getHeaders,
fetch: options.fetch,
defaultObjectGenerationMode: "json"
});
};
const createTextEmbeddingModel = (modelId) => {
throw new NoSuchModelError({ modelId, modelType: "textEmbeddingModel" });
};
const provider = function(modelId, settings) {
return createChatModel(modelId, settings);
};
provider.beta = createBetaModel;
provider.chat = createChatModel;
provider.chatModel = createChatModel;
provider.completion = createCompletionModel;
provider.completionModel = createCompletionModel;
provider.embedding = createTextEmbeddingModel;
provider.textEmbeddingModel = createTextEmbeddingModel;
return provider;
}
var friendli = createFriendli({});
export {
createFriendli,
friendli
};
//# sourceMappingURL=index.mjs.map