UNPKG

@sap-ai-sdk/langchain

Version:

SAP Cloud SDK for AI is the official Software Development Kit (SDK) for **SAP AI Core**, **SAP Generative AI Hub**, and **Orchestration Service**.

273 lines 11.7 kB
import { AzureOpenAiChatClient as AzureOpenAiChatClientBase } from '@sap-ai-sdk/foundation-models'; import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import { ChatGenerationChunk } from '@langchain/core/outputs'; import { RunnableLambda, RunnablePassthrough, RunnableSequence } from '@langchain/core/runnables'; import { getSchemaDescription, isInteropZodSchema } from '@langchain/core/utils/types'; import { JsonOutputParser, StructuredOutputParser } from '@langchain/core/output_parsers'; import { toJsonSchema } from '@langchain/core/utils/json_schema'; import { JsonOutputKeyToolsParser } from '@langchain/core/output_parsers/openai_tools'; import { mapAzureOpenAiChunkToLangChainMessageChunk, mapLangChainToAiClient, mapOutputToChatResult, mapToolToOpenAiTool } from './util.js'; /** * LangChain chat client for Azure OpenAI consumption on SAP BTP. */ export class AzureOpenAiChatClient extends BaseChatModel { temperature; top_p; logit_bias; user; presence_penalty; frequency_penalty; stop; max_tokens; supportsStrictToolCalling; modelName; openAiChatClient; constructor(fields, destination) { super(fields); this.openAiChatClient = new AzureOpenAiChatClientBase(fields, destination); this.modelName = fields.modelName; this.temperature = fields.temperature; this.top_p = fields.top_p; this.logit_bias = fields.logit_bias; this.user = fields.user; this.stop = fields.stop; this.presence_penalty = fields.presence_penalty; this.frequency_penalty = fields.frequency_penalty; this.max_tokens = fields.max_tokens; if (fields.supportsStrictToolCalling !== undefined) { this.supportsStrictToolCalling = fields.supportsStrictToolCalling; } } _llmType() { return 'azure_openai'; } async _generate(messages, options, runManager) { const res = await this.caller.callWithOptions({ signal: options.signal }, () => this.openAiChatClient.run(mapLangChainToAiClient(this, messages, options), options.requestConfig)); const content = res.getContent(); // we currently do not support streaming await runManager?.handleLLMNewToken(typeof content === 'string' ? content : ''); return mapOutputToChatResult(res._data); } bindTools(tools, kwargs) { let strict; if (kwargs?.strict !== undefined) { strict = kwargs.strict; } else if (this.supportsStrictToolCalling !== undefined) { strict = this.supportsStrictToolCalling; } const newTools = tools.map(tool => mapToolToOpenAiTool(tool, strict)); return this.withConfig({ tools: newTools, ...kwargs }); } withStructuredOutput(outputSchema, config) { const schema = outputSchema; const name = config?.name; let method = config?.method; const includeRaw = config?.includeRaw; let llm; let outputParser; if (config?.strict !== undefined && method === 'jsonMode') { throw new Error("Argument 'strict' is not supported for 'method' = 'jsonMode'."); } if (method === undefined) { method = 'jsonSchema'; } if (method === 'jsonMode') { let outputFormatSchema; if (isInteropZodSchema(schema)) { outputParser = StructuredOutputParser.fromZodSchema(schema); outputFormatSchema = toJsonSchema(schema); } else { outputParser = new JsonOutputParser(); } llm = this.withConfig({ response_format: { type: 'json_object' }, ls_structured_output_format: { kwargs: { method: 'jsonMode' }, schema: outputFormatSchema } }); } else if (method === 'jsonSchema') { const asJsonSchema = toJsonSchema(schema); llm = this.withConfig({ response_format: { type: 'json_schema', json_schema: { name: name ?? 'extract', description: getSchemaDescription(schema), schema: asJsonSchema, strict: config?.strict } }, ls_structured_output_format: { kwargs: { method: 'functionCalling' }, schema: asJsonSchema } }); if (isInteropZodSchema(schema)) { const altParser = StructuredOutputParser.fromZodSchema(schema); outputParser = RunnableLambda.from((aiMessage) => { if ('parsed' in aiMessage.additional_kwargs) { return aiMessage.additional_kwargs.parsed; } return altParser; }); } else { outputParser = new JsonOutputParser(); } } else { let functionName = name ?? 'extract'; // Is function calling if (isInteropZodSchema(schema)) { const asJsonSchema = toJsonSchema(schema); llm = this.withConfig({ tools: [ { type: 'function', function: { name: functionName, description: asJsonSchema.description, parameters: asJsonSchema } } ], tool_choice: { type: 'function', function: { name: functionName } }, ls_structured_output_format: { kwargs: { method: 'functionCalling' }, schema: asJsonSchema }, // Do not pass `strict` argument to OpenAI if `config.strict` is undefined ...(config?.strict !== undefined ? { strict: config.strict } : {}) }); outputParser = new JsonOutputKeyToolsParser({ returnSingle: true, keyName: functionName, zodSchema: schema }); } else { let openAIFunctionDefinition; if (typeof schema.name === 'string' && typeof schema.parameters === 'object' && schema.parameters != null) { openAIFunctionDefinition = schema; functionName = schema.name; } else { functionName = schema.title ?? functionName; openAIFunctionDefinition = { name: functionName, description: schema.description ?? '', parameters: schema }; } llm = this.withConfig({ tools: [ { type: 'function', function: openAIFunctionDefinition } ], tool_choice: { type: 'function', function: { name: functionName } }, ls_structured_output_format: { kwargs: { method: 'functionCalling' }, schema: toJsonSchema(schema) }, // Do not pass `strict` argument to OpenAI if `config.strict` is undefined ...(config?.strict !== undefined ? { strict: config.strict } : {}) }); outputParser = new JsonOutputKeyToolsParser({ returnSingle: true, keyName: functionName }); } } if (!includeRaw) { return llm.pipe(outputParser); } const parserAssign = RunnablePassthrough.assign({ parsed: (input, parserConfig) => outputParser.invoke(input.raw, parserConfig) }); const parserNone = RunnablePassthrough.assign({ parsed: () => null }); const parsedWithFallback = parserAssign.withFallbacks({ fallbacks: [parserNone] }); return RunnableSequence.from([{ raw: llm }, parsedWithFallback]); } /** * Stream response chunks from the Azure OpenAI client. * @param messages - The messages to send to the model. * @param options - The call options. * @param runManager - The callback manager for the run. * @returns An async generator of chat generation chunks. */ async *_streamResponseChunks(messages, options, runManager) { const response = await this.caller.callWithOptions({ signal: options.signal }, () => this.openAiChatClient.stream(mapLangChainToAiClient(this, messages, options), options.signal, options.requestConfig)); for await (const chunk of response.stream) { // There can be only none or one choice inside a chunk const choice = chunk._data.choices[0]; // Map the chunk to a LangChain message chunk const messageChunk = mapAzureOpenAiChunkToLangChainMessageChunk(chunk); // Create initial generation info with token indices const newTokenIndices = { prompt: options.promptIndex ?? 0, completion: choice?.index ?? 0 }; const generationInfo = { ...newTokenIndices }; // Process finish reason if (choice?.finish_reason) { generationInfo.finish_reason = choice.finish_reason; // Only include system fingerprint in the last chunk for now to avoid concatenation issues generationInfo.system_fingerprint = chunk._data.system_fingerprint; generationInfo.model_name = chunk._data.model; generationInfo.id = chunk._data.id; generationInfo.created = chunk._data.created; generationInfo.index = choice.index; } // Process token usage const tokenUsage = chunk.getTokenUsage(); if (tokenUsage) { generationInfo.token_usage = tokenUsage; messageChunk.usage_metadata = { input_tokens: tokenUsage.prompt_tokens, output_tokens: tokenUsage.completion_tokens, total_tokens: tokenUsage.total_tokens }; } const content = chunk.getDeltaContent() ?? ''; const generationChunk = new ChatGenerationChunk({ message: messageChunk, text: content, generationInfo }); // Notify the run manager about the new token // Some parameters(`_runId`, `_parentRunId`, `_tags`) are set as undefined as they are implicitly read from the context. await runManager?.handleLLMNewToken(content, newTokenIndices, undefined, undefined, undefined, { chunk: generationChunk }); yield generationChunk; } } } //# sourceMappingURL=chat.js.map