UNPKG

@langchain/community

Version:
271 lines (270 loc) 10.3 kB
import { BaseChatModel } from "@langchain/core/language_models/chat_models"; import { AIMessage, AIMessageChunk, ChatMessage, } from "@langchain/core/messages"; import { ChatGenerationChunk, } from "@langchain/core/outputs"; /** * Represents a chat message in the Google Vertex AI chat model. */ export class GoogleVertexAIChatMessage { constructor(fields) { Object.defineProperty(this, "author", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "content", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "name", { enumerable: true, configurable: true, writable: true, value: void 0 }); this.author = fields.author; this.content = fields.content; this.name = fields.name; } /** * Extracts the role of a generic message and maps it to a Google Vertex * AI chat author. * @param message The chat message to extract the role from. * @returns The role of the message mapped to a Google Vertex AI chat author. */ static extractGenericMessageCustomRole(message) { if (message.role !== "system" && message.role !== "bot" && message.role !== "user" && message.role !== "context") { console.warn(`Unknown message role: ${message.role}`); } return message.role; } /** * Maps a message type to a Google Vertex AI chat author. * @param message The message to map. * @param model The model to use for mapping. * @returns The message type mapped to a Google Vertex AI chat author. */ static mapMessageTypeToVertexChatAuthor(message, model) { const type = message._getType(); switch (type) { case "ai": return model.startsWith("codechat-") ? "system" : "bot"; case "human": return "user"; case "system": throw new Error(`System messages are only supported as the first passed message for Google Vertex AI.`); case "generic": { if (!ChatMessage.isInstance(message)) throw new Error("Invalid generic chat message"); return GoogleVertexAIChatMessage.extractGenericMessageCustomRole(message); } default: throw new Error(`Unknown / unsupported message type: ${message}`); } } /** * Creates a new Google Vertex AI chat message from a base message. * @param message The base message to convert. * @param model The model to use for conversion. * @returns A new Google Vertex AI chat message. */ static fromChatMessage(message, model) { if (typeof message.content !== "string") { throw new Error("ChatGoogleVertexAI does not support non-string message content."); } return new GoogleVertexAIChatMessage({ author: GoogleVertexAIChatMessage.mapMessageTypeToVertexChatAuthor(message, model), content: message.content, }); } } /** * Base class for Google Vertex AI chat models. * Implemented subclasses must provide a GoogleVertexAILLMConnection * with appropriate auth client. */ export class BaseChatGoogleVertexAI extends BaseChatModel { get lc_aliases() { return { model: "model_name", }; } constructor(fields) { super(fields ?? {}); Object.defineProperty(this, "lc_serializable", { enumerable: true, configurable: true, writable: true, value: true }); Object.defineProperty(this, "model", { enumerable: true, configurable: true, writable: true, value: "chat-bison" }); Object.defineProperty(this, "temperature", { enumerable: true, configurable: true, writable: true, value: 0.2 }); Object.defineProperty(this, "maxOutputTokens", { enumerable: true, configurable: true, writable: true, value: 1024 }); Object.defineProperty(this, "topP", { enumerable: true, configurable: true, writable: true, value: 0.8 }); Object.defineProperty(this, "topK", { enumerable: true, configurable: true, writable: true, value: 40 }); Object.defineProperty(this, "examples", { enumerable: true, configurable: true, writable: true, value: [] }); Object.defineProperty(this, "connection", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "streamedConnection", { enumerable: true, configurable: true, writable: true, value: void 0 }); this.model = fields?.model ?? this.model; this.temperature = fields?.temperature ?? this.temperature; this.maxOutputTokens = fields?.maxOutputTokens ?? this.maxOutputTokens; this.topP = fields?.topP ?? this.topP; this.topK = fields?.topK ?? this.topK; this.examples = fields?.examples ?? this.examples; } _combineLLMOutput() { // TODO: Combine the safetyAttributes return []; } async *_streamResponseChunks(_messages, _options, _runManager) { // Make the call as a streaming request const instance = this.createInstance(_messages); const parameters = this.formatParameters(); const result = await this.streamedConnection.request([instance], parameters, _options); // Get the streaming parser of the response const stream = result.data; // Loop until the end of the stream // During the loop, yield each time we get a chunk from the streaming parser // that is either available or added to the queue while (!stream.streamDone) { const output = await stream.nextChunk(); const chunk = output !== null ? BaseChatGoogleVertexAI.convertPredictionChunk(output) : new ChatGenerationChunk({ text: "", message: new AIMessageChunk(""), generationInfo: { finishReason: "stop" }, }); yield chunk; } } async _generate(messages, options) { const instance = this.createInstance(messages); const parameters = this.formatParameters(); const result = await this.connection.request([instance], parameters, options); const generations = result?.data?.predictions?.map((prediction) => BaseChatGoogleVertexAI.convertPrediction(prediction)) ?? []; return { generations, }; } _llmType() { return "vertexai"; } /** * Creates an instance of the Google Vertex AI chat model. * @param messages The messages for the model instance. * @returns A new instance of the Google Vertex AI chat model. */ createInstance(messages) { let context = ""; let conversationMessages = messages; if (messages[0]?._getType() === "system") { if (typeof messages[0].content !== "string") { throw new Error("ChatGoogleVertexAI does not support non-string message content."); } context = messages[0].content; conversationMessages = messages.slice(1); } // https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts if (conversationMessages.length % 2 === 0) { throw new Error(`Google Vertex AI requires an odd number of messages to generate a response.`); } const vertexChatMessages = conversationMessages.map((baseMessage, i) => { const currMessage = GoogleVertexAIChatMessage.fromChatMessage(baseMessage, this.model); const prevMessage = i > 0 ? GoogleVertexAIChatMessage.fromChatMessage(conversationMessages[i - 1], this.model) : null; // https://cloud.google.com/vertex-ai/docs/generative-ai/chat/chat-prompts#messages if (prevMessage && currMessage.author === prevMessage.author) { throw new Error(`Google Vertex AI requires AI and human messages to alternate.`); } return currMessage; }); const examples = this.examples.map((example) => ({ input: GoogleVertexAIChatMessage.fromChatMessage(example.input, this.model), output: GoogleVertexAIChatMessage.fromChatMessage(example.output, this.model), })); const instance = { context, examples, messages: vertexChatMessages, }; return instance; } formatParameters() { return { temperature: this.temperature, topK: this.topK, topP: this.topP, maxOutputTokens: this.maxOutputTokens, }; } /** * Converts a prediction from the Google Vertex AI chat model to a chat * generation. * @param prediction The prediction to convert. * @returns The converted chat generation. */ static convertPrediction(prediction) { const message = prediction?.candidates[0]; return { text: message?.content, message: new AIMessage(message.content), generationInfo: prediction, }; } // eslint-disable-next-line @typescript-eslint/no-explicit-any static convertPredictionChunk(output) { const generation = BaseChatGoogleVertexAI.convertPrediction(output.outputs[0]); return new ChatGenerationChunk({ text: generation.text, message: new AIMessageChunk(generation.message), generationInfo: generation.generationInfo, }); } }