UNPKG

@mastra/core

Version:

The core foundation of the Mastra framework, providing essential components and interfaces for building AI-powered applications.

544 lines (539 loc) 14.5 kB
'use strict'; var chunk7JRVDC7F_cjs = require('./chunk-7JRVDC7F.cjs'); var chunkD63P5O4Q_cjs = require('./chunk-D63P5O4Q.cjs'); var chunkO7IW545H_cjs = require('./chunk-O7IW545H.cjs'); var ai = require('ai'); var zod = require('zod'); // src/llm/model/base.ts var MastraLLMBase = class extends chunkD63P5O4Q_cjs.MastraBase { // @ts-ignore #mastra; #model; constructor({ name, model }) { super({ component: chunkO7IW545H_cjs.RegisteredLogger.LLM, name }); this.#model = model; } getProvider() { return this.#model.provider; } getModelId() { return this.#model.modelId; } getModel() { return this.#model; } convertToMessages(messages) { if (Array.isArray(messages)) { return messages.map((m) => { if (typeof m === "string") { return { role: "user", content: m }; } return m; }); } return [ { role: "user", content: messages } ]; } __registerPrimitives(p) { if (p.telemetry) { this.__setTelemetry(p.telemetry); } if (p.logger) { this.__setLogger(p.logger); } } __registerMastra(p) { this.#mastra = p; } async __text(input) { this.logger.debug(`[LLMs:${this.name}] Generating text.`, { input }); throw new Error("Method not implemented."); } async __textObject(input) { this.logger.debug(`[LLMs:${this.name}] Generating object.`, { input }); throw new Error("Method not implemented."); } async generate(messages, options) { this.logger.debug(`[LLMs:${this.name}] Generating text.`, { messages, options }); throw new Error("Method not implemented."); } async __stream(input) { this.logger.debug(`[LLMs:${this.name}] Streaming text.`, { input }); throw new Error("Method not implemented."); } async __streamObject(input) { this.logger.debug(`[LLMs:${this.name}] Streaming object.`, { input }); throw new Error("Method not implemented."); } async stream(messages, options) { this.logger.debug(`[LLMs:${this.name}] Streaming text.`, { messages, options }); throw new Error("Method not implemented."); } }; // src/llm/model/model.ts var MastraLLM = class extends MastraLLMBase { #model; #mastra; constructor({ model, mastra }) { super({ name: "aisdk", model }); this.#model = model; if (mastra) { this.#mastra = mastra; if (mastra.getLogger()) { this.__setLogger(mastra.getLogger()); } } } __registerPrimitives(p) { if (p.telemetry) { this.__setTelemetry(p.telemetry); } if (p.logger) { this.__setLogger(p.logger); } } __registerMastra(p) { this.#mastra = p; } getProvider() { return this.#model.provider; } getModelId() { return this.#model.modelId; } getModel() { return this.#model; } async __text({ runId, messages, maxSteps = 5, tools = {}, temperature, toolChoice = "auto", onStepFinish, experimental_output, telemetry, threadId, resourceId, memory, runtimeContext, ...rest }) { const model = this.#model; this.logger.debug(`[LLM] - Generating text`, { runId, messages, maxSteps, threadId, resourceId, tools: Object.keys(tools) }); const argsForExecute = { model, temperature, tools: { ...tools }, toolChoice, maxSteps, onStepFinish: async (props) => { void onStepFinish?.(props); this.logger.debug("[LLM] - Step Change:", { text: props?.text, toolCalls: props?.toolCalls, toolResults: props?.toolResults, finishReason: props?.finishReason, usage: props?.usage, runId }); if (props?.response?.headers?.["x-ratelimit-remaining-tokens"] && parseInt(props?.response?.headers?.["x-ratelimit-remaining-tokens"], 10) < 2e3) { this.logger.warn("Rate limit approaching, waiting 10 seconds", { runId }); await chunk7JRVDC7F_cjs.delay(10 * 1e3); } }, ...rest }; let schema; if (experimental_output) { this.logger.debug("[LLM] - Using experimental output", { runId }); if (typeof experimental_output.parse === "function") { schema = experimental_output; if (schema instanceof zod.z.ZodArray) { schema = schema._def.type; } } else { schema = ai.jsonSchema(experimental_output); } } return await ai.generateText({ messages, ...argsForExecute, experimental_telemetry: { ...this.experimental_telemetry, ...telemetry }, experimental_output: schema ? ai.Output.object({ schema }) : void 0 }); } async __textObject({ messages, onStepFinish, maxSteps = 5, tools = {}, structuredOutput, runId, temperature, toolChoice = "auto", telemetry, threadId, resourceId, memory, runtimeContext, ...rest }) { const model = this.#model; this.logger.debug(`[LLM] - Generating a text object`, { runId }); const argsForExecute = { model, temperature, tools: { ...tools }, maxSteps, toolChoice, onStepFinish: async (props) => { void onStepFinish?.(props); this.logger.debug("[LLM] - Step Change:", { text: props?.text, toolCalls: props?.toolCalls, toolResults: props?.toolResults, finishReason: props?.finishReason, usage: props?.usage, runId }); if (props?.response?.headers?.["x-ratelimit-remaining-tokens"] && parseInt(props?.response?.headers?.["x-ratelimit-remaining-tokens"], 10) < 2e3) { this.logger.warn("Rate limit approaching, waiting 10 seconds", { runId }); await chunk7JRVDC7F_cjs.delay(10 * 1e3); } }, ...rest }; let schema; let output = "object"; if (typeof structuredOutput.parse === "function") { schema = structuredOutput; if (schema instanceof zod.z.ZodArray) { output = "array"; schema = schema._def.type; } } else { schema = ai.jsonSchema(structuredOutput); } return await ai.generateObject({ messages, ...argsForExecute, output, schema, experimental_telemetry: { ...this.experimental_telemetry, ...telemetry } }); } async __stream({ messages, onStepFinish, onFinish, maxSteps = 5, tools = {}, runId, temperature, toolChoice = "auto", experimental_output, telemetry, threadId, resourceId, memory, runtimeContext, ...rest }) { const model = this.#model; this.logger.debug(`[LLM] - Streaming text`, { runId, threadId, resourceId, messages, maxSteps, tools: Object.keys(tools || {}) }); const argsForExecute = { model, temperature, tools: { ...tools }, maxSteps, toolChoice, onStepFinish: async (props) => { void onStepFinish?.(props); this.logger.debug("[LLM] - Stream Step Change:", { text: props?.text, toolCalls: props?.toolCalls, toolResults: props?.toolResults, finishReason: props?.finishReason, usage: props?.usage, runId }); if (props?.response?.headers?.["x-ratelimit-remaining-tokens"] && parseInt(props?.response?.headers?.["x-ratelimit-remaining-tokens"], 10) < 2e3) { this.logger.warn("Rate limit approaching, waiting 10 seconds", { runId }); await chunk7JRVDC7F_cjs.delay(10 * 1e3); } }, onFinish: async (props) => { void onFinish?.(props); this.logger.debug("[LLM] - Stream Finished:", { text: props?.text, toolCalls: props?.toolCalls, toolResults: props?.toolResults, finishReason: props?.finishReason, usage: props?.usage, runId, threadId, resourceId }); }, ...rest }; let schema; if (experimental_output) { this.logger.debug("[LLM] - Using experimental output", { runId }); if (typeof experimental_output.parse === "function") { schema = experimental_output; if (schema instanceof zod.z.ZodArray) { schema = schema._def.type; } } else { schema = ai.jsonSchema(experimental_output); } } return await ai.streamText({ messages, ...argsForExecute, experimental_telemetry: { ...this.experimental_telemetry, ...telemetry }, experimental_output: schema ? ai.Output.object({ schema }) : void 0 }); } async __streamObject({ messages, runId, tools = {}, maxSteps = 5, toolChoice = "auto", runtimeContext, threadId, resourceId, memory, temperature, onStepFinish, onFinish, structuredOutput, telemetry, ...rest }) { const model = this.#model; this.logger.debug(`[LLM] - Streaming structured output`, { runId, messages, maxSteps, tools: Object.keys(tools || {}) }); const finalTools = tools; const argsForExecute = { model, temperature, tools: { ...finalTools }, maxSteps, toolChoice, onStepFinish: async (props) => { void onStepFinish?.(props); this.logger.debug("[LLM] - Stream Step Change:", { text: props?.text, toolCalls: props?.toolCalls, toolResults: props?.toolResults, finishReason: props?.finishReason, usage: props?.usage, runId, threadId, resourceId }); if (props?.response?.headers?.["x-ratelimit-remaining-tokens"] && parseInt(props?.response?.headers?.["x-ratelimit-remaining-tokens"], 10) < 2e3) { this.logger.warn("Rate limit approaching, waiting 10 seconds", { runId }); await chunk7JRVDC7F_cjs.delay(10 * 1e3); } }, onFinish: async (props) => { void onFinish?.(props); this.logger.debug("[LLM] - Stream Finished:", { text: props?.text, toolCalls: props?.toolCalls, toolResults: props?.toolResults, finishReason: props?.finishReason, usage: props?.usage, runId, threadId, resourceId }); }, ...rest }; let schema; let output = "object"; if (typeof structuredOutput.parse === "function") { schema = structuredOutput; if (schema instanceof zod.z.ZodArray) { output = "array"; schema = schema._def.type; } } else { schema = ai.jsonSchema(structuredOutput); } return ai.streamObject({ messages, ...argsForExecute, output, schema, experimental_telemetry: { ...this.experimental_telemetry, ...telemetry } }); } async generate(messages, { maxSteps = 5, output, ...rest }) { const msgs = this.convertToMessages(messages); if (!output) { return await this.__text({ messages: msgs, maxSteps, ...rest }); } return await this.__textObject({ messages: msgs, structuredOutput: output, maxSteps, ...rest }); } async stream(messages, { maxSteps = 5, output, ...rest }) { const msgs = this.convertToMessages(messages); if (!output) { return await this.__stream({ messages: msgs, maxSteps, ...rest }); } return await this.__streamObject({ messages: msgs, structuredOutput: output, maxSteps, ...rest }); } convertToUIMessages(messages) { function addToolMessageToChat({ toolMessage, messages: messages2, toolResultContents }) { const chatMessages2 = messages2.map((message) => { if (message.toolInvocations) { return { ...message, toolInvocations: message.toolInvocations.map((toolInvocation) => { const toolResult = toolMessage.content.find((tool) => tool.toolCallId === toolInvocation.toolCallId); if (toolResult) { return { ...toolInvocation, state: "result", result: toolResult.result }; } return toolInvocation; }) }; } return message; }); const resultContents = [...toolResultContents, ...toolMessage.content]; return { chatMessages: chatMessages2, toolResultContents: resultContents }; } const { chatMessages } = messages.reduce( (obj, message) => { if (message.role === "tool") { return addToolMessageToChat({ toolMessage: message, messages: obj.chatMessages, toolResultContents: obj.toolResultContents }); } let textContent = ""; let toolInvocations = []; if (typeof message.content === "string") { textContent = message.content; } else if (typeof message.content === "number") { textContent = String(message.content); } else if (Array.isArray(message.content)) { for (const content of message.content) { if (content.type === "text") { textContent += content.text; } else if (content.type === "tool-call") { const toolResult = obj.toolResultContents.find((tool) => tool.toolCallId === content.toolCallId); toolInvocations.push({ state: toolResult ? "result" : "call", toolCallId: content.toolCallId, toolName: content.toolName, args: content.args, result: toolResult?.result }); } } } obj.chatMessages.push({ id: message.id, role: message.role, content: textContent, toolInvocations }); return obj; }, { chatMessages: [], toolResultContents: [] } ); return chatMessages; } }; exports.MastraLLM = MastraLLM;