ernie-ai-provider
Version:
Community-built ERNIE AI Provider for Vercel AI SDK - Integrate Baidu's ERNIE models with Vercel's AI application framework
405 lines • 16.3 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.ErnieChatLanguageModel = void 0;
const provider_utils_1 = require("@ai-sdk/provider-utils");
const zod_1 = require("zod");
/**
* ERNIE 聊天语言模型实现
*/
class ErnieChatLanguageModel {
constructor(modelId, settings, config) {
this.specificationVersion = 'v1';
this.defaultObjectGenerationMode = 'json';
this.supportsToolUse = true;
/**
* 处理失败的响应
*/
this.failedResponseHandler = async ({ response, url, requestBodyValues, }) => {
const responseBody = await response.text();
if (response.status === 401) {
throw new Error(`ERNIE API 认证失败 (${response.status}): ${responseBody}`);
}
if (response.status === 429) {
throw new Error(`ERNIE API 请求频率限制 (${response.status}): ${responseBody}`);
}
throw new Error(`ERNIE API 请求失败 (${response.status}): ${responseBody}`);
};
this.modelId = modelId;
this.settings = settings;
this.config = config;
}
get provider() {
return this.config.provider;
}
/**
* 将 AI SDK 的参数转换为 ERNIE API 格式
*/
getArgs({ prompt, mode, ...settings }) {
const type = mode.type;
const warnings = [];
if (mode.type === 'object-tool') {
throw new Error(`Object generation mode '${type}' is not supported.`);
}
const baseArgs = {
model: this.modelId,
messages: prompt.map((message) => {
switch (message.role) {
case 'system':
return { role: 'system', content: message.content };
case 'user':
return {
role: 'user',
content: typeof message.content === 'string'
? message.content
: message.content
.map((part) => {
switch (part.type) {
case 'text':
return part.text;
case 'image':
throw new Error('Image content is not supported');
default:
throw new Error(`Unsupported content type`);
}
})
.join(''),
};
case 'assistant':
return {
role: 'assistant',
content: message.content.filter((part) => part.type === 'text')
.map((part) => part.text)
.join(''),
};
case 'tool':
// ERNIE API 不直接支持 tool 消息类型
// 但我们需要将工具结果转换为 assistant 消息来继续对话
const toolResults = message.content.map((content) => {
if (content.type === 'tool-result') {
return `tool: ${content.toolName} ,result: ${content.result}`;
}
return JSON.stringify(content);
}).join('\n');
return {
role: 'assistant',
content: toolResults,
};
default:
throw new Error(`Unsupported message role`);
}
}),
};
// 添加模型特定的设置 - 下划线命名优先级高于骆驼命名
if (this.settings.temperature != null) {
baseArgs.temperature = this.settings.temperature;
}
// top_p 参数处理:下划线优先
if (this.settings.top_p != null) {
baseArgs.top_p = this.settings.top_p;
}
else if (this.settings.topP != null) {
baseArgs.top_p = this.settings.topP;
}
// max_tokens 参数处理:下划线优先
if (this.settings.max_tokens != null) {
baseArgs.max_tokens = this.settings.max_tokens;
}
else if (this.settings.maxTokens != null) {
baseArgs.max_tokens = this.settings.maxTokens;
}
// penalty_score 参数处理:下划线优先
if (this.settings.penalty_score != null) {
baseArgs.penalty_score = this.settings.penalty_score;
}
else if (this.settings.penaltyScore != null) {
baseArgs.penalty_score = this.settings.penaltyScore;
}
if (this.settings.stop != null) {
baseArgs.stop = this.settings.stop;
}
if (this.settings.seed != null) {
baseArgs.seed = this.settings.seed;
}
// frequency_penalty 参数处理:下划线优先
if (this.settings.frequency_penalty != null) {
baseArgs.frequency_penalty = this.settings.frequency_penalty;
}
else if (this.settings.frequencyPenalty != null) {
baseArgs.frequency_penalty = this.settings.frequencyPenalty;
}
// presence_penalty 参数处理:下划线优先
if (this.settings.presence_penalty != null) {
baseArgs.presence_penalty = this.settings.presence_penalty;
}
else if (this.settings.presencePenalty != null) {
baseArgs.presence_penalty = this.settings.presencePenalty;
}
// repetition_penalty 参数处理:下划线优先
if (this.settings.repetition_penalty != null) {
baseArgs.repetition_penalty = this.settings.repetition_penalty;
}
else if (this.settings.repetitionPenalty != null) {
baseArgs.repetition_penalty = this.settings.repetitionPenalty;
}
// 处理网络搜索参数 - 支持新的web_search对象格式
if (this.settings.web_search != null) {
baseArgs.web_search = this.settings.web_search;
}
else {
// AI SDK兼容性:处理旧的搜索相关参数
if (this.settings.enableSearch != null) {
baseArgs.web_search = {
enable: this.settings.enableSearch
};
}
}
// 处理系统参数
if (this.settings.system != null) {
baseArgs.system = this.settings.system;
}
if (this.settings.user_id != null) {
baseArgs.user_id = this.settings.user_id;
}
else if (this.settings.userId != null) {
// AI SDK兼容性
baseArgs.user_id = this.settings.userId;
}
// 处理工具调用
if (this.settings.tools != null) {
baseArgs.tools = this.settings.tools;
}
// 处理响应格式
if (this.settings.response_format != null) {
baseArgs.response_format = this.settings.response_format;
}
// 处理流式响应选项
if (this.settings.stream_options != null) {
baseArgs.stream_options = this.settings.stream_options;
}
else if (this.settings.streamOptions != null) {
// AI SDK兼容性
baseArgs.stream_options = {
include_usage: this.settings.streamOptions.includeUsage
};
}
// 处理工具调用
if (mode.type === 'regular' && mode.tools?.length) {
baseArgs.tools = mode.tools.map((tool) => ({
type: 'function',
function: {
name: tool.name,
description: tool.description,
parameters: tool.parameters || { type: 'object', properties: {} },
},
}));
}
// 处理对象生成模式
if (mode.type === 'object-json') {
baseArgs.response_format = {
type: 'json_object'
};
// 在系统消息中添加JSON格式要求
const jsonInstruction = 'You must respond with valid JSON only. Do not include any explanatory text outside the JSON.';
if (baseArgs.messages.length > 0 && baseArgs.messages[0].role === 'system') {
baseArgs.messages[0].content += '\n\n' + jsonInstruction;
}
else {
baseArgs.messages.unshift({
role: 'system',
content: jsonInstruction
});
}
}
return baseArgs;
}
/**
* 生成文本(非流式)
*/
async doGenerate(options) {
const args = this.getArgs(options);
const { responseHeaders, value: response } = await (0, provider_utils_1.postJsonToApi)({
url: `${this.config.baseURL}/chat/completions`,
headers: (0, provider_utils_1.combineHeaders)(this.config.headers(), options.headers),
body: args,
failedResponseHandler: this.failedResponseHandler,
successfulResponseHandler: (0, provider_utils_1.createJsonResponseHandler)(ernieResponseSchema),
abortSignal: options.abortSignal,
fetch: this.config.fetch,
});
const { messages: rawPrompt, ...rawSettings } = args;
const choice = response.choices[0];
return {
text: choice.message.content ?? '',
toolCalls: choice.message.tool_calls?.map((toolCall) => ({
toolCallType: 'function',
toolCallId: toolCall.id,
toolName: toolCall.function.name,
args: toolCall.function.arguments,
})) ?? [],
finishReason: this.mapFinishReason(choice.finish_reason),
usage: {
promptTokens: response.usage.prompt_tokens,
completionTokens: response.usage.completion_tokens,
},
rawCall: { rawPrompt, rawSettings },
rawResponse: { headers: responseHeaders },
warnings: [],
};
}
/**
* 生成文本(流式)
*/
async doStream(options) {
const args = this.getArgs(options);
args.stream = true;
const { responseHeaders, value: response } = await (0, provider_utils_1.postJsonToApi)({
url: `${this.config.baseURL}/chat/completions`,
headers: (0, provider_utils_1.combineHeaders)(this.config.headers(), options.headers),
body: args,
failedResponseHandler: this.failedResponseHandler,
successfulResponseHandler: (0, provider_utils_1.createEventSourceResponseHandler)(ernieStreamChunkSchema),
abortSignal: options.abortSignal,
fetch: this.config.fetch,
});
const { messages: rawPrompt, ...rawSettings } = args;
let finishReason = 'other';
let usage = {
promptTokens: Number.NaN,
completionTokens: Number.NaN,
};
const self = this;
return {
stream: response.pipeThrough(new TransformStream({
transform(chunk, controller) {
if (!chunk.success) {
controller.enqueue({ type: 'error', error: chunk.error });
return;
}
const value = chunk.value;
if (value.choices?.[0]?.delta?.content) {
controller.enqueue({
type: 'text-delta',
textDelta: value.choices[0].delta.content,
});
}
if (value.choices?.[0]?.finish_reason) {
finishReason = self.mapFinishReason(value.choices[0].finish_reason);
}
if (value.usage) {
usage = {
promptTokens: value.usage.prompt_tokens,
completionTokens: value.usage.completion_tokens,
};
}
if (value.choices?.[0]?.delta?.tool_calls) {
// 处理工具调用流式响应
const toolCalls = value.choices[0].delta.tool_calls;
for (const toolCall of toolCalls) {
if (toolCall.function?.name) {
controller.enqueue({
type: 'tool-call',
toolCallType: 'function',
toolCallId: toolCall.id,
toolName: toolCall.function.name,
args: toolCall.function.arguments || '',
});
}
}
}
},
flush(controller) {
controller.enqueue({
type: 'finish',
finishReason,
usage,
});
},
})),
rawCall: { rawPrompt, rawSettings },
rawResponse: { headers: responseHeaders },
warnings: [],
};
}
/**
* 映射完成原因
*/
mapFinishReason(finishReason) {
switch (finishReason) {
case 'stop':
return 'stop';
case 'length':
return 'length';
case 'tool_calls':
return 'tool-calls';
case 'content_filter':
return 'content-filter';
default:
return 'other';
}
}
}
exports.ErnieChatLanguageModel = ErnieChatLanguageModel;
// ERNIE API 响应模式定义
const ernieResponseSchema = zod_1.z.object({
id: zod_1.z.string(),
object: zod_1.z.string(),
created: zod_1.z.number(),
model: zod_1.z.string(),
choices: zod_1.z.array(zod_1.z.object({
index: zod_1.z.number(),
message: zod_1.z.object({
role: zod_1.z.string(),
content: zod_1.z.string().nullable(),
tool_calls: zod_1.z
.array(zod_1.z.object({
id: zod_1.z.string(),
type: zod_1.z.string(),
function: zod_1.z.object({
name: zod_1.z.string(),
arguments: zod_1.z.string(),
}),
}))
.optional(),
}),
finish_reason: zod_1.z.string().nullable(),
})),
usage: zod_1.z.object({
prompt_tokens: zod_1.z.number(),
completion_tokens: zod_1.z.number(),
total_tokens: zod_1.z.number(),
}),
});
const ernieStreamChunkSchema = zod_1.z.object({
id: zod_1.z.string(),
object: zod_1.z.string(),
created: zod_1.z.number(),
model: zod_1.z.string(),
choices: zod_1.z
.array(zod_1.z.object({
index: zod_1.z.number(),
delta: zod_1.z.object({
role: zod_1.z.string().optional(),
content: zod_1.z.string().optional(),
tool_calls: zod_1.z
.array(zod_1.z.object({
id: zod_1.z.string(),
type: zod_1.z.string(),
function: zod_1.z.object({
name: zod_1.z.string().optional(),
arguments: zod_1.z.string().optional(),
}),
}))
.optional(),
}),
finish_reason: zod_1.z.string().nullable().optional(),
}))
.optional(),
usage: zod_1.z
.object({
prompt_tokens: zod_1.z.number(),
completion_tokens: zod_1.z.number(),
total_tokens: zod_1.z.number(),
})
.optional(),
});
//# sourceMappingURL=ernie-chat-language-model.js.map