UNPKG

@lobehub/chat

Version:

Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.

166 lines (145 loc) 5.45 kB
import OpenAI, { AzureOpenAI } from 'openai'; import type { Stream } from 'openai/streaming'; import { systemToUserModels } from '@/const/models'; import { LobeRuntimeAI } from '../BaseAI'; import { AgentRuntimeErrorType } from '../error'; import { ChatMethodOptions, ChatStreamPayload, Embeddings, EmbeddingsOptions, EmbeddingsPayload, ModelProvider, } from '../types'; import { AgentRuntimeError } from '../utils/createError'; import { debugStream } from '../utils/debugStream'; import { transformResponseToStream } from '../utils/openaiCompatibleFactory'; import { convertOpenAIMessages } from '../utils/openaiHelpers'; import { StreamingResponse } from '../utils/response'; import { OpenAIStream } from '../utils/streams'; export class LobeAzureOpenAI implements LobeRuntimeAI { client: AzureOpenAI; constructor(params: { apiKey?: string; apiVersion?: string; baseURL?: string } = {}) { if (!params.apiKey || !params.baseURL) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey); this.client = new AzureOpenAI({ apiKey: params.apiKey, apiVersion: params.apiVersion, dangerouslyAllowBrowser: true, endpoint: params.baseURL, }); this.baseURL = params.baseURL; } baseURL: string; async chat(payload: ChatStreamPayload, options?: ChatMethodOptions) { const { messages, model, ...params } = payload; // o1 series models on Azure OpenAI does not support streaming currently const enableStreaming = model.includes('o1') ? false : (params.stream ?? true); const updatedMessages = messages.map((message) => ({ ...message, role: // Convert 'system' role to 'user' or 'developer' based on the model (model.includes('o1') || model.includes('o3')) && message.role === 'system' ? [...systemToUserModels].some((sub) => model.includes(sub)) ? 'user' : 'developer' : message.role, })); try { const response = await this.client.chat.completions.create({ messages: await convertOpenAIMessages( updatedMessages as OpenAI.ChatCompletionMessageParam[], ), model, ...params, max_completion_tokens: undefined, stream: enableStreaming, tool_choice: params.tools ? 'auto' : undefined, }); if (enableStreaming) { const stream = response as Stream<OpenAI.ChatCompletionChunk>; const [prod, debug] = stream.tee(); if (process.env.DEBUG_AZURE_CHAT_COMPLETION === '1') { debugStream(debug.toReadableStream()).catch(console.error); } return StreamingResponse(OpenAIStream(prod, { callbacks: options?.callback }), { headers: options?.headers, }); } else { const stream = transformResponseToStream(response as OpenAI.ChatCompletion); return StreamingResponse(OpenAIStream(stream, { callbacks: options?.callback }), { headers: options?.headers, }); } } catch (e) { return this.handleError(e, model); } } async embeddings(payload: EmbeddingsPayload, options?: EmbeddingsOptions): Promise<Embeddings[]> { try { const res = await this.client.embeddings.create( { ...payload, encoding_format: 'float', user: options?.user }, { headers: options?.headers, signal: options?.signal }, ); return res.data.map((item) => item.embedding); } catch (error) { return this.handleError(error, payload.model); } } protected handleError(e: any, model?: string): never { let error = e as { [key: string]: any; code: string; message: string }; if (error.code) { switch (error.code) { case 'DeploymentNotFound': { error = { ...error, deployId: model }; } } } else { error = { cause: error.cause, message: error.message, name: error.name, } as any; } const errorType = error.code ? AgentRuntimeErrorType.ProviderBizError : AgentRuntimeErrorType.AgentRuntimeError; throw AgentRuntimeError.chat({ endpoint: this.maskSensitiveUrl(this.baseURL), error, errorType, provider: ModelProvider.Azure, }); } // Convert object keys to camel case, copy from `@azure/openai` in `node_modules/@azure/openai/dist/index.cjs` private camelCaseKeys = (obj: any): any => { if (typeof obj !== 'object' || !obj) return obj; if (Array.isArray(obj)) { return obj.map((v) => this.camelCaseKeys(v)); } else { for (const key of Object.keys(obj)) { const value = obj[key]; const newKey = this.tocamelCase(key); if (newKey !== key) { delete obj[key]; } obj[newKey] = typeof obj[newKey] === 'object' ? this.camelCaseKeys(value) : value; } return obj; } }; private tocamelCase = (str: string) => { return str .toLowerCase() .replaceAll(/(_[a-z])/g, (group) => group.toUpperCase().replace('_', '')); }; private maskSensitiveUrl = (url: string) => { // 使用正则表达式匹配 'https://' 后面和 '.openai.azure.com/' 前面的内容 const regex = /^(https:\/\/)([^.]+)(\.openai\.azure\.com\/.*)$/; // 使用替换函数 return url.replace(regex, (match, protocol, subdomain, rest) => { // 将子域名替换为 '***' return `${protocol}***${rest}`; }); }; }