@langchain/azure-openai
Version:
Azure SDK for OpenAI integrations for LangChain.js
439 lines (438 loc) • 17.1 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.AzureOpenAI = void 0;
const openai_1 = require("@azure/openai");
const base_1 = require("@langchain/core/language_models/base");
const llms_1 = require("@langchain/core/language_models/llms");
const chunk_array_1 = require("@langchain/core/utils/chunk_array");
const outputs_1 = require("@langchain/core/outputs");
const env_1 = require("@langchain/core/utils/env");
const core_auth_1 = require("@azure/core-auth");
const constants_js_1 = require("./constants.cjs");
/** @deprecated Import from "@langchain/openai" instead. */
class AzureOpenAI extends llms_1.BaseLLM {
static lc_name() {
return "AzureOpenAI";
}
get callKeys() {
return [...super.callKeys, "options"];
}
get lc_secrets() {
return {
apiKey: "AZURE_OPENAI_API_KEY",
openAIApiKey: "OPENAI_API_KEY",
azureOpenAIApiKey: "AZURE_OPENAI_API_KEY",
azureOpenAIEndpoint: "AZURE_OPENAI_API_ENDPOINT",
azureOpenAIApiDeploymentName: "AZURE_OPENAI_API_DEPLOYMENT_NAME",
};
}
get lc_aliases() {
return {
modelName: "model",
openAIApiKey: "openai_api_key",
azureOpenAIApiKey: "azure_openai_api_key",
azureOpenAIEndpoint: "azure_openai_api_endpoint",
azureOpenAIApiDeploymentName: "azure_openai_api_deployment_name",
};
}
constructor(fields) {
super(fields ?? {});
Object.defineProperty(this, "lc_serializable", {
enumerable: true,
configurable: true,
writable: true,
value: true
});
Object.defineProperty(this, "temperature", {
enumerable: true,
configurable: true,
writable: true,
value: 0.7
});
Object.defineProperty(this, "maxTokens", {
enumerable: true,
configurable: true,
writable: true,
value: 256
});
Object.defineProperty(this, "topP", {
enumerable: true,
configurable: true,
writable: true,
value: 1
});
Object.defineProperty(this, "frequencyPenalty", {
enumerable: true,
configurable: true,
writable: true,
value: 0
});
Object.defineProperty(this, "presencePenalty", {
enumerable: true,
configurable: true,
writable: true,
value: 0
});
Object.defineProperty(this, "n", {
enumerable: true,
configurable: true,
writable: true,
value: 1
});
Object.defineProperty(this, "bestOf", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "logitBias", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "modelName", {
enumerable: true,
configurable: true,
writable: true,
value: "gpt-3.5-turbo-instruct"
});
Object.defineProperty(this, "model", {
enumerable: true,
configurable: true,
writable: true,
value: "gpt-3.5-turbo-instruct"
});
Object.defineProperty(this, "modelKwargs", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "batchSize", {
enumerable: true,
configurable: true,
writable: true,
value: 20
});
Object.defineProperty(this, "timeout", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "stop", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "stopSequences", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "user", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "streaming", {
enumerable: true,
configurable: true,
writable: true,
value: false
});
Object.defineProperty(this, "azureOpenAIApiKey", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "apiKey", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "azureOpenAIEndpoint", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "azureOpenAIApiDeploymentName", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "logprobs", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "echo", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "client", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
this.azureOpenAIEndpoint =
fields?.azureOpenAIEndpoint ??
(0, env_1.getEnvironmentVariable)("AZURE_OPENAI_API_ENDPOINT");
this.azureOpenAIApiDeploymentName =
fields?.azureOpenAIApiDeploymentName ??
(0, env_1.getEnvironmentVariable)("AZURE_OPENAI_API_DEPLOYMENT_NAME");
const openAiApiKey = fields?.apiKey ??
fields?.openAIApiKey ??
(0, env_1.getEnvironmentVariable)("OPENAI_API_KEY");
this.azureOpenAIApiKey =
fields?.apiKey ??
fields?.azureOpenAIApiKey ??
(0, env_1.getEnvironmentVariable)("AZURE_OPENAI_API_KEY") ??
openAiApiKey;
this.apiKey = this.azureOpenAIApiKey;
const azureCredential = fields?.credentials ??
(this.apiKey === openAiApiKey
? new openai_1.OpenAIKeyCredential(this.apiKey ?? "")
: new openai_1.AzureKeyCredential(this.apiKey ?? ""));
// eslint-disable-next-line no-instanceof/no-instanceof
const isOpenAIApiKey = azureCredential instanceof openai_1.OpenAIKeyCredential;
if (!this.apiKey && !fields?.credentials) {
throw new Error("Azure OpenAI API key not found");
}
if (!this.azureOpenAIEndpoint && !isOpenAIApiKey) {
throw new Error("Azure OpenAI Endpoint not found");
}
if (!this.azureOpenAIApiDeploymentName && !isOpenAIApiKey) {
throw new Error("Azure OpenAI Deployment name not found");
}
this.maxTokens = fields?.maxTokens ?? this.maxTokens;
this.temperature = fields?.temperature ?? this.temperature;
this.topP = fields?.topP ?? this.topP;
this.logitBias = fields?.logitBias;
this.user = fields?.user;
this.n = fields?.n ?? this.n;
this.logprobs = fields?.logprobs;
this.echo = fields?.echo;
this.stop = fields?.stopSequences ?? fields?.stop;
this.stopSequences = this.stop;
this.presencePenalty = fields?.presencePenalty ?? this.presencePenalty;
this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty;
this.bestOf = fields?.bestOf ?? this.bestOf;
this.modelName = fields?.model ?? fields?.modelName ?? this.model;
this.model = this.modelName;
this.modelKwargs = fields?.modelKwargs ?? {};
this.streaming = fields?.streaming ?? false;
this.batchSize = fields?.batchSize ?? this.batchSize;
if (this.streaming && this.bestOf && this.bestOf > 1) {
throw new Error("Cannot stream results when bestOf > 1");
}
const options = {
userAgentOptions: { userAgentPrefix: constants_js_1.USER_AGENT_PREFIX },
};
if (isOpenAIApiKey) {
this.client = new openai_1.OpenAIClient(azureCredential);
}
else if ((0, core_auth_1.isTokenCredential)(azureCredential)) {
this.client = new openai_1.OpenAIClient(this.azureOpenAIEndpoint ?? "", azureCredential, options);
}
else {
this.client = new openai_1.OpenAIClient(this.azureOpenAIEndpoint ?? "", azureCredential, options);
}
}
async *_streamResponseChunks(input, options, runManager) {
const deploymentName = this.azureOpenAIApiDeploymentName || this.model;
const stream = await this.caller.call(() => this.client.streamCompletions(deploymentName, [input], {
maxTokens: this.maxTokens,
temperature: this.temperature,
topP: this.topP,
logitBias: this.logitBias,
user: this.user,
n: this.n,
logprobs: this.logprobs,
echo: this.echo,
stop: this.stopSequences,
presencePenalty: this.presencePenalty,
frequencyPenalty: this.frequencyPenalty,
bestOf: this.bestOf,
requestOptions: {
timeout: options?.timeout ?? this.timeout,
},
abortSignal: options?.signal ?? undefined,
...this.modelKwargs,
}));
for await (const data of stream) {
const choice = data?.choices[0];
if (!choice) {
continue;
}
const chunk = new outputs_1.GenerationChunk({
text: choice.text,
generationInfo: {
finishReason: choice.finishReason,
},
});
yield chunk;
// eslint-disable-next-line no-void
void runManager?.handleLLMNewToken(chunk.text ?? "");
}
if (options.signal?.aborted) {
throw new Error("AbortError");
}
}
async _generate(prompts, options, runManager) {
const deploymentName = this.azureOpenAIApiDeploymentName || this.model;
if (this.maxTokens === -1) {
if (prompts.length !== 1) {
throw new Error("max_tokens set to -1 not supported for multiple inputs");
}
this.maxTokens = await (0, base_1.calculateMaxTokens)({
prompt: prompts[0],
// Cast here to allow for other models that may not fit the union
modelName: this.model,
});
}
const subPrompts = (0, chunk_array_1.chunkArray)(prompts, this.batchSize);
if (this.streaming) {
const choices = [];
for (let i = 0; i < subPrompts.length; i += 1) {
let response;
const stream = await this.caller.call(() => this.client.streamCompletions(deploymentName, subPrompts[i], {
maxTokens: this.maxTokens,
temperature: this.temperature,
topP: this.topP,
logitBias: this.logitBias,
user: this.user,
n: this.n,
logprobs: this.logprobs,
echo: this.echo,
stop: this.stopSequences,
presencePenalty: this.presencePenalty,
frequencyPenalty: this.frequencyPenalty,
bestOf: this.bestOf,
requestOptions: {
timeout: options?.timeout ?? this.timeout,
},
abortSignal: options?.signal ?? undefined,
...this.modelKwargs,
}));
for await (const message of stream) {
if (!response) {
response = {
id: message.id,
created: message.created,
promptFilterResults: message.promptFilterResults,
};
}
// on all messages, update choice
for (const part of message.choices) {
if (!choices[part.index]) {
choices[part.index] = part;
}
else {
const choice = choices[part.index];
choice.text += part.text;
choice.finishReason = part.finishReason;
choice.logprobs = part.logprobs;
}
void runManager?.handleLLMNewToken(part.text, {
prompt: Math.floor(part.index / this.n),
completion: part.index % this.n,
});
}
}
if (options.signal?.aborted) {
throw new Error("AbortError");
}
}
const generations = (0, chunk_array_1.chunkArray)(choices, this.n).map((promptChoices) => promptChoices.map((choice) => ({
text: choice.text ?? "",
generationInfo: {
finishReason: choice.finishReason,
logprobs: choice.logprobs,
},
})));
return {
generations,
llmOutput: {
tokenUsage: {
completionTokens: undefined,
promptTokens: undefined,
totalTokens: undefined,
},
},
};
}
else {
const tokenUsage = {};
const subPrompts = (0, chunk_array_1.chunkArray)(prompts, this.batchSize);
const choices = [];
for (let i = 0; i < subPrompts.length; i += 1) {
const data = await this.caller.call(() => this.client.getCompletions(deploymentName, prompts, {
maxTokens: this.maxTokens,
temperature: this.temperature,
topP: this.topP,
logitBias: this.logitBias,
user: this.user,
n: this.n,
logprobs: this.logprobs,
echo: this.echo,
stop: this.stopSequences,
presencePenalty: this.presencePenalty,
frequencyPenalty: this.frequencyPenalty,
bestOf: this.bestOf,
requestOptions: {
timeout: options?.timeout ?? this.timeout,
},
abortSignal: options?.signal ?? undefined,
...this.modelKwargs,
}));
choices.push(...data.choices);
tokenUsage.completionTokens =
(tokenUsage.completionTokens ?? 0) + data.usage.completionTokens;
tokenUsage.promptTokens =
(tokenUsage.promptTokens ?? 0) + data.usage.promptTokens;
tokenUsage.totalTokens =
(tokenUsage.totalTokens ?? 0) + data.usage.totalTokens;
}
const generations = (0, chunk_array_1.chunkArray)(choices, this.n).map((promptChoices) => promptChoices.map((choice) => {
void runManager?.handleLLMNewToken(choice.text, {
prompt: Math.floor(choice.index / this.n),
completion: choice.index % this.n,
});
return {
text: choice.text ?? "",
generationInfo: {
finishReason: choice.finishReason,
logprobs: choice.logprobs,
},
};
}));
return {
generations,
llmOutput: {
tokenUsage: {
completionTokens: tokenUsage.completionTokens,
promptTokens: tokenUsage.promptTokens,
totalTokens: tokenUsage.totalTokens,
},
},
};
}
}
_llmType() {
return "azure_openai";
}
}
exports.AzureOpenAI = AzureOpenAI;