@astreus-ai/astreus
Version:
AI Agent Framework with Chat Management
470 lines (414 loc) • 15 kB
text/typescript
import { Knex } from "knex";
import fs from "fs";
import path from "path";
import dotenv from "dotenv";
import {
DatabaseConfig,
DatabaseInstance,
DatabaseFactory,
TableOperations,
TableNamesConfig,
} from "./types";
import { createSqliteDatabase } from "./database/sqlite";
import { createPostgresqlDatabase } from "./database/postgresql";
import { v4 as uuidv4 } from "uuid";
import { logger } from "./utils";
import { validateRequiredParam, validateRequiredParams } from "./utils/validation";
import { DEFAULT_DB_PATH } from "./constants";
// Load environment variables
dotenv.config();
// Re-export configuration helpers
export { createSqliteConfig } from "./database/sqlite";
export { createPostgresqlConfig } from "./database/postgresql";
// Re-export database modules
export {
createUser,
getUserById,
getUserByUsername,
updateUser,
deleteUser
} from "./database/modules/user";
// Re-export types
export { DatabaseInstance, DatabaseConfig } from "./types/database";
/**
* Core database implementation that provides storage functionality
* for the Astreus framework. Supports multiple database backends.
*/
class Database implements DatabaseInstance {
public knex: Knex;
public config: DatabaseConfig;
private initialized: boolean = false;
private tableNames: Required<TableNamesConfig>;
private customTables: Map<string, string> = new Map();
constructor(config: DatabaseConfig) {
// Validate required parameters
validateRequiredParam(config, "config", "Database constructor");
validateRequiredParams(
config,
["type"],
"Database constructor"
);
this.config = config;
// Apply table prefix if specified
const prefix = config.tablePrefix || '';
// Set table names with defaults and prefix
this.tableNames = {
agents: prefix + (config.tableNames?.agents || 'agents'),
users: prefix + (config.tableNames?.users || 'users'),
tasks: prefix + (config.tableNames?.tasks || 'tasks'),
memories: prefix + (config.tableNames?.memories || 'memories'),
chats: prefix + (config.tableNames?.chats || 'chats'),
custom: config.tableNames?.custom || {}
};
// Register custom tables
if (config.tableNames?.custom) {
Object.entries(config.tableNames.custom).forEach(([name, tableName]) => {
this.customTables.set(name, prefix + tableName);
});
}
// Initialize knex instance based on database type
if (config.type === "sqlite") {
this.knex = createSqliteDatabase(config);
} else if (config.type === "postgresql") {
this.knex = createPostgresqlDatabase(config);
} else {
throw new Error(`Unsupported database type: ${config.type}`);
}
}
/**
* Connect to the database and verify the connection
* @throws Error if connection fails
*/
async connect(): Promise<void> {
try {
// Test the connection
await this.knex.raw("SELECT 1");
logger.database("Connect", `Connected to ${this.config.type} database`);
} catch (error) {
logger.error(`Error connecting to ${this.config.type} database:`, error);
throw error;
}
}
/**
* Gracefully disconnect from the database
*/
async disconnect(): Promise<void> {
await this.knex.destroy();
logger.database("Disconnect", `Disconnected from ${this.config.type} database`);
}
/**
* Execute a raw SQL query against the database
* @param query The SQL query to execute
* @param params Parameters to bind to the query
* @returns Results of the query
*/
async executeQuery<T = any>(query: string, params: any[] = []): Promise<T[]> {
// Validate required parameters
validateRequiredParam(query, "query", "executeQuery");
try {
return this.knex.raw(query, params) as Promise<T[]>;
} catch (error) {
logger.error("Error executing query:", error);
throw error;
}
}
/**
* Check if a table exists
* @param tableName Name of the table to check
* @returns Promise resolving to boolean indicating if table exists
*/
async hasTable(tableName: string): Promise<boolean> {
validateRequiredParam(tableName, "tableName", "hasTable");
return await this.knex.schema.hasTable(tableName);
}
/**
* Create a table with the given schema
* @param tableName Name of the table to create
* @param schema Function that defines the table schema
*/
async createTable(tableName: string, schema: (table: Knex.TableBuilder) => void): Promise<void> {
validateRequiredParam(tableName, "tableName", "createTable");
validateRequiredParam(schema, "schema", "createTable");
try {
await this.knex.schema.createTable(tableName, schema);
logger.database("CreateTable", `Created table: ${tableName}`);
} catch (error) {
logger.error(`Error creating table ${tableName}:`, error);
throw error;
}
}
/**
* Drop a table if it exists
* @param tableName Name of the table to drop
*/
async dropTable(tableName: string): Promise<void> {
validateRequiredParam(tableName, "tableName", "dropTable");
try {
await this.knex.schema.dropTableIfExists(tableName);
logger.database("DropTable", `Dropped table: ${tableName}`);
} catch (error) {
logger.error(`Error dropping table ${tableName}:`, error);
throw error;
}
}
/**
* Ensure a table exists, create it if it doesn't
* @param tableName Name of the table to ensure
* @param schema Function that defines the table schema
*/
async ensureTable(tableName: string, schema: (table: Knex.TableBuilder) => void): Promise<void> {
validateRequiredParam(tableName, "tableName", "ensureTable");
validateRequiredParam(schema, "schema", "ensureTable");
const exists = await this.hasTable(tableName);
if (!exists) {
await this.createTable(tableName, schema);
} else {
logger.database("EnsureTable", `Table ${tableName} already exists`);
}
}
/**
* Register a custom table name mapping
* @param name Logical name for the table
* @param tableName Actual table name in database
*/
registerCustomTable(name: string, tableName: string): void {
validateRequiredParam(name, "name", "registerCustomTable");
validateRequiredParam(tableName, "tableName", "registerCustomTable");
const prefix = this.config.tablePrefix || '';
this.customTables.set(name, prefix + tableName);
logger.database("RegisterCustomTable", `Registered custom table: ${name} -> ${prefix + tableName}`);
}
/**
* Get the actual table name for a custom table
* @param name Logical name of the custom table
* @returns Actual table name or undefined if not found
*/
getCustomTableName(name: string): string | undefined {
return this.customTables.get(name);
}
/**
* Initialize database schema - only handles legacy migrations now
* Each module (createMemory, createChat, etc.) is responsible for creating its own tables
*/
async initializeSchema(): Promise<void> {
try {
// Only handle legacy migrations, no auto table creation
await this.migrateLegacyTables();
// Mark as initialized
this.initialized = true;
logger.database("InitializeSchema", "Database schema initialization complete (legacy migrations only)");
} catch (error) {
logger.error("Error initializing database schema:", error);
throw error;
}
}
/**
* Migrate legacy tables and remove deprecated ones
* @param memoriesTableName Optional table name for memories (defaults to configured name)
*/
private async migrateLegacyTables(memoriesTableName?: string): Promise<void> {
const memoryTable = memoriesTableName || this.tableNames.memories;
// Check for task_contexts table (deprecated) and migrate data
const hasTaskContextsTable = await this.knex.schema.hasTable("task_contexts");
if (hasTaskContextsTable) {
try {
const contextRecords = await this.knex("task_contexts").select("*");
if (contextRecords.length > 0) {
logger.database("InitializeSchema", `Migrating ${contextRecords.length} task contexts to memory system...`);
// Check if memories table exists before migration
const hasMemoriesTable = await this.knex.schema.hasTable(memoryTable);
if (!hasMemoriesTable) {
logger.warn(`Memories table '${memoryTable}' does not exist. Skipping task contexts migration.`);
} else {
// Batch insert to memories table
const memoryRecords = contextRecords.map((record: any) => ({
id: uuidv4(),
agentId: "system",
sessionId: record.sessionId,
userId: "",
role: "task_context",
content: record.data,
timestamp: record.updatedAt || new Date(),
metadata: JSON.stringify({
contextType: "task_execution_context",
migratedFrom: "task_contexts",
}),
}));
await this.knex(memoryTable).insert(memoryRecords);
logger.database("InitializeSchema", "Task contexts migration completed successfully");
}
}
} catch (migrationError) {
logger.error("Error migrating task contexts:", migrationError);
}
// Drop the deprecated table
await this.knex.schema.dropTable("task_contexts");
logger.database("InitializeSchema", "Dropped deprecated task_contexts table");
}
}
/**
* Get operations interface for a specific table
* @param tableName Name of the table to operate on
* @returns Table operations interface
*/
getTable(tableName: string): TableOperations {
// Validate required parameters
validateRequiredParam(tableName, "tableName", "getTable");
const knexInstance = this.knex;
return {
/**
* Insert data into table
*/
async insert(data: Record<string, any>): Promise<number | string> {
// Validate required parameters
validateRequiredParam(data, "data", "insert");
try {
const result = await knexInstance(tableName).insert(data);
return result[0];
} catch (error) {
logger.error(`Error inserting into ${tableName}:`, error);
throw error;
}
},
/**
* Find records in table
*/
async find(filter?: Record<string, any>): Promise<Record<string, any>[]> {
try {
let query = knexInstance(tableName);
if (filter) {
query = query.where(filter);
}
return query.select("*");
} catch (error) {
logger.error(`Error finding in ${tableName}:`, error);
throw error;
}
},
/**
* Find one record in table
*/
async findOne(
filter: Record<string, any>
): Promise<Record<string, any> | null> {
// Validate required parameters
validateRequiredParam(filter, "filter", "findOne");
try {
const result = await knexInstance(tableName).where(filter).first();
return result || null;
} catch (error) {
logger.error(`Error finding one in ${tableName}:`, error);
throw error;
}
},
/**
* Update records in table
*/
async update(
filter: Record<string, any>,
data: Record<string, any>
): Promise<number> {
// Validate required parameters
validateRequiredParam(filter, "filter", "update");
validateRequiredParam(data, "data", "update");
try {
return await knexInstance(tableName).where(filter).update(data);
} catch (error) {
logger.error(`Error updating in ${tableName}:`, error);
throw error;
}
},
/**
* Delete records from table
*/
async delete(filter: Record<string, any>): Promise<number> {
// Validate required parameters
validateRequiredParam(filter, "filter", "delete");
try {
return await knexInstance(tableName).where(filter).delete();
} catch (error) {
logger.error(`Error deleting from ${tableName}:`, error);
throw error;
}
},
};
}
/**
* Check if the database has been initialized
*/
isInitialized(): boolean {
return this.initialized;
}
/**
* Get configured table names
* @returns Object containing all configured table names
*/
getTableNames(): Required<TableNamesConfig> {
return this.tableNames;
}
}
// Database factory function
export const createDatabase: DatabaseFactory = async (
config?: DatabaseConfig
) => {
// If no config is provided, create a default one
if (!config) {
// Determine which database to use based on environment variables
const dbType = process.env.DATABASE_TYPE || "sqlite";
if (dbType === "sqlite") {
// For SQLite, create a default file-based database
const dbPath = process.env.DATABASE_PATH || DEFAULT_DB_PATH;
// Create database directory if it doesn't exist (for file-based SQLite)
const dir = path.dirname(dbPath);
if (!fs.existsSync(dir)) {
fs.mkdirSync(dir, { recursive: true });
logger.database("CreateDatabase", `Created database directory: ${dir}`);
}
config = {
type: "sqlite",
connection: dbPath,
};
logger.database("CreateDatabase", `Using SQLite database at ${dbPath}`);
} else if (dbType === "postgresql") {
// For PostgreSQL, use connection URL
if (process.env.DATABASE_URL) {
// Parse connection string
const url = new URL(process.env.DATABASE_URL);
const host = url.hostname;
const port = parseInt(url.port || "5432");
const user = url.username;
const password = url.password;
const database = url.pathname.substring(1); // Remove leading slash
config = {
type: "postgresql",
connection: {
host,
port,
user,
password,
database,
},
};
logger.database("CreateDatabase", `Using PostgreSQL database from URL: ${host}:${port}/${database}`);
} else {
throw new Error("PostgreSQL connection requires DATABASE_URL environment variable");
}
} else {
throw new Error(`Unsupported database type: ${dbType}`);
}
} else {
// Validate the provided config
validateRequiredParams(
config,
["type"],
"createDatabase"
);
}
// Create a new database instance
const db = new Database(config);
// Connect to the database
await db.connect();
// Only run legacy migrations, no auto table creation
await db.initializeSchema();
return db;
};