UNPKG

@langchain/azure-openai

Version:

Azure SDK for OpenAI integrations for LangChain.js

598 lines (597 loc) 23.8 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.AzureChatOpenAI = exports.messageToOpenAIRole = void 0; const openai_1 = require("@azure/openai"); const chat_models_1 = require("@langchain/core/language_models/chat_models"); const messages_1 = require("@langchain/core/messages"); const outputs_1 = require("@langchain/core/outputs"); const env_1 = require("@langchain/core/utils/env"); const core_auth_1 = require("@azure/core-auth"); const openai_format_fndef_js_1 = require("./utils/openai-format-fndef.cjs"); const constants_js_1 = require("./constants.cjs"); function _convertDeltaToMessageChunk(delta, defaultRole) { const role = delta.role ?? defaultRole; const content = delta.content ?? ""; let additional_kwargs; if (delta.functionCall) { additional_kwargs = { function_call: delta.functionCall, }; } else if (delta.toolCalls) { additional_kwargs = { tool_calls: delta.toolCalls, }; } else { additional_kwargs = {}; } if (role === "user") { return new messages_1.HumanMessageChunk({ content }); } else if (role === "assistant") { return new messages_1.AIMessageChunk({ content, additional_kwargs }); } else if (role === "system") { return new messages_1.SystemMessageChunk({ content }); } else if (role === "function") { return new messages_1.FunctionMessageChunk({ content, additional_kwargs, name: delta.role, }); } else if (role === "tool") { return new messages_1.ToolMessageChunk({ content, additional_kwargs, tool_call_id: delta.toolCalls[0].id, }); } else { return new messages_1.ChatMessageChunk({ content, role }); } } function openAIResponseToChatMessage(message) { switch (message.role) { case "assistant": return new messages_1.AIMessage(message.content || "", { function_call: message.functionCall, tool_calls: message.toolCalls, }); default: return new messages_1.ChatMessage(message.content || "", message.role ?? "unknown"); } } function extractGenericMessageCustomRole(message) { if (message.role !== "system" && message.role !== "assistant" && message.role !== "user" && message.role !== "function" && message.role !== "tool") { console.warn(`Unknown message role: ${message.role}`); } return message.role; } function messageToOpenAIRole(message) { const type = message._getType(); switch (type) { case "system": return "system"; case "ai": return "assistant"; case "human": return "user"; case "function": return "function"; case "tool": return "tool"; case "generic": { if (!messages_1.ChatMessage.isInstance(message)) throw new Error("Invalid generic chat message"); return extractGenericMessageCustomRole(message); } default: throw new Error(`Unknown message type: ${type}`); } } exports.messageToOpenAIRole = messageToOpenAIRole; /** @deprecated Import from "@langchain/openai" instead. */ class AzureChatOpenAI extends chat_models_1.BaseChatModel { static lc_name() { return "AzureChatOpenAI"; } get callKeys() { return [ ...super.callKeys, "options", "function_call", "functions", "tools", "tool_choice", "promptIndex", "response_format", "seed", ]; } get lc_secrets() { return { openAIApiKey: "OPENAI_API_KEY", azureOpenAIApiKey: "AZURE_OPENAI_API_KEY", azureOpenAIEndpoint: "AZURE_OPENAI_API_ENDPOINT", azureOpenAIApiDeploymentName: "AZURE_OPENAI_API_DEPLOYMENT_NAME", }; } get lc_aliases() { return { modelName: "model", openAIApiKey: "openai_api_key", azureOpenAIApiKey: "azure_openai_api_key", azureOpenAIEndpoint: "azure_openai_api_endpoint", azureOpenAIApiDeploymentName: "azure_openai_api_deployment_name", }; } constructor(fields) { super(fields ?? {}); Object.defineProperty(this, "lc_serializable", { enumerable: true, configurable: true, writable: true, value: true }); Object.defineProperty(this, "azureExtensionOptions", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "maxTokens", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "temperature", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "topP", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "logitBias", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "user", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "n", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "presencePenalty", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "frequencyPenalty", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "stop", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "stopSequences", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "streaming", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "modelName", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "model", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "modelKwargs", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "timeout", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "azureOpenAIEndpoint", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "azureOpenAIApiKey", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "apiKey", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "azureOpenAIApiDeploymentName", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "client", { enumerable: true, configurable: true, writable: true, value: void 0 }); this.azureOpenAIEndpoint = fields?.azureOpenAIEndpoint ?? (0, env_1.getEnvironmentVariable)("AZURE_OPENAI_API_ENDPOINT"); this.azureOpenAIApiDeploymentName = (fields?.azureOpenAIEmbeddingsApiDeploymentName || fields?.azureOpenAIApiDeploymentName) ?? (0, env_1.getEnvironmentVariable)("AZURE_OPENAI_API_DEPLOYMENT_NAME"); const openAiApiKey = fields?.apiKey ?? fields?.openAIApiKey ?? (0, env_1.getEnvironmentVariable)("OPENAI_API_KEY"); this.azureOpenAIApiKey = fields?.apiKey ?? fields?.azureOpenAIApiKey ?? (0, env_1.getEnvironmentVariable)("AZURE_OPENAI_API_KEY") ?? openAiApiKey; this.apiKey = this.azureOpenAIApiKey; const azureCredential = fields?.credentials ?? (this.apiKey === openAiApiKey ? new openai_1.OpenAIKeyCredential(this.apiKey ?? "") : new openai_1.AzureKeyCredential(this.apiKey ?? "")); // eslint-disable-next-line no-instanceof/no-instanceof const isOpenAIApiKey = azureCredential instanceof openai_1.OpenAIKeyCredential; if (!this.apiKey && !fields?.credentials) { throw new Error("Azure OpenAI API key not found"); } if (!this.azureOpenAIEndpoint && !isOpenAIApiKey) { throw new Error("Azure OpenAI Endpoint not found"); } if (!this.azureOpenAIApiDeploymentName && !isOpenAIApiKey) { throw new Error("Azure OpenAI Deployment name not found"); } this.modelName = fields?.model ?? fields?.modelName ?? this.model; this.model = this.modelName; this.modelKwargs = fields?.modelKwargs ?? {}; this.timeout = fields?.timeout; this.temperature = fields?.temperature ?? this.temperature; this.topP = fields?.topP ?? this.topP; this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty; this.presencePenalty = fields?.presencePenalty ?? this.presencePenalty; this.maxTokens = fields?.maxTokens; this.n = fields?.n ?? this.n; this.logitBias = fields?.logitBias; this.stop = fields?.stopSequences ?? fields?.stop; this.stopSequences = this.stop; this.user = fields?.user; this.azureExtensionOptions = fields?.azureExtensionOptions; this.streaming = fields?.streaming ?? false; const options = { userAgentOptions: { userAgentPrefix: constants_js_1.USER_AGENT_PREFIX }, }; if (isOpenAIApiKey) { this.client = new openai_1.OpenAIClient(azureCredential); } else if ((0, core_auth_1.isTokenCredential)(azureCredential)) { this.client = new openai_1.OpenAIClient(this.azureOpenAIEndpoint ?? "", azureCredential, options); } else { this.client = new openai_1.OpenAIClient(this.azureOpenAIEndpoint ?? "", azureCredential, options); } } formatMessages(messages) { return messages.map((message) => ({ role: messageToOpenAIRole(message), content: message.content, name: message.name, toolCalls: message.additional_kwargs.tool_calls, functionCall: message.additional_kwargs.function_call, toolCallId: message.tool_call_id, })); } async _streamChatCompletionsWithRetry(azureOpenAIMessages, options) { return this.caller.call(async () => { const deploymentName = this.azureOpenAIApiDeploymentName || this.model; const res = await this.client.streamChatCompletions(deploymentName, azureOpenAIMessages, { functions: options?.functions, functionCall: options?.function_call, maxTokens: this.maxTokens, temperature: this.temperature, topP: this.topP, logitBias: this.logitBias, user: this.user, n: this.n, stop: this.stopSequences, presencePenalty: this.presencePenalty, frequencyPenalty: this.frequencyPenalty, azureExtensionOptions: this.azureExtensionOptions, requestOptions: { timeout: options?.timeout ?? this.timeout, }, abortSignal: options?.signal ?? undefined, tools: options?.tools, toolChoice: options?.tool_choice, responseFormat: options?.response_format, seed: options?.seed, ...this.modelKwargs, }); return res; }); } async *_streamResponseChunks(messages, options, runManager) { const azureOpenAIMessages = this.formatMessages(messages); let defaultRole; const streamIterable = await this._streamChatCompletionsWithRetry(azureOpenAIMessages, options); for await (const data of streamIterable) { const choice = data?.choices[0]; if (!choice) { continue; } const { delta } = choice; if (!delta) { continue; } const chunk = _convertDeltaToMessageChunk(delta, defaultRole); defaultRole = delta.role ?? defaultRole; const newTokenIndices = { prompt: options.promptIndex ?? 0, completion: choice.index ?? 0, }; if (typeof chunk.content !== "string") { console.log("[WARNING]: Received non-string content from OpenAI. This is currently not supported."); continue; } const generationChunk = new outputs_1.ChatGenerationChunk({ message: chunk, text: chunk.content, generationInfo: newTokenIndices, }); yield generationChunk; // eslint-disable-next-line no-void void runManager?.handleLLMNewToken(generationChunk.text ?? "", newTokenIndices, undefined, undefined, undefined, { chunk: generationChunk }); } if (options.signal?.aborted) { throw new Error("AbortError"); } } async _generate(messages, options, runManager) { const deploymentName = this.azureOpenAIApiDeploymentName || this.model; const tokenUsage = {}; const azureOpenAIMessages = this.formatMessages(messages); if (!this.streaming) { const data = await this.caller.call(() => this.client.getChatCompletions(deploymentName, azureOpenAIMessages, { functions: options?.functions, functionCall: options?.function_call, maxTokens: this.maxTokens, temperature: this.temperature, topP: this.topP, logitBias: this.logitBias, user: this.user, n: this.n, stop: this.stopSequences, presencePenalty: this.presencePenalty, frequencyPenalty: this.frequencyPenalty, azureExtensionOptions: this.azureExtensionOptions, requestOptions: { timeout: options?.timeout ?? this.timeout, }, abortSignal: options?.signal ?? undefined, tools: options?.tools, toolChoice: options?.tool_choice, responseFormat: options?.response_format, seed: options?.seed, ...this.modelKwargs, })); const { completionTokens, promptTokens, totalTokens } = data?.usage ?? {}; if (completionTokens) { tokenUsage.completionTokens = (tokenUsage.completionTokens ?? 0) + completionTokens; } if (promptTokens) { tokenUsage.promptTokens = (tokenUsage.promptTokens ?? 0) + promptTokens; } if (totalTokens) { tokenUsage.totalTokens = (tokenUsage.totalTokens ?? 0) + totalTokens; } const generations = []; for (const part of data?.choices ?? []) { const text = part.message?.content ?? ""; const generation = { text, message: openAIResponseToChatMessage(part.message ?? { role: "assistant", content: text, toolCalls: [], }), }; generation.generationInfo = { ...(part.finishReason ? { finish_reason: part.finishReason } : {}), }; generations.push(generation); } return { generations, llmOutput: { tokenUsage }, }; } else { const stream = this._streamResponseChunks(messages, options, runManager); const finalChunks = {}; for await (const chunk of stream) { const index = chunk.generationInfo?.completion ?? 0; if (finalChunks[index] === undefined) { finalChunks[index] = chunk; } else { finalChunks[index] = finalChunks[index].concat(chunk); } } const generations = Object.entries(finalChunks) .sort(([aKey], [bKey]) => parseInt(aKey, 10) - parseInt(bKey, 10)) .map(([_, value]) => value); const promptTokenUsage = await this.getEstimatedTokenCountFromPrompt(messages, options?.functions, options?.function_call); const completionTokenUsage = await this.getNumTokensFromGenerations(generations); tokenUsage.promptTokens = promptTokenUsage; tokenUsage.completionTokens = completionTokenUsage; tokenUsage.totalTokens = promptTokenUsage + completionTokenUsage; return { generations, llmOutput: { estimatedTokenUsage: tokenUsage } }; } } /** * Estimate the number of tokens an array of generations have used. */ async getNumTokensFromGenerations(generations) { const generationUsages = await Promise.all(generations.map(async (generation) => { if (generation.message.additional_kwargs?.function_call) { return (await this.getNumTokensFromMessages([generation.message])) .countPerMessage[0]; } else { return await this.getNumTokens(generation.message.content); } })); return generationUsages.reduce((a, b) => a + b, 0); } _llmType() { return "azure-openai"; } /** * Estimate the number of tokens a prompt will use. * Modified from: https://github.com/hmarr/openai-chat-tokens/blob/main/src/index.ts */ async getEstimatedTokenCountFromPrompt(messages, functions, function_call) { // It appears that if functions are present, the first system message is padded with a trailing newline. This // was inferred by trying lots of combinations of messages and functions and seeing what the token counts were. let tokens = (await this.getNumTokensFromMessages(messages)).totalCount; // If there are functions, add the function definitions as they count towards token usage if (functions && function_call !== "auto") { const promptDefinitions = (0, openai_format_fndef_js_1.formatFunctionDefinitions)(functions); tokens += await this.getNumTokens(promptDefinitions); tokens += 9; // Add nine per completion } // If there's a system message _and_ functions are present, subtract four tokens. I assume this is because // functions typically add a system message, but reuse the first one if it's already there. This offsets // the extra 9 tokens added by the function definitions. if (functions && messages.find((m) => m._getType() === "system")) { tokens -= 4; } // If function_call is 'none', add one token. // If it's a FunctionCall object, add 4 + the number of tokens in the function name. // If it's undefined or 'auto', don't add anything. if (function_call === "none") { tokens += 1; } else if (typeof function_call === "object") { tokens += (await this.getNumTokens(function_call.name)) + 4; } return tokens; } async getNumTokensFromMessages(messages) { let totalCount = 0; let tokensPerMessage = 0; let tokensPerName = 0; // From: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb if (this.model === "gpt-3.5-turbo-0301") { tokensPerMessage = 4; tokensPerName = -1; } else { tokensPerMessage = 3; tokensPerName = 1; } const countPerMessage = await Promise.all(messages.map(async (message) => { const textCount = await this.getNumTokens(message.content); const roleCount = await this.getNumTokens(messageToOpenAIRole(message)); const nameCount = message.name !== undefined ? tokensPerName + (await this.getNumTokens(message.name)) : 0; let count = textCount + tokensPerMessage + roleCount + nameCount; // From: https://github.com/hmarr/openai-chat-tokens/blob/main/src/index.ts messageTokenEstimate const openAIMessage = message; if (openAIMessage._getType() === "function") { count -= 2; } if (openAIMessage.additional_kwargs?.function_call) { count += 3; } if (openAIMessage?.additional_kwargs.function_call?.name) { count += await this.getNumTokens(openAIMessage.additional_kwargs.function_call?.name); } if (openAIMessage.additional_kwargs.function_call?.arguments) { count += await this.getNumTokens( // Remove newlines and spaces JSON.stringify(JSON.parse(openAIMessage.additional_kwargs.function_call?.arguments))); } totalCount += count; return count; })); totalCount += 3; // every reply is primed with <|start|>assistant<|message|> return { totalCount, countPerMessage }; } /** @ignore */ _combineLLMOutput(...llmOutputs) { return llmOutputs.reduce((acc, llmOutput) => { if (llmOutput && llmOutput.tokenUsage) { acc.tokenUsage.completionTokens += llmOutput.tokenUsage.completionTokens ?? 0; acc.tokenUsage.promptTokens += llmOutput.tokenUsage.promptTokens ?? 0; acc.tokenUsage.totalTokens += llmOutput.tokenUsage.totalTokens ?? 0; } return acc; }, { tokenUsage: { completionTokens: 0, promptTokens: 0, totalTokens: 0, }, }); } } exports.AzureChatOpenAI = AzureChatOpenAI;