@lobehub/chat
Version:
Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.
132 lines (110 loc) • 4.39 kB
text/typescript
import createClient, { ModelClient } from '@azure-rest/ai-inference';
import { AzureKeyCredential } from '@azure/core-auth';
import OpenAI from 'openai';
import { systemToUserModels } from '@/const/models';
import { LobeRuntimeAI } from '../BaseAI';
import { AgentRuntimeErrorType } from '../error';
import { ChatMethodOptions, ChatStreamPayload, ModelProvider } from '../types';
import { AgentRuntimeError } from '../utils/createError';
import { debugStream } from '../utils/debugStream';
import { transformResponseToStream } from '../utils/openaiCompatibleFactory';
import { StreamingResponse } from '../utils/response';
import { OpenAIStream, createSSEDataExtractor } from '../utils/streams';
interface AzureAIParams {
apiKey?: string;
apiVersion?: string;
baseURL?: string;
}
export class LobeAzureAI implements LobeRuntimeAI {
client: ModelClient;
constructor(params?: AzureAIParams) {
if (!params?.apiKey || !params?.baseURL)
throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey);
this.client = createClient(params?.baseURL, new AzureKeyCredential(params?.apiKey));
this.baseURL = params?.baseURL;
}
baseURL: string;
async chat(payload: ChatStreamPayload, options?: ChatMethodOptions) {
const { messages, model, temperature, top_p, ...params } = payload;
// o1 series models on Azure OpenAI does not support streaming currently
const enableStreaming = model.includes('o1') ? false : (params.stream ?? true);
const updatedMessages = messages.map((message) => ({
...message,
role:
// Convert 'system' role to 'user' or 'developer' based on the model
(model.includes('o1') || model.includes('o3')) && message.role === 'system'
? [...systemToUserModels].some((sub) => model.includes(sub))
? 'user'
: 'developer'
: message.role,
}));
try {
const response = this.client.path('/chat/completions').post({
body: {
messages: updatedMessages as OpenAI.ChatCompletionMessageParam[],
model,
...params,
stream: enableStreaming,
temperature: model.includes('o3') || model.includes('o4') ? undefined : temperature,
tool_choice: params.tools ? 'auto' : undefined,
top_p: model.includes('o3') || model.includes('o4') ? undefined : top_p,
},
});
if (enableStreaming) {
const stream = await response.asBrowserStream();
const [prod, debug] = stream.body!.tee();
if (process.env.DEBUG_AZURE_AI_CHAT_COMPLETION === '1') {
debugStream(debug).catch(console.error);
}
return StreamingResponse(
OpenAIStream(prod.pipeThrough(createSSEDataExtractor()), {
callbacks: options?.callback,
}),
{
headers: options?.headers,
},
);
} else {
const res = await response;
// the azure AI inference response is openai compatible
const stream = transformResponseToStream(res.body as OpenAI.ChatCompletion);
return StreamingResponse(OpenAIStream(stream, { callbacks: options?.callback }), {
headers: options?.headers,
});
}
} catch (e) {
let error = e as { [key: string]: any; code: string; message: string };
if (error.code) {
switch (error.code) {
case 'DeploymentNotFound': {
error = { ...error, deployId: model };
}
}
} else {
error = {
cause: error.cause,
message: error.message,
name: error.name,
} as any;
}
const errorType = error.code
? AgentRuntimeErrorType.ProviderBizError
: AgentRuntimeErrorType.AgentRuntimeError;
throw AgentRuntimeError.chat({
endpoint: this.maskSensitiveUrl(this.baseURL),
error,
errorType,
provider: ModelProvider.Azure,
});
}
}
private maskSensitiveUrl = (url: string) => {
// 使用正则表达式匹配 'https://' 后面和 '.azure.com/' 前面的内容
const regex = /^(https:\/\/)([^.]+)(\.cognitiveservices\.azure\.com\/.*)$/;
// 使用替换函数
return url.replace(regex, (match, protocol, subdomain, rest) => {
// 将子域名替换为 '***'
return `${protocol}***${rest}`;
});
};
}