@lobehub/chat
Version:
Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.
166 lines (145 loc) • 5.45 kB
text/typescript
import OpenAI, { AzureOpenAI } from 'openai';
import type { Stream } from 'openai/streaming';
import { systemToUserModels } from '@/const/models';
import { LobeRuntimeAI } from '../BaseAI';
import { AgentRuntimeErrorType } from '../error';
import {
ChatMethodOptions,
ChatStreamPayload,
Embeddings,
EmbeddingsOptions,
EmbeddingsPayload,
ModelProvider,
} from '../types';
import { AgentRuntimeError } from '../utils/createError';
import { debugStream } from '../utils/debugStream';
import { transformResponseToStream } from '../utils/openaiCompatibleFactory';
import { convertOpenAIMessages } from '../utils/openaiHelpers';
import { StreamingResponse } from '../utils/response';
import { OpenAIStream } from '../utils/streams';
export class LobeAzureOpenAI implements LobeRuntimeAI {
client: AzureOpenAI;
constructor(params: { apiKey?: string; apiVersion?: string; baseURL?: string } = {}) {
if (!params.apiKey || !params.baseURL)
throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey);
this.client = new AzureOpenAI({
apiKey: params.apiKey,
apiVersion: params.apiVersion,
dangerouslyAllowBrowser: true,
endpoint: params.baseURL,
});
this.baseURL = params.baseURL;
}
baseURL: string;
async chat(payload: ChatStreamPayload, options?: ChatMethodOptions) {
const { messages, model, ...params } = payload;
// o1 series models on Azure OpenAI does not support streaming currently
const enableStreaming = model.includes('o1') ? false : (params.stream ?? true);
const updatedMessages = messages.map((message) => ({
...message,
role:
// Convert 'system' role to 'user' or 'developer' based on the model
(model.includes('o1') || model.includes('o3')) && message.role === 'system'
? [...systemToUserModels].some((sub) => model.includes(sub))
? 'user'
: 'developer'
: message.role,
}));
try {
const response = await this.client.chat.completions.create({
messages: await convertOpenAIMessages(
updatedMessages as OpenAI.ChatCompletionMessageParam[],
),
model,
...params,
max_completion_tokens: undefined,
stream: enableStreaming,
tool_choice: params.tools ? 'auto' : undefined,
});
if (enableStreaming) {
const stream = response as Stream<OpenAI.ChatCompletionChunk>;
const [prod, debug] = stream.tee();
if (process.env.DEBUG_AZURE_CHAT_COMPLETION === '1') {
debugStream(debug.toReadableStream()).catch(console.error);
}
return StreamingResponse(OpenAIStream(prod, { callbacks: options?.callback }), {
headers: options?.headers,
});
} else {
const stream = transformResponseToStream(response as OpenAI.ChatCompletion);
return StreamingResponse(OpenAIStream(stream, { callbacks: options?.callback }), {
headers: options?.headers,
});
}
} catch (e) {
return this.handleError(e, model);
}
}
async embeddings(payload: EmbeddingsPayload, options?: EmbeddingsOptions): Promise<Embeddings[]> {
try {
const res = await this.client.embeddings.create(
{ ...payload, encoding_format: 'float', user: options?.user },
{ headers: options?.headers, signal: options?.signal },
);
return res.data.map((item) => item.embedding);
} catch (error) {
return this.handleError(error, payload.model);
}
}
protected handleError(e: any, model?: string): never {
let error = e as { [key: string]: any; code: string; message: string };
if (error.code) {
switch (error.code) {
case 'DeploymentNotFound': {
error = { ...error, deployId: model };
}
}
} else {
error = {
cause: error.cause,
message: error.message,
name: error.name,
} as any;
}
const errorType = error.code
? AgentRuntimeErrorType.ProviderBizError
: AgentRuntimeErrorType.AgentRuntimeError;
throw AgentRuntimeError.chat({
endpoint: this.maskSensitiveUrl(this.baseURL),
error,
errorType,
provider: ModelProvider.Azure,
});
}
// Convert object keys to camel case, copy from `@azure/openai` in `node_modules/@azure/openai/dist/index.cjs`
private camelCaseKeys = (obj: any): any => {
if (typeof obj !== 'object' || !obj) return obj;
if (Array.isArray(obj)) {
return obj.map((v) => this.camelCaseKeys(v));
} else {
for (const key of Object.keys(obj)) {
const value = obj[key];
const newKey = this.tocamelCase(key);
if (newKey !== key) {
delete obj[key];
}
obj[newKey] = typeof obj[newKey] === 'object' ? this.camelCaseKeys(value) : value;
}
return obj;
}
};
private tocamelCase = (str: string) => {
return str
.toLowerCase()
.replaceAll(/(_[a-z])/g, (group) => group.toUpperCase().replace('_', ''));
};
private maskSensitiveUrl = (url: string) => {
// 使用正则表达式匹配 'https://' 后面和 '.openai.azure.com/' 前面的内容
const regex = /^(https:\/\/)([^.]+)(\.openai\.azure\.com\/.*)$/;
// 使用替换函数
return url.replace(regex, (match, protocol, subdomain, rest) => {
// 将子域名替换为 '***'
return `${protocol}***${rest}`;
});
};
}