ernie-ai-provider
Version:
Community-built ERNIE AI Provider for Vercel AI SDK - Integrate Baidu's ERNIE models with Vercel's AI application framework
310 lines (309 loc) • 12.1 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.ErnieChatLanguageModel = void 0;
const provider_utils_1 = require("@ai-sdk/provider-utils");
const zod_1 = require("zod");
/**
* ERNIE 聊天语言模型实现
*/
class ErnieChatLanguageModel {
constructor(modelId, settings, config) {
this.specificationVersion = 'v1';
this.defaultObjectGenerationMode = 'json';
/**
* 处理失败的响应
*/
this.failedResponseHandler = async ({ response, url, requestBodyValues, }) => {
const responseBody = await response.text();
if (response.status === 401) {
throw new Error(`ERNIE API 认证失败 (${response.status}): ${responseBody}`);
}
if (response.status === 429) {
throw new Error(`ERNIE API 请求频率限制 (${response.status}): ${responseBody}`);
}
throw new Error(`ERNIE API 请求失败 (${response.status}): ${responseBody}`);
};
this.modelId = modelId;
this.settings = settings;
this.config = config;
}
get provider() {
return this.config.provider;
}
/**
* 将 AI SDK 的参数转换为 ERNIE API 格式
*/
getArgs({ prompt, mode, ...settings }) {
const type = mode.type;
const warnings = [];
if (mode.type === 'object-json' || mode.type === 'object-tool') {
throw new Error(`Object generation mode '${type}' is not supported.`);
}
const baseArgs = {
model: this.modelId,
messages: prompt.map((message) => {
switch (message.role) {
case 'system':
return { role: 'system', content: message.content };
case 'user':
return {
role: 'user',
content: typeof message.content === 'string'
? message.content
: message.content
.map((part) => {
switch (part.type) {
case 'text':
return part.text;
case 'image':
throw new Error('Image content is not supported');
default:
throw new Error(`Unsupported content type`);
}
})
.join(''),
};
case 'assistant':
return {
role: 'assistant',
content: message.content.filter((part) => part.type === 'text')
.map((part) => part.text)
.join(''),
};
case 'tool':
throw new Error('Tool messages are not supported');
default:
throw new Error(`Unsupported message role`);
}
}),
};
// 添加模型特定的设置
if (this.settings.temperature != null) {
baseArgs.temperature = this.settings.temperature;
}
if (this.settings.topP != null) {
baseArgs.top_p = this.settings.topP;
}
if (this.settings.maxTokens != null) {
baseArgs.max_tokens = this.settings.maxTokens;
}
if (this.settings.penaltyScore != null) {
baseArgs.penalty_score = this.settings.penaltyScore;
}
if (this.settings.stop != null) {
baseArgs.stop = this.settings.stop;
}
if (this.settings.seed != null) {
baseArgs.seed = this.settings.seed;
}
if (this.settings.frequencyPenalty != null) {
baseArgs.frequency_penalty = this.settings.frequencyPenalty;
}
if (this.settings.presencePenalty != null) {
baseArgs.presence_penalty = this.settings.presencePenalty;
}
if (this.settings.repetitionPenalty != null) {
baseArgs.repetition_penalty = this.settings.repetitionPenalty;
}
// 处理工具调用
if (mode.type === 'regular' && mode.tools?.length) {
baseArgs.tools = mode.tools.map((tool) => ({
type: 'function',
function: {
name: tool.name,
description: tool.description,
parameters: tool.parameters,
},
}));
}
return baseArgs;
}
/**
* 生成文本(非流式)
*/
async doGenerate(options) {
const args = this.getArgs(options);
const { responseHeaders, value: response } = await (0, provider_utils_1.postJsonToApi)({
url: `${this.config.baseURL}/chat/completions`,
headers: (0, provider_utils_1.combineHeaders)(this.config.headers(), options.headers),
body: args,
failedResponseHandler: this.failedResponseHandler,
successfulResponseHandler: (0, provider_utils_1.createJsonResponseHandler)(ernieResponseSchema),
abortSignal: options.abortSignal,
fetch: this.config.fetch,
});
const { messages: rawPrompt, ...rawSettings } = args;
const choice = response.choices[0];
return {
text: choice.message.content ?? '',
toolCalls: choice.message.tool_calls?.map((toolCall) => ({
toolCallType: 'function',
toolCallId: toolCall.id,
toolName: toolCall.function.name,
args: toolCall.function.arguments,
})) ?? [],
finishReason: this.mapFinishReason(choice.finish_reason),
usage: {
promptTokens: response.usage.prompt_tokens,
completionTokens: response.usage.completion_tokens,
},
rawCall: { rawPrompt, rawSettings },
rawResponse: { headers: responseHeaders },
warnings: [],
};
}
/**
* 生成文本(流式)
*/
async doStream(options) {
const args = this.getArgs(options);
args.stream = true;
const { responseHeaders, value: response } = await (0, provider_utils_1.postJsonToApi)({
url: `${this.config.baseURL}/chat/completions`,
headers: (0, provider_utils_1.combineHeaders)(this.config.headers(), options.headers),
body: args,
failedResponseHandler: this.failedResponseHandler,
successfulResponseHandler: (0, provider_utils_1.createEventSourceResponseHandler)(ernieStreamChunkSchema),
abortSignal: options.abortSignal,
fetch: this.config.fetch,
});
const { messages: rawPrompt, ...rawSettings } = args;
let finishReason = 'other';
let usage = {
promptTokens: Number.NaN,
completionTokens: Number.NaN,
};
return {
stream: response.pipeThrough(new TransformStream({
transform(chunk, controller) {
if (!chunk.success) {
controller.enqueue({ type: 'error', error: chunk.error });
return;
}
const value = chunk.value;
if (value.choices?.[0]?.delta?.content) {
controller.enqueue({
type: 'text-delta',
textDelta: value.choices[0].delta.content,
});
}
if (value.choices?.[0]?.finish_reason) {
finishReason = this.mapFinishReason(value.choices[0].finish_reason);
}
if (value.usage) {
usage = {
promptTokens: value.usage.prompt_tokens,
completionTokens: value.usage.completion_tokens,
};
}
if (value.choices?.[0]?.delta?.tool_calls) {
// 处理工具调用流式响应
const toolCalls = value.choices[0].delta.tool_calls;
for (const toolCall of toolCalls) {
if (toolCall.function?.name) {
controller.enqueue({
type: 'tool-call',
toolCallType: 'function',
toolCallId: toolCall.id,
toolName: toolCall.function.name,
args: toolCall.function.arguments,
});
}
}
}
},
flush(controller) {
controller.enqueue({
type: 'finish',
finishReason,
usage,
});
},
})),
rawCall: { rawPrompt, rawSettings },
rawResponse: { headers: responseHeaders },
warnings: [],
};
}
/**
* 映射完成原因
*/
mapFinishReason(finishReason) {
switch (finishReason) {
case 'stop':
return 'stop';
case 'length':
return 'length';
case 'tool_calls':
return 'tool-calls';
case 'content_filter':
return 'content-filter';
default:
return 'other';
}
}
}
exports.ErnieChatLanguageModel = ErnieChatLanguageModel;
// ERNIE API 响应模式定义
const ernieResponseSchema = zod_1.z.object({
id: zod_1.z.string(),
object: zod_1.z.string(),
created: zod_1.z.number(),
model: zod_1.z.string(),
choices: zod_1.z.array(zod_1.z.object({
index: zod_1.z.number(),
message: zod_1.z.object({
role: zod_1.z.string(),
content: zod_1.z.string().nullable(),
tool_calls: zod_1.z
.array(zod_1.z.object({
id: zod_1.z.string(),
type: zod_1.z.string(),
function: zod_1.z.object({
name: zod_1.z.string(),
arguments: zod_1.z.string(),
}),
}))
.optional(),
}),
finish_reason: zod_1.z.string().nullable(),
})),
usage: zod_1.z.object({
prompt_tokens: zod_1.z.number(),
completion_tokens: zod_1.z.number(),
total_tokens: zod_1.z.number(),
}),
});
const ernieStreamChunkSchema = zod_1.z.object({
id: zod_1.z.string(),
object: zod_1.z.string(),
created: zod_1.z.number(),
model: zod_1.z.string(),
choices: zod_1.z
.array(zod_1.z.object({
index: zod_1.z.number(),
delta: zod_1.z.object({
role: zod_1.z.string().optional(),
content: zod_1.z.string().optional(),
tool_calls: zod_1.z
.array(zod_1.z.object({
id: zod_1.z.string(),
type: zod_1.z.string(),
function: zod_1.z.object({
name: zod_1.z.string().optional(),
arguments: zod_1.z.string().optional(),
}),
}))
.optional(),
}),
finish_reason: zod_1.z.string().nullable().optional(),
}))
.optional(),
usage: zod_1.z
.object({
prompt_tokens: zod_1.z.number(),
completion_tokens: zod_1.z.number(),
total_tokens: zod_1.z.number(),
})
.optional(),
});