@sap-ai-sdk/langchain
Version:
SAP Cloud SDK for AI is the official Software Development Kit (SDK) for **SAP AI Core**, **SAP Generative AI Hub**, and **Orchestration Service**.
273 lines • 11.7 kB
JavaScript
import { AzureOpenAiChatClient as AzureOpenAiChatClientBase } from '@sap-ai-sdk/foundation-models';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { ChatGenerationChunk } from '@langchain/core/outputs';
import { RunnableLambda, RunnablePassthrough, RunnableSequence } from '@langchain/core/runnables';
import { getSchemaDescription, isInteropZodSchema } from '@langchain/core/utils/types';
import { JsonOutputParser, StructuredOutputParser } from '@langchain/core/output_parsers';
import { toJsonSchema } from '@langchain/core/utils/json_schema';
import { JsonOutputKeyToolsParser } from '@langchain/core/output_parsers/openai_tools';
import { mapAzureOpenAiChunkToLangChainMessageChunk, mapLangChainToAiClient, mapOutputToChatResult, mapToolToOpenAiTool } from './util.js';
/**
* LangChain chat client for Azure OpenAI consumption on SAP BTP.
*/
export class AzureOpenAiChatClient extends BaseChatModel {
temperature;
top_p;
logit_bias;
user;
presence_penalty;
frequency_penalty;
stop;
max_tokens;
supportsStrictToolCalling;
modelName;
openAiChatClient;
constructor(fields, destination) {
super(fields);
this.openAiChatClient = new AzureOpenAiChatClientBase(fields, destination);
this.modelName = fields.modelName;
this.temperature = fields.temperature;
this.top_p = fields.top_p;
this.logit_bias = fields.logit_bias;
this.user = fields.user;
this.stop = fields.stop;
this.presence_penalty = fields.presence_penalty;
this.frequency_penalty = fields.frequency_penalty;
this.max_tokens = fields.max_tokens;
if (fields.supportsStrictToolCalling !== undefined) {
this.supportsStrictToolCalling = fields.supportsStrictToolCalling;
}
}
_llmType() {
return 'azure_openai';
}
async _generate(messages, options, runManager) {
const res = await this.caller.callWithOptions({
signal: options.signal
}, () => this.openAiChatClient.run(mapLangChainToAiClient(this, messages, options), options.requestConfig));
const content = res.getContent();
// we currently do not support streaming
await runManager?.handleLLMNewToken(typeof content === 'string' ? content : '');
return mapOutputToChatResult(res._data);
}
bindTools(tools, kwargs) {
let strict;
if (kwargs?.strict !== undefined) {
strict = kwargs.strict;
}
else if (this.supportsStrictToolCalling !== undefined) {
strict = this.supportsStrictToolCalling;
}
const newTools = tools.map(tool => mapToolToOpenAiTool(tool, strict));
return this.withConfig({
tools: newTools,
...kwargs
});
}
withStructuredOutput(outputSchema, config) {
const schema = outputSchema;
const name = config?.name;
let method = config?.method;
const includeRaw = config?.includeRaw;
let llm;
let outputParser;
if (config?.strict !== undefined && method === 'jsonMode') {
throw new Error("Argument 'strict' is not supported for 'method' = 'jsonMode'.");
}
if (method === undefined) {
method = 'jsonSchema';
}
if (method === 'jsonMode') {
let outputFormatSchema;
if (isInteropZodSchema(schema)) {
outputParser = StructuredOutputParser.fromZodSchema(schema);
outputFormatSchema = toJsonSchema(schema);
}
else {
outputParser = new JsonOutputParser();
}
llm = this.withConfig({
response_format: { type: 'json_object' },
ls_structured_output_format: {
kwargs: { method: 'jsonMode' },
schema: outputFormatSchema
}
});
}
else if (method === 'jsonSchema') {
const asJsonSchema = toJsonSchema(schema);
llm = this.withConfig({
response_format: {
type: 'json_schema',
json_schema: {
name: name ?? 'extract',
description: getSchemaDescription(schema),
schema: asJsonSchema,
strict: config?.strict
}
},
ls_structured_output_format: {
kwargs: { method: 'functionCalling' },
schema: asJsonSchema
}
});
if (isInteropZodSchema(schema)) {
const altParser = StructuredOutputParser.fromZodSchema(schema);
outputParser = RunnableLambda.from((aiMessage) => {
if ('parsed' in aiMessage.additional_kwargs) {
return aiMessage.additional_kwargs.parsed;
}
return altParser;
});
}
else {
outputParser = new JsonOutputParser();
}
}
else {
let functionName = name ?? 'extract';
// Is function calling
if (isInteropZodSchema(schema)) {
const asJsonSchema = toJsonSchema(schema);
llm = this.withConfig({
tools: [
{
type: 'function',
function: {
name: functionName,
description: asJsonSchema.description,
parameters: asJsonSchema
}
}
],
tool_choice: {
type: 'function',
function: {
name: functionName
}
},
ls_structured_output_format: {
kwargs: { method: 'functionCalling' },
schema: asJsonSchema
},
// Do not pass `strict` argument to OpenAI if `config.strict` is undefined
...(config?.strict !== undefined ? { strict: config.strict } : {})
});
outputParser = new JsonOutputKeyToolsParser({
returnSingle: true,
keyName: functionName,
zodSchema: schema
});
}
else {
let openAIFunctionDefinition;
if (typeof schema.name === 'string' &&
typeof schema.parameters === 'object' &&
schema.parameters != null) {
openAIFunctionDefinition = schema;
functionName = schema.name;
}
else {
functionName = schema.title ?? functionName;
openAIFunctionDefinition = {
name: functionName,
description: schema.description ?? '',
parameters: schema
};
}
llm = this.withConfig({
tools: [
{
type: 'function',
function: openAIFunctionDefinition
}
],
tool_choice: {
type: 'function',
function: {
name: functionName
}
},
ls_structured_output_format: {
kwargs: { method: 'functionCalling' },
schema: toJsonSchema(schema)
},
// Do not pass `strict` argument to OpenAI if `config.strict` is undefined
...(config?.strict !== undefined ? { strict: config.strict } : {})
});
outputParser = new JsonOutputKeyToolsParser({
returnSingle: true,
keyName: functionName
});
}
}
if (!includeRaw) {
return llm.pipe(outputParser);
}
const parserAssign = RunnablePassthrough.assign({
parsed: (input, parserConfig) => outputParser.invoke(input.raw, parserConfig)
});
const parserNone = RunnablePassthrough.assign({
parsed: () => null
});
const parsedWithFallback = parserAssign.withFallbacks({
fallbacks: [parserNone]
});
return RunnableSequence.from([{ raw: llm }, parsedWithFallback]);
}
/**
* Stream response chunks from the Azure OpenAI client.
* @param messages - The messages to send to the model.
* @param options - The call options.
* @param runManager - The callback manager for the run.
* @returns An async generator of chat generation chunks.
*/
async *_streamResponseChunks(messages, options, runManager) {
const response = await this.caller.callWithOptions({
signal: options.signal
}, () => this.openAiChatClient.stream(mapLangChainToAiClient(this, messages, options), options.signal, options.requestConfig));
for await (const chunk of response.stream) {
// There can be only none or one choice inside a chunk
const choice = chunk._data.choices[0];
// Map the chunk to a LangChain message chunk
const messageChunk = mapAzureOpenAiChunkToLangChainMessageChunk(chunk);
// Create initial generation info with token indices
const newTokenIndices = {
prompt: options.promptIndex ?? 0,
completion: choice?.index ?? 0
};
const generationInfo = { ...newTokenIndices };
// Process finish reason
if (choice?.finish_reason) {
generationInfo.finish_reason = choice.finish_reason;
// Only include system fingerprint in the last chunk for now to avoid concatenation issues
generationInfo.system_fingerprint = chunk._data.system_fingerprint;
generationInfo.model_name = chunk._data.model;
generationInfo.id = chunk._data.id;
generationInfo.created = chunk._data.created;
generationInfo.index = choice.index;
}
// Process token usage
const tokenUsage = chunk.getTokenUsage();
if (tokenUsage) {
generationInfo.token_usage = tokenUsage;
messageChunk.usage_metadata = {
input_tokens: tokenUsage.prompt_tokens,
output_tokens: tokenUsage.completion_tokens,
total_tokens: tokenUsage.total_tokens
};
}
const content = chunk.getDeltaContent() ?? '';
const generationChunk = new ChatGenerationChunk({
message: messageChunk,
text: content,
generationInfo
});
// Notify the run manager about the new token
// Some parameters(`_runId`, `_parentRunId`, `_tags`) are set as undefined as they are implicitly read from the context.
await runManager?.handleLLMNewToken(content, newTokenIndices, undefined, undefined, undefined, { chunk: generationChunk });
yield generationChunk;
}
}
}
//# sourceMappingURL=chat.js.map