mastra-browser-core
Version:
The core foundation of the Mastra framework, providing essential components and interfaces for building AI-powered applications.
528 lines (526 loc) • 17.7 kB
JavaScript
import { MastraStorage } from '../../chunk-UUMEB542.js';
import { TABLE_WORKFLOW_SNAPSHOT, TABLE_TRACES, TABLE_MESSAGES, TABLE_THREADS, TABLE_EVALS } from '../../chunk-FAAZLQT5.js';
import { __name, __publicField } from '../../chunk-WH5OY6PO.js';
import { isAbsolute, join, resolve } from 'path';
import { PGlite, MemoryFS } from '@electric-sql/pglite';
function safelyParseJSON(jsonString) {
try {
return JSON.parse(jsonString);
} catch {
return {};
}
}
__name(safelyParseJSON, "safelyParseJSON");
var _PGliteStore = class _PGliteStore extends MastraStorage {
constructor({ config }) {
super({ name: `PGliteStore` });
__publicField(this, "client", null);
__publicField(this, "clientPromise", null);
if (config.url === ":memory:" || config.url.startsWith("file::memory:")) {
this.shouldCacheInit = false;
}
this.clientPromise = this.initClient(config);
}
async initClient(config) {
const url = this.rewriteDbUrl(config.url);
this.logger.debug(`Initializing PGlite with URL: ${url}`);
try {
const client = await PGlite.create(url, {
fs: new MemoryFS()
});
this.client = client;
return client;
} catch (error) {
this.logger.error(`Error initializing PGlite client: ${error}`);
throw error;
}
}
async getClient() {
if (!this.client && this.clientPromise) {
this.client = await this.clientPromise;
}
if (!this.client) {
throw new Error("PGlite client not initialized");
}
return this.client;
}
// Rewrite DB URL to match the LibSQLStore logic for consistent file paths
rewriteDbUrl(url) {
if (url.startsWith("file:") && url !== "file::memory:") {
const pathPart = url.slice("file:".length);
if (isAbsolute(pathPart)) {
return url;
}
const cwd = process.cwd();
if (cwd.includes(".mastra") && (cwd.endsWith(`output`) || cwd.endsWith(`output/`) || cwd.endsWith(`output\\`))) {
const baseDir = join(cwd, `..`, `..`);
const fullPath = resolve(baseDir, pathPart);
this.logger.debug(
`Initializing PGlite db with url ${url} with relative file path from inside .mastra/output directory. Rewriting relative file url to "file:${fullPath}". This ensures it's outside the .mastra/output directory.`
);
return `file:${fullPath}`;
}
}
return url;
}
getCreateTableSQL(tableName, schema) {
const columns = Object.entries(schema).map(([name, col]) => {
let type = col.type.toUpperCase();
if (type === "TEXT") type = "TEXT";
if (type === "TIMESTAMP") type = "TEXT";
if (type === "JSONB") type = "JSONB";
if (type === "BIGINT") type = "BIGINT";
if (type === "INTEGER") type = "INTEGER";
const nullable = col.nullable ? "" : "NOT NULL";
const primaryKey = col.primaryKey ? "PRIMARY KEY" : "";
return `"${name}" ${type} ${nullable} ${primaryKey}`.trim();
});
if (tableName === TABLE_WORKFLOW_SNAPSHOT) {
const stmnt = `CREATE TABLE IF NOT EXISTS ${tableName} (
${columns.join(",\n")},
PRIMARY KEY (workflow_name, run_id)
)`;
return stmnt;
}
return `CREATE TABLE IF NOT EXISTS ${tableName} (${columns.join(", ")})`;
}
async createTable({
tableName,
schema
}) {
try {
this.logger.debug(`Creating database table`, { tableName, operation: "schema init" });
const sql = this.getCreateTableSQL(tableName, schema);
const client = await this.getClient();
await client.exec(sql);
} catch (error) {
this.logger.error(`Error creating table ${tableName}: ${error}`);
throw error;
}
}
async clearTable({ tableName }) {
try {
const client = await this.getClient();
await client.exec(`DELETE FROM ${tableName}`);
} catch (e) {
if (e instanceof Error) {
this.logger.error(e.message);
}
}
}
prepareParams(record) {
return Object.fromEntries(
Object.entries(record).map(([k, v]) => {
if (typeof v === `undefined`) {
return [k, null];
}
if (v instanceof Date) {
return [k, v.toISOString()];
}
if (typeof v === "object") {
return [k, JSON.stringify(v)];
}
return [k, v];
})
);
}
async insert({ tableName, record }) {
try {
const client = await this.getClient();
const columns = Object.keys(record);
const placeholders = columns.map((_, i) => `$${i + 1}`).join(", ");
const params = this.prepareParams(record);
const values = Object.values(params);
const quotedColumns = columns.map((col) => `"${col}"`);
await client.query(
`INSERT INTO ${tableName} (${quotedColumns.join(", ")}) VALUES (${placeholders})
ON CONFLICT (${this.getPrimaryKeys(tableName)}) DO UPDATE SET
${columns.map((col, i) => `"${col}" = $${i + 1}`).join(", ")}`,
values
);
} catch (error) {
this.logger.error(`Error upserting into table ${tableName}: ${error}`);
throw error;
}
}
// Helper to get primary keys for upsert operation
getPrimaryKeys(tableName) {
switch (tableName) {
case TABLE_THREADS:
case TABLE_MESSAGES:
case TABLE_TRACES:
return '"id"';
case TABLE_WORKFLOW_SNAPSHOT:
return '"workflow_name", "run_id"';
default:
return '"id"';
}
}
async batchInsert({ tableName, records }) {
if (records.length === 0) return;
try {
const client = await this.getClient();
await client.transaction(async (tx) => {
for (const record of records) {
const columns = Object.keys(record);
const placeholders = columns.map((_, i) => `$${i + 1}`).join(", ");
const params = this.prepareParams(record);
const values = Object.values(params);
const quotedColumns = columns.map((col) => `"${col}"`);
await tx.query(
`INSERT INTO ${tableName} (${quotedColumns.join(", ")}) VALUES (${placeholders})
ON CONFLICT (${this.getPrimaryKeys(tableName)}) DO UPDATE SET
${columns.map((col, i) => `"${col}" = $${i + 1}`).join(", ")}`,
values
);
}
});
} catch (error) {
this.logger.error(`Error batch upserting into table ${tableName}: ${error}`);
throw error;
}
}
async load({ tableName, keys }) {
const conditions = Object.keys(keys).map((key, i) => `"${key}" = $${i + 1}`).join(" AND ");
const values = Object.values(keys);
const client = await this.getClient();
try {
const result = await client.query(
`SELECT * FROM ${tableName} WHERE ${conditions} ORDER BY "createdAt" DESC LIMIT 1`,
values
);
if (!result.rows || result.rows.length === 0) {
return null;
}
const row = result.rows[0];
const parsed = Object.fromEntries(
Object.entries(row || {}).map(([k, v]) => {
try {
return [k, typeof v === "string" ? v.startsWith("{") || v.startsWith("[") ? JSON.parse(v) : v : v];
} catch {
return [k, v];
}
})
);
return parsed;
} catch (error) {
this.logger.error(`Error querying table ${tableName}: ${error}`);
throw error;
}
}
async getThreadById({ threadId }) {
const result = await this.load({
tableName: TABLE_THREADS,
keys: { id: threadId }
});
if (!result) {
return null;
}
return {
...result,
createdAt: typeof result.createdAt === "string" ? new Date(result.createdAt) : result.createdAt,
updatedAt: typeof result.updatedAt === "string" ? new Date(result.updatedAt) : result.updatedAt,
metadata: typeof result.metadata === "string" ? JSON.parse(result.metadata) : result.metadata
};
}
async getThreadsByResourceId({ resourceId }) {
const client = await this.getClient();
const result = await client.query(
`SELECT * FROM ${TABLE_THREADS} WHERE "resourceId" = $1`,
[resourceId]
);
return (result.rows ?? []).map((thread) => ({
id: thread.id,
resourceId: thread.resourceId,
title: thread.title,
createdAt: typeof thread.createdAt === "string" ? new Date(thread.createdAt) : thread.createdAt,
updatedAt: typeof thread.updatedAt === "string" ? new Date(thread.updatedAt) : thread.updatedAt,
metadata: typeof thread.metadata === "string" ? JSON.parse(thread.metadata) : thread.metadata
}));
}
async saveThread({ thread }) {
await this.insert({
tableName: TABLE_THREADS,
record: {
...thread,
metadata: JSON.stringify(thread.metadata)
}
});
return thread;
}
async updateThread({
id,
title,
metadata
}) {
const thread = await this.getThreadById({ threadId: id });
if (!thread) {
throw new Error(`Thread ${id} not found`);
}
const updatedThread = {
...thread,
title,
metadata: {
...thread.metadata,
...metadata
}
};
const client = await this.getClient();
await client.query(
`UPDATE ${TABLE_THREADS} SET title = $1, metadata = $2 WHERE id = $3`,
[title, JSON.stringify(updatedThread.metadata), id]
);
return updatedThread;
}
async deleteThread({ threadId }) {
const client = await this.getClient();
await client.query(
`DELETE FROM ${TABLE_THREADS} WHERE id = $1`,
[threadId]
);
await client.query(
`DELETE FROM ${TABLE_MESSAGES} WHERE thread_id = $1`,
[threadId]
);
}
parseRow(row) {
let content = row.content;
try {
content = JSON.parse(row.content);
} catch {
}
return {
id: row.id,
content,
role: row.role,
type: row.type,
createdAt: new Date(row.createdAt),
threadId: row.thread_id
};
}
async getMessages({ threadId, selectBy }) {
try {
const client = await this.getClient();
const messages = [];
const limit = typeof selectBy?.last === `number` ? selectBy.last : 40;
if (selectBy?.include?.length) {
const includeIds = selectBy.include.map((i) => i.id);
const maxPrev = Math.max(...selectBy.include.map((i) => i.withPreviousMessages || 0));
const maxNext = Math.max(...selectBy.include.map((i) => i.withNextMessages || 0));
const includeResult = await client.query(
`
WITH numbered_messages AS (
SELECT
id,
content,
role,
type,
"createdAt",
thread_id,
ROW_NUMBER() OVER (ORDER BY "createdAt" ASC) as row_num
FROM "${TABLE_MESSAGES}"
WHERE thread_id = $1
),
target_positions AS (
SELECT row_num as target_pos
FROM numbered_messages
WHERE id IN (${includeIds.map((_, i) => `$${i + 2}`).join(", ")})
)
SELECT DISTINCT m.*
FROM numbered_messages m
CROSS JOIN target_positions t
WHERE m.row_num BETWEEN (t.target_pos - $${includeIds.length + 2}) AND (t.target_pos + $${includeIds.length + 3})
ORDER BY m."createdAt" ASC
`,
[threadId, ...includeIds, maxPrev, maxNext]
);
if (includeResult.rows && includeResult.rows.length > 0) {
messages.push(...includeResult.rows.map((row) => this.parseRow(row)));
}
}
const excludeIds = messages.map((m) => m.id);
let remainingSql;
let remainingArgs;
if (excludeIds.length) {
remainingSql = `
SELECT
id,
content,
role,
type,
"createdAt",
thread_id
FROM "${TABLE_MESSAGES}"
WHERE thread_id = $1
AND id NOT IN (${excludeIds.map((_, i) => `$${i + 2}`).join(", ")})
ORDER BY "createdAt" DESC
LIMIT $${excludeIds.length + 2}
`;
remainingArgs = [threadId, ...excludeIds, limit];
} else {
remainingSql = `
SELECT
id,
content,
role,
type,
"createdAt",
thread_id
FROM "${TABLE_MESSAGES}"
WHERE thread_id = $1
ORDER BY "createdAt" DESC
LIMIT $2
`;
remainingArgs = [threadId, limit];
}
const remainingResult = await client.query(remainingSql, remainingArgs);
if (remainingResult.rows && remainingResult.rows.length > 0) {
messages.push(...remainingResult.rows.map((row) => this.parseRow(row)));
}
messages.sort((a, b) => a.createdAt.getTime() - b.createdAt.getTime());
return messages;
} catch (error) {
this.logger.error("Error getting messages:", error);
throw error;
}
}
async saveMessages({ messages }) {
if (messages.length === 0) return messages;
const client = await this.getClient();
try {
const threadId = messages[0]?.threadId;
if (!threadId) {
throw new Error("Thread ID is required");
}
await client.transaction(async (tx) => {
for (const message of messages) {
const time = message.createdAt || /* @__PURE__ */ new Date();
await tx.query(
`INSERT INTO ${TABLE_MESSAGES} (id, thread_id, content, role, type, "createdAt")
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (id) DO UPDATE SET
content = $3, role = $4, type = $5, "createdAt" = $6`,
[
message.id,
threadId,
typeof message.content === "object" ? JSON.stringify(message.content) : message.content,
message.role,
message.type,
time instanceof Date ? time.toISOString() : time
]
);
}
});
return messages;
} catch (error) {
this.logger.error("Failed to save messages in database: " + error?.message);
throw error;
}
}
transformEvalRow(row) {
const resultValue = typeof row.result === "string" ? JSON.parse(row.result) : row.result;
const testInfoValue = row.test_info ? typeof row.test_info === "string" ? JSON.parse(row.test_info) : row.test_info : void 0;
if (!resultValue || typeof resultValue !== "object" || !("score" in resultValue)) {
throw new Error(`Invalid MetricResult format: ${JSON.stringify(resultValue)}`);
}
return {
input: row.input,
output: row.output,
result: resultValue,
agentName: row.agent_name,
metricName: row.metric_name,
instructions: row.instructions,
testInfo: testInfoValue,
globalRunId: row.global_run_id,
runId: row.run_id,
createdAt: row.created_at
};
}
async getEvalsByAgentName(agentName, type) {
try {
const client = await this.getClient();
const baseQuery = `SELECT * FROM ${TABLE_EVALS} WHERE agent_name = $1`;
let typeCondition = "";
if (type === "test") {
typeCondition = " AND test_info IS NOT NULL AND test_info->>'testPath' IS NOT NULL";
} else if (type === "live") {
typeCondition = " AND (test_info IS NULL OR test_info->>'testPath' IS NULL)";
}
const result = await client.query(
`${baseQuery}${typeCondition} ORDER BY created_at DESC`,
[agentName]
);
return result.rows?.map((row) => this.transformEvalRow(row)) ?? [];
} catch (error) {
if (error instanceof Error && error.message.includes("no such table")) {
return [];
}
this.logger.error("Failed to get evals for the specified agent: " + error?.message);
throw error;
}
}
async getTraces({
name,
scope,
page,
perPage,
attributes
} = {
page: 0,
perPage: 100
}) {
const limit = perPage;
const offset = page * perPage;
const args = [];
const conditions = [];
if (name) {
conditions.push("name LIKE $" + (args.length + 1) + " || '%'");
args.push(name);
}
if (scope) {
conditions.push("scope = $" + (args.length + 1));
args.push(scope);
}
if (attributes) {
Object.keys(attributes).forEach((key) => {
conditions.push(`attributes->>'${key}' = $${args.length + 1}`);
});
}
if (attributes) {
for (const [_key, value] of Object.entries(attributes)) {
args.push(value);
}
}
const whereClause = conditions.length > 0 ? `WHERE ${conditions.join(" AND ")}` : "";
args.push(limit, offset);
const client = await this.getClient();
const result = await client.query(
`SELECT * FROM ${TABLE_TRACES} ${whereClause} ORDER BY "startTime" DESC LIMIT $${args.length - 1} OFFSET $${args.length}`,
args
);
if (!result.rows) {
return [];
}
return result.rows.map((row) => ({
id: row.id,
parentSpanId: row.parentSpanId,
traceId: row.traceId,
name: row.name,
scope: row.scope,
kind: row.kind,
status: safelyParseJSON(row.status),
events: safelyParseJSON(row.events),
links: safelyParseJSON(row.links),
attributes: safelyParseJSON(row.attributes),
startTime: row.startTime,
endTime: row.endTime,
other: safelyParseJSON(row.other),
createdAt: row.createdAt
}));
}
};
__name(_PGliteStore, "PGliteStore");
var PGliteStore = _PGliteStore;
export { PGliteStore as DefaultStorage, PGliteStore };
//# sourceMappingURL=index.js.map
//# sourceMappingURL=index.js.map