UNPKG

ernie-ai-provider

Version:

Community-built ERNIE AI Provider for Vercel AI SDK - Integrate Baidu's ERNIE models with Vercel's AI application framework

405 lines 16.3 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.ErnieChatLanguageModel = void 0; const provider_utils_1 = require("@ai-sdk/provider-utils"); const zod_1 = require("zod"); /** * ERNIE 聊天语言模型实现 */ class ErnieChatLanguageModel { constructor(modelId, settings, config) { this.specificationVersion = 'v1'; this.defaultObjectGenerationMode = 'json'; this.supportsToolUse = true; /** * 处理失败的响应 */ this.failedResponseHandler = async ({ response, url, requestBodyValues, }) => { const responseBody = await response.text(); if (response.status === 401) { throw new Error(`ERNIE API 认证失败 (${response.status}): ${responseBody}`); } if (response.status === 429) { throw new Error(`ERNIE API 请求频率限制 (${response.status}): ${responseBody}`); } throw new Error(`ERNIE API 请求失败 (${response.status}): ${responseBody}`); }; this.modelId = modelId; this.settings = settings; this.config = config; } get provider() { return this.config.provider; } /** * 将 AI SDK 的参数转换为 ERNIE API 格式 */ getArgs({ prompt, mode, ...settings }) { const type = mode.type; const warnings = []; if (mode.type === 'object-tool') { throw new Error(`Object generation mode '${type}' is not supported.`); } const baseArgs = { model: this.modelId, messages: prompt.map((message) => { switch (message.role) { case 'system': return { role: 'system', content: message.content }; case 'user': return { role: 'user', content: typeof message.content === 'string' ? message.content : message.content .map((part) => { switch (part.type) { case 'text': return part.text; case 'image': throw new Error('Image content is not supported'); default: throw new Error(`Unsupported content type`); } }) .join(''), }; case 'assistant': return { role: 'assistant', content: message.content.filter((part) => part.type === 'text') .map((part) => part.text) .join(''), }; case 'tool': // ERNIE API 不直接支持 tool 消息类型 // 但我们需要将工具结果转换为 assistant 消息来继续对话 const toolResults = message.content.map((content) => { if (content.type === 'tool-result') { return `tool: ${content.toolName} ,result: ${content.result}`; } return JSON.stringify(content); }).join('\n'); return { role: 'assistant', content: toolResults, }; default: throw new Error(`Unsupported message role`); } }), }; // 添加模型特定的设置 - 下划线命名优先级高于骆驼命名 if (this.settings.temperature != null) { baseArgs.temperature = this.settings.temperature; } // top_p 参数处理:下划线优先 if (this.settings.top_p != null) { baseArgs.top_p = this.settings.top_p; } else if (this.settings.topP != null) { baseArgs.top_p = this.settings.topP; } // max_tokens 参数处理:下划线优先 if (this.settings.max_tokens != null) { baseArgs.max_tokens = this.settings.max_tokens; } else if (this.settings.maxTokens != null) { baseArgs.max_tokens = this.settings.maxTokens; } // penalty_score 参数处理:下划线优先 if (this.settings.penalty_score != null) { baseArgs.penalty_score = this.settings.penalty_score; } else if (this.settings.penaltyScore != null) { baseArgs.penalty_score = this.settings.penaltyScore; } if (this.settings.stop != null) { baseArgs.stop = this.settings.stop; } if (this.settings.seed != null) { baseArgs.seed = this.settings.seed; } // frequency_penalty 参数处理:下划线优先 if (this.settings.frequency_penalty != null) { baseArgs.frequency_penalty = this.settings.frequency_penalty; } else if (this.settings.frequencyPenalty != null) { baseArgs.frequency_penalty = this.settings.frequencyPenalty; } // presence_penalty 参数处理:下划线优先 if (this.settings.presence_penalty != null) { baseArgs.presence_penalty = this.settings.presence_penalty; } else if (this.settings.presencePenalty != null) { baseArgs.presence_penalty = this.settings.presencePenalty; } // repetition_penalty 参数处理:下划线优先 if (this.settings.repetition_penalty != null) { baseArgs.repetition_penalty = this.settings.repetition_penalty; } else if (this.settings.repetitionPenalty != null) { baseArgs.repetition_penalty = this.settings.repetitionPenalty; } // 处理网络搜索参数 - 支持新的web_search对象格式 if (this.settings.web_search != null) { baseArgs.web_search = this.settings.web_search; } else { // AI SDK兼容性:处理旧的搜索相关参数 if (this.settings.enableSearch != null) { baseArgs.web_search = { enable: this.settings.enableSearch }; } } // 处理系统参数 if (this.settings.system != null) { baseArgs.system = this.settings.system; } if (this.settings.user_id != null) { baseArgs.user_id = this.settings.user_id; } else if (this.settings.userId != null) { // AI SDK兼容性 baseArgs.user_id = this.settings.userId; } // 处理工具调用 if (this.settings.tools != null) { baseArgs.tools = this.settings.tools; } // 处理响应格式 if (this.settings.response_format != null) { baseArgs.response_format = this.settings.response_format; } // 处理流式响应选项 if (this.settings.stream_options != null) { baseArgs.stream_options = this.settings.stream_options; } else if (this.settings.streamOptions != null) { // AI SDK兼容性 baseArgs.stream_options = { include_usage: this.settings.streamOptions.includeUsage }; } // 处理工具调用 if (mode.type === 'regular' && mode.tools?.length) { baseArgs.tools = mode.tools.map((tool) => ({ type: 'function', function: { name: tool.name, description: tool.description, parameters: tool.parameters || { type: 'object', properties: {} }, }, })); } // 处理对象生成模式 if (mode.type === 'object-json') { baseArgs.response_format = { type: 'json_object' }; // 在系统消息中添加JSON格式要求 const jsonInstruction = 'You must respond with valid JSON only. Do not include any explanatory text outside the JSON.'; if (baseArgs.messages.length > 0 && baseArgs.messages[0].role === 'system') { baseArgs.messages[0].content += '\n\n' + jsonInstruction; } else { baseArgs.messages.unshift({ role: 'system', content: jsonInstruction }); } } return baseArgs; } /** * 生成文本(非流式) */ async doGenerate(options) { const args = this.getArgs(options); const { responseHeaders, value: response } = await (0, provider_utils_1.postJsonToApi)({ url: `${this.config.baseURL}/chat/completions`, headers: (0, provider_utils_1.combineHeaders)(this.config.headers(), options.headers), body: args, failedResponseHandler: this.failedResponseHandler, successfulResponseHandler: (0, provider_utils_1.createJsonResponseHandler)(ernieResponseSchema), abortSignal: options.abortSignal, fetch: this.config.fetch, }); const { messages: rawPrompt, ...rawSettings } = args; const choice = response.choices[0]; return { text: choice.message.content ?? '', toolCalls: choice.message.tool_calls?.map((toolCall) => ({ toolCallType: 'function', toolCallId: toolCall.id, toolName: toolCall.function.name, args: toolCall.function.arguments, })) ?? [], finishReason: this.mapFinishReason(choice.finish_reason), usage: { promptTokens: response.usage.prompt_tokens, completionTokens: response.usage.completion_tokens, }, rawCall: { rawPrompt, rawSettings }, rawResponse: { headers: responseHeaders }, warnings: [], }; } /** * 生成文本(流式) */ async doStream(options) { const args = this.getArgs(options); args.stream = true; const { responseHeaders, value: response } = await (0, provider_utils_1.postJsonToApi)({ url: `${this.config.baseURL}/chat/completions`, headers: (0, provider_utils_1.combineHeaders)(this.config.headers(), options.headers), body: args, failedResponseHandler: this.failedResponseHandler, successfulResponseHandler: (0, provider_utils_1.createEventSourceResponseHandler)(ernieStreamChunkSchema), abortSignal: options.abortSignal, fetch: this.config.fetch, }); const { messages: rawPrompt, ...rawSettings } = args; let finishReason = 'other'; let usage = { promptTokens: Number.NaN, completionTokens: Number.NaN, }; const self = this; return { stream: response.pipeThrough(new TransformStream({ transform(chunk, controller) { if (!chunk.success) { controller.enqueue({ type: 'error', error: chunk.error }); return; } const value = chunk.value; if (value.choices?.[0]?.delta?.content) { controller.enqueue({ type: 'text-delta', textDelta: value.choices[0].delta.content, }); } if (value.choices?.[0]?.finish_reason) { finishReason = self.mapFinishReason(value.choices[0].finish_reason); } if (value.usage) { usage = { promptTokens: value.usage.prompt_tokens, completionTokens: value.usage.completion_tokens, }; } if (value.choices?.[0]?.delta?.tool_calls) { // 处理工具调用流式响应 const toolCalls = value.choices[0].delta.tool_calls; for (const toolCall of toolCalls) { if (toolCall.function?.name) { controller.enqueue({ type: 'tool-call', toolCallType: 'function', toolCallId: toolCall.id, toolName: toolCall.function.name, args: toolCall.function.arguments || '', }); } } } }, flush(controller) { controller.enqueue({ type: 'finish', finishReason, usage, }); }, })), rawCall: { rawPrompt, rawSettings }, rawResponse: { headers: responseHeaders }, warnings: [], }; } /** * 映射完成原因 */ mapFinishReason(finishReason) { switch (finishReason) { case 'stop': return 'stop'; case 'length': return 'length'; case 'tool_calls': return 'tool-calls'; case 'content_filter': return 'content-filter'; default: return 'other'; } } } exports.ErnieChatLanguageModel = ErnieChatLanguageModel; // ERNIE API 响应模式定义 const ernieResponseSchema = zod_1.z.object({ id: zod_1.z.string(), object: zod_1.z.string(), created: zod_1.z.number(), model: zod_1.z.string(), choices: zod_1.z.array(zod_1.z.object({ index: zod_1.z.number(), message: zod_1.z.object({ role: zod_1.z.string(), content: zod_1.z.string().nullable(), tool_calls: zod_1.z .array(zod_1.z.object({ id: zod_1.z.string(), type: zod_1.z.string(), function: zod_1.z.object({ name: zod_1.z.string(), arguments: zod_1.z.string(), }), })) .optional(), }), finish_reason: zod_1.z.string().nullable(), })), usage: zod_1.z.object({ prompt_tokens: zod_1.z.number(), completion_tokens: zod_1.z.number(), total_tokens: zod_1.z.number(), }), }); const ernieStreamChunkSchema = zod_1.z.object({ id: zod_1.z.string(), object: zod_1.z.string(), created: zod_1.z.number(), model: zod_1.z.string(), choices: zod_1.z .array(zod_1.z.object({ index: zod_1.z.number(), delta: zod_1.z.object({ role: zod_1.z.string().optional(), content: zod_1.z.string().optional(), tool_calls: zod_1.z .array(zod_1.z.object({ id: zod_1.z.string(), type: zod_1.z.string(), function: zod_1.z.object({ name: zod_1.z.string().optional(), arguments: zod_1.z.string().optional(), }), })) .optional(), }), finish_reason: zod_1.z.string().nullable().optional(), })) .optional(), usage: zod_1.z .object({ prompt_tokens: zod_1.z.number(), completion_tokens: zod_1.z.number(), total_tokens: zod_1.z.number(), }) .optional(), }); //# sourceMappingURL=ernie-chat-language-model.js.map