UNPKG

sql-talk

Version:

SQL Talk - 自然言語をSQLに変換するMCPサーバー(安全性保護・SSHトンネル対応) / SQL Talk - MCP Server for Natural Language to SQL conversion with safety guards and SSH tunnel support

305 lines (253 loc) 9.73 kB
import { Pool } from 'pg'; import mysql from 'mysql2/promise'; import { DatabaseEngine, ConnectionConfig } from '@/types/index.js'; import { DatabaseError } from '@/core/errors.js'; import { logger } from '@/core/logger.js'; import { configManager } from '@/core/config.js'; import { sshTunnelManager } from './ssh-tunnel.js'; export interface DatabaseConnection { query<T = any>(sql: string, params?: any[]): Promise<T[]>; queryOne<T = any>(sql: string, params?: any[]): Promise<T | null>; close(): Promise<void>; getConnectionInfo(): { host: string; port: number; tunneled: boolean }; } export class PostgreSQLConnection implements DatabaseConnection { private pool!: Pool; private connectionInfo: { host: string; port: number; tunneled: boolean }; private sshTunnelCleanup?: () => Promise<void>; constructor(config: ConnectionConfig) { this.connectionInfo = { host: config.host, port: config.port, tunneled: false }; this.initializeConnection(config); } private async initializeConnection(config: ConnectionConfig): Promise<void> { let actualHost = config.host; let actualPort = config.port; // SSHトンネルが設定されている場合 if (config.ssh_tunnel?.enabled) { try { const tunnelId = `postgres-${config.host}-${config.port}-${Date.now()}`; const { localPort, cleanup } = await sshTunnelManager.createTunnel(tunnelId, config.ssh_tunnel); actualHost = 'localhost'; actualPort = localPort; this.sshTunnelCleanup = cleanup; this.connectionInfo = { host: actualHost, port: actualPort, tunneled: true }; logger.info(`PostgreSQL connection via SSH tunnel: ${config.host}:${config.port} -> localhost:${localPort}`); } catch (error) { logger.error('Failed to create SSH tunnel for PostgreSQL:', error); throw new DatabaseError(`SSH tunnel creation failed: ${error}`); } } this.pool = new Pool({ host: actualHost, port: actualPort, user: config.user, password: config.password, database: config.database, ssl: config.ssl ? { rejectUnauthorized: false } : false, max: 10, idleTimeoutMillis: 30000, connectionTimeoutMillis: 5000, }); this.pool.on('error', (err) => { logger.error('PostgreSQL pool error:', err); }); } async query<T = any>(sql: string, params?: any[]): Promise<T[]> { const client = await this.pool.connect(); try { const result = await client.query(sql, params); return result.rows as T[]; } catch (error) { logger.error('PostgreSQL query error:', { sql, params, error }); throw new DatabaseError(`PostgreSQL query failed: ${error}`, { sql, params }); } finally { client.release(); } } async queryOne<T = any>(sql: string, params?: any[]): Promise<T | null> { const results = await this.query<T>(sql, params); return results.length > 0 ? results[0] : null; } async close(): Promise<void> { await this.pool.end(); // SSHトンネルを閉じる if (this.sshTunnelCleanup) { await this.sshTunnelCleanup(); } } getConnectionInfo(): { host: string; port: number; tunneled: boolean } { return this.connectionInfo; } } export class MySQLConnection implements DatabaseConnection { private pool!: mysql.Pool; private connectionInfo: { host: string; port: number; tunneled: boolean }; private sshTunnelCleanup?: () => Promise<void>; constructor(config: ConnectionConfig) { this.connectionInfo = { host: config.host, port: config.port, tunneled: false }; this.initializeConnection(config); } private async initializeConnection(config: ConnectionConfig): Promise<void> { let actualHost = config.host; let actualPort = config.port; // SSHトンネルが設定されている場合 if (config.ssh_tunnel?.enabled) { try { const tunnelId = `mysql-${config.host}-${config.port}-${Date.now()}`; const { localPort, cleanup } = await sshTunnelManager.createTunnel(tunnelId, config.ssh_tunnel); actualHost = 'localhost'; actualPort = localPort; this.sshTunnelCleanup = cleanup; this.connectionInfo = { host: actualHost, port: actualPort, tunneled: true }; logger.info(`MySQL connection via SSH tunnel: ${config.host}:${config.port} -> localhost:${localPort}`); } catch (error) { logger.error('Failed to create SSH tunnel for MySQL:', error); throw new DatabaseError(`SSH tunnel creation failed: ${error}`); } } const poolConfig: mysql.PoolOptions = { host: actualHost, port: actualPort, user: config.user, password: config.password, database: config.database, connectionLimit: 10, }; if (config.ssl) { poolConfig.ssl = { rejectUnauthorized: false }; } this.pool = mysql.createPool(poolConfig); } async query<T = any>(sql: string, params?: any[]): Promise<T[]> { try { const [rows] = await this.pool.execute(sql, params); return rows as T[]; } catch (error) { logger.error('MySQL query error:', { sql, params, error }); throw new DatabaseError(`MySQL query failed: ${error}`, { sql, params }); } } async queryOne<T = any>(sql: string, params?: any[]): Promise<T | null> { const results = await this.query<T>(sql, params); return results.length > 0 ? results[0] : null; } async close(): Promise<void> { await this.pool.end(); // SSHトンネルを閉じる if (this.sshTunnelCleanup) { await this.sshTunnelCleanup(); } } getConnectionInfo(): { host: string; port: number; tunneled: boolean } { return this.connectionInfo; } } export class ConnectionManager { private static instance: ConnectionManager; private readOnlyConnection: DatabaseConnection | null = null; private ddlCommentConnection: DatabaseConnection | null = null; private engine: DatabaseEngine | null = null; private constructor() {} public static getInstance(): ConnectionManager { if (!ConnectionManager.instance) { ConnectionManager.instance = new ConnectionManager(); } return ConnectionManager.instance; } public async initialize(): Promise<void> { const config = configManager.getConfig(); this.engine = config.engine; logger.info(`Initializing database connections for ${this.engine}`); // Initialize read-only connection if (configManager.validateConnection('read_only')) { this.readOnlyConnection = this.createConnection(config.connections.read_only); await this.testConnection(this.readOnlyConnection, 'read_only'); } else { throw new DatabaseError('Invalid read-only connection configuration'); } // Initialize DDL comment connection if (configManager.validateConnection('ddl_comment')) { this.ddlCommentConnection = this.createConnection(config.connections.ddl_comment); await this.testConnection(this.ddlCommentConnection, 'ddl_comment'); } else { logger.warn('DDL comment connection not configured - comment application will be disabled'); } logger.info('Database connections initialized successfully'); } private createConnection(config: ConnectionConfig): DatabaseConnection { switch (this.engine) { case 'postgres': return new PostgreSQLConnection(config); case 'mysql': return new MySQLConnection(config); default: throw new DatabaseError(`Unsupported database engine: ${this.engine}`); } } private async testConnection(connection: DatabaseConnection, name: string): Promise<void> { try { const testSql = this.engine === 'postgres' ? 'SELECT 1 as test' : 'SELECT 1 as test'; const result = await connection.queryOne(testSql); if (result && result.test === 1) { const info = connection.getConnectionInfo(); logger.info(`${name} connection test successful (${info.host}:${info.port}${info.tunneled ? ' via SSH tunnel' : ''})`); } else { throw new Error('Connection test returned unexpected result'); } } catch (error) { logger.error(`${name} connection test failed:`, error); throw new DatabaseError(`Failed to connect to database (${name}): ${error}`); } } public getReadOnlyConnection(): DatabaseConnection { if (!this.readOnlyConnection) { throw new DatabaseError('Read-only connection not initialized'); } return this.readOnlyConnection; } public getDDLCommentConnection(): DatabaseConnection { if (!this.ddlCommentConnection) { throw new DatabaseError('DDL comment connection not initialized'); } return this.ddlCommentConnection; } public getEngine(): DatabaseEngine { if (!this.engine) { throw new DatabaseError('Database engine not initialized'); } return this.engine; } public async close(): Promise<void> { const promises: Promise<void>[] = []; if (this.readOnlyConnection) { promises.push(this.readOnlyConnection.close()); } if (this.ddlCommentConnection) { promises.push(this.ddlCommentConnection.close()); } await Promise.all(promises); // SSHトンネルもすべて閉じる await sshTunnelManager.closeAllTunnels(); this.readOnlyConnection = null; this.ddlCommentConnection = null; this.engine = null; logger.info('Database connections closed'); } } export const connectionManager = ConnectionManager.getInstance();