@mastra/core
Version:
The core foundation of the Mastra framework, providing essential components and interfaces for building AI-powered applications.
544 lines (539 loc) • 14.5 kB
JavaScript
var chunk7JRVDC7F_cjs = require('./chunk-7JRVDC7F.cjs');
var chunkD63P5O4Q_cjs = require('./chunk-D63P5O4Q.cjs');
var chunkO7IW545H_cjs = require('./chunk-O7IW545H.cjs');
var ai = require('ai');
var zod = require('zod');
// src/llm/model/base.ts
var MastraLLMBase = class extends chunkD63P5O4Q_cjs.MastraBase {
// @ts-ignore
#mastra;
#model;
constructor({ name, model }) {
super({
component: chunkO7IW545H_cjs.RegisteredLogger.LLM,
name
});
this.#model = model;
}
getProvider() {
return this.#model.provider;
}
getModelId() {
return this.#model.modelId;
}
getModel() {
return this.#model;
}
convertToMessages(messages) {
if (Array.isArray(messages)) {
return messages.map((m) => {
if (typeof m === "string") {
return {
role: "user",
content: m
};
}
return m;
});
}
return [
{
role: "user",
content: messages
}
];
}
__registerPrimitives(p) {
if (p.telemetry) {
this.__setTelemetry(p.telemetry);
}
if (p.logger) {
this.__setLogger(p.logger);
}
}
__registerMastra(p) {
this.#mastra = p;
}
async __text(input) {
this.logger.debug(`[LLMs:${this.name}] Generating text.`, { input });
throw new Error("Method not implemented.");
}
async __textObject(input) {
this.logger.debug(`[LLMs:${this.name}] Generating object.`, { input });
throw new Error("Method not implemented.");
}
async generate(messages, options) {
this.logger.debug(`[LLMs:${this.name}] Generating text.`, { messages, options });
throw new Error("Method not implemented.");
}
async __stream(input) {
this.logger.debug(`[LLMs:${this.name}] Streaming text.`, { input });
throw new Error("Method not implemented.");
}
async __streamObject(input) {
this.logger.debug(`[LLMs:${this.name}] Streaming object.`, { input });
throw new Error("Method not implemented.");
}
async stream(messages, options) {
this.logger.debug(`[LLMs:${this.name}] Streaming text.`, { messages, options });
throw new Error("Method not implemented.");
}
};
// src/llm/model/model.ts
var MastraLLM = class extends MastraLLMBase {
#model;
#mastra;
constructor({ model, mastra }) {
super({ name: "aisdk", model });
this.#model = model;
if (mastra) {
this.#mastra = mastra;
if (mastra.getLogger()) {
this.__setLogger(mastra.getLogger());
}
}
}
__registerPrimitives(p) {
if (p.telemetry) {
this.__setTelemetry(p.telemetry);
}
if (p.logger) {
this.__setLogger(p.logger);
}
}
__registerMastra(p) {
this.#mastra = p;
}
getProvider() {
return this.#model.provider;
}
getModelId() {
return this.#model.modelId;
}
getModel() {
return this.#model;
}
async __text({
runId,
messages,
maxSteps = 5,
tools = {},
temperature,
toolChoice = "auto",
onStepFinish,
experimental_output,
telemetry,
threadId,
resourceId,
memory,
runtimeContext,
...rest
}) {
const model = this.#model;
this.logger.debug(`[LLM] - Generating text`, {
runId,
messages,
maxSteps,
threadId,
resourceId,
tools: Object.keys(tools)
});
const argsForExecute = {
model,
temperature,
tools: {
...tools
},
toolChoice,
maxSteps,
onStepFinish: async (props) => {
void onStepFinish?.(props);
this.logger.debug("[LLM] - Step Change:", {
text: props?.text,
toolCalls: props?.toolCalls,
toolResults: props?.toolResults,
finishReason: props?.finishReason,
usage: props?.usage,
runId
});
if (props?.response?.headers?.["x-ratelimit-remaining-tokens"] && parseInt(props?.response?.headers?.["x-ratelimit-remaining-tokens"], 10) < 2e3) {
this.logger.warn("Rate limit approaching, waiting 10 seconds", { runId });
await chunk7JRVDC7F_cjs.delay(10 * 1e3);
}
},
...rest
};
let schema;
if (experimental_output) {
this.logger.debug("[LLM] - Using experimental output", {
runId
});
if (typeof experimental_output.parse === "function") {
schema = experimental_output;
if (schema instanceof zod.z.ZodArray) {
schema = schema._def.type;
}
} else {
schema = ai.jsonSchema(experimental_output);
}
}
return await ai.generateText({
messages,
...argsForExecute,
experimental_telemetry: {
...this.experimental_telemetry,
...telemetry
},
experimental_output: schema ? ai.Output.object({
schema
}) : void 0
});
}
async __textObject({
messages,
onStepFinish,
maxSteps = 5,
tools = {},
structuredOutput,
runId,
temperature,
toolChoice = "auto",
telemetry,
threadId,
resourceId,
memory,
runtimeContext,
...rest
}) {
const model = this.#model;
this.logger.debug(`[LLM] - Generating a text object`, { runId });
const argsForExecute = {
model,
temperature,
tools: {
...tools
},
maxSteps,
toolChoice,
onStepFinish: async (props) => {
void onStepFinish?.(props);
this.logger.debug("[LLM] - Step Change:", {
text: props?.text,
toolCalls: props?.toolCalls,
toolResults: props?.toolResults,
finishReason: props?.finishReason,
usage: props?.usage,
runId
});
if (props?.response?.headers?.["x-ratelimit-remaining-tokens"] && parseInt(props?.response?.headers?.["x-ratelimit-remaining-tokens"], 10) < 2e3) {
this.logger.warn("Rate limit approaching, waiting 10 seconds", { runId });
await chunk7JRVDC7F_cjs.delay(10 * 1e3);
}
},
...rest
};
let schema;
let output = "object";
if (typeof structuredOutput.parse === "function") {
schema = structuredOutput;
if (schema instanceof zod.z.ZodArray) {
output = "array";
schema = schema._def.type;
}
} else {
schema = ai.jsonSchema(structuredOutput);
}
return await ai.generateObject({
messages,
...argsForExecute,
output,
schema,
experimental_telemetry: {
...this.experimental_telemetry,
...telemetry
}
});
}
async __stream({
messages,
onStepFinish,
onFinish,
maxSteps = 5,
tools = {},
runId,
temperature,
toolChoice = "auto",
experimental_output,
telemetry,
threadId,
resourceId,
memory,
runtimeContext,
...rest
}) {
const model = this.#model;
this.logger.debug(`[LLM] - Streaming text`, {
runId,
threadId,
resourceId,
messages,
maxSteps,
tools: Object.keys(tools || {})
});
const argsForExecute = {
model,
temperature,
tools: {
...tools
},
maxSteps,
toolChoice,
onStepFinish: async (props) => {
void onStepFinish?.(props);
this.logger.debug("[LLM] - Stream Step Change:", {
text: props?.text,
toolCalls: props?.toolCalls,
toolResults: props?.toolResults,
finishReason: props?.finishReason,
usage: props?.usage,
runId
});
if (props?.response?.headers?.["x-ratelimit-remaining-tokens"] && parseInt(props?.response?.headers?.["x-ratelimit-remaining-tokens"], 10) < 2e3) {
this.logger.warn("Rate limit approaching, waiting 10 seconds", { runId });
await chunk7JRVDC7F_cjs.delay(10 * 1e3);
}
},
onFinish: async (props) => {
void onFinish?.(props);
this.logger.debug("[LLM] - Stream Finished:", {
text: props?.text,
toolCalls: props?.toolCalls,
toolResults: props?.toolResults,
finishReason: props?.finishReason,
usage: props?.usage,
runId,
threadId,
resourceId
});
},
...rest
};
let schema;
if (experimental_output) {
this.logger.debug("[LLM] - Using experimental output", {
runId
});
if (typeof experimental_output.parse === "function") {
schema = experimental_output;
if (schema instanceof zod.z.ZodArray) {
schema = schema._def.type;
}
} else {
schema = ai.jsonSchema(experimental_output);
}
}
return await ai.streamText({
messages,
...argsForExecute,
experimental_telemetry: {
...this.experimental_telemetry,
...telemetry
},
experimental_output: schema ? ai.Output.object({
schema
}) : void 0
});
}
async __streamObject({
messages,
runId,
tools = {},
maxSteps = 5,
toolChoice = "auto",
runtimeContext,
threadId,
resourceId,
memory,
temperature,
onStepFinish,
onFinish,
structuredOutput,
telemetry,
...rest
}) {
const model = this.#model;
this.logger.debug(`[LLM] - Streaming structured output`, {
runId,
messages,
maxSteps,
tools: Object.keys(tools || {})
});
const finalTools = tools;
const argsForExecute = {
model,
temperature,
tools: {
...finalTools
},
maxSteps,
toolChoice,
onStepFinish: async (props) => {
void onStepFinish?.(props);
this.logger.debug("[LLM] - Stream Step Change:", {
text: props?.text,
toolCalls: props?.toolCalls,
toolResults: props?.toolResults,
finishReason: props?.finishReason,
usage: props?.usage,
runId,
threadId,
resourceId
});
if (props?.response?.headers?.["x-ratelimit-remaining-tokens"] && parseInt(props?.response?.headers?.["x-ratelimit-remaining-tokens"], 10) < 2e3) {
this.logger.warn("Rate limit approaching, waiting 10 seconds", { runId });
await chunk7JRVDC7F_cjs.delay(10 * 1e3);
}
},
onFinish: async (props) => {
void onFinish?.(props);
this.logger.debug("[LLM] - Stream Finished:", {
text: props?.text,
toolCalls: props?.toolCalls,
toolResults: props?.toolResults,
finishReason: props?.finishReason,
usage: props?.usage,
runId,
threadId,
resourceId
});
},
...rest
};
let schema;
let output = "object";
if (typeof structuredOutput.parse === "function") {
schema = structuredOutput;
if (schema instanceof zod.z.ZodArray) {
output = "array";
schema = schema._def.type;
}
} else {
schema = ai.jsonSchema(structuredOutput);
}
return ai.streamObject({
messages,
...argsForExecute,
output,
schema,
experimental_telemetry: {
...this.experimental_telemetry,
...telemetry
}
});
}
async generate(messages, { maxSteps = 5, output, ...rest }) {
const msgs = this.convertToMessages(messages);
if (!output) {
return await this.__text({
messages: msgs,
maxSteps,
...rest
});
}
return await this.__textObject({
messages: msgs,
structuredOutput: output,
maxSteps,
...rest
});
}
async stream(messages, { maxSteps = 5, output, ...rest }) {
const msgs = this.convertToMessages(messages);
if (!output) {
return await this.__stream({
messages: msgs,
maxSteps,
...rest
});
}
return await this.__streamObject({
messages: msgs,
structuredOutput: output,
maxSteps,
...rest
});
}
convertToUIMessages(messages) {
function addToolMessageToChat({
toolMessage,
messages: messages2,
toolResultContents
}) {
const chatMessages2 = messages2.map((message) => {
if (message.toolInvocations) {
return {
...message,
toolInvocations: message.toolInvocations.map((toolInvocation) => {
const toolResult = toolMessage.content.find((tool) => tool.toolCallId === toolInvocation.toolCallId);
if (toolResult) {
return {
...toolInvocation,
state: "result",
result: toolResult.result
};
}
return toolInvocation;
})
};
}
return message;
});
const resultContents = [...toolResultContents, ...toolMessage.content];
return { chatMessages: chatMessages2, toolResultContents: resultContents };
}
const { chatMessages } = messages.reduce(
(obj, message) => {
if (message.role === "tool") {
return addToolMessageToChat({
toolMessage: message,
messages: obj.chatMessages,
toolResultContents: obj.toolResultContents
});
}
let textContent = "";
let toolInvocations = [];
if (typeof message.content === "string") {
textContent = message.content;
} else if (typeof message.content === "number") {
textContent = String(message.content);
} else if (Array.isArray(message.content)) {
for (const content of message.content) {
if (content.type === "text") {
textContent += content.text;
} else if (content.type === "tool-call") {
const toolResult = obj.toolResultContents.find((tool) => tool.toolCallId === content.toolCallId);
toolInvocations.push({
state: toolResult ? "result" : "call",
toolCallId: content.toolCallId,
toolName: content.toolName,
args: content.args,
result: toolResult?.result
});
}
}
}
obj.chatMessages.push({
id: message.id,
role: message.role,
content: textContent,
toolInvocations
});
return obj;
},
{ chatMessages: [], toolResultContents: [] }
);
return chatMessages;
}
};
exports.MastraLLM = MastraLLM;
;