sql-talk
Version:
SQL Talk - 自然言語をSQLに変換するMCPサーバー(安全性保護・SSHトンネル対応) / SQL Talk - MCP Server for Natural Language to SQL conversion with safety guards and SSH tunnel support
305 lines (253 loc) • 9.73 kB
text/typescript
import { Pool } from 'pg';
import mysql from 'mysql2/promise';
import { DatabaseEngine, ConnectionConfig } from '@/types/index.js';
import { DatabaseError } from '@/core/errors.js';
import { logger } from '@/core/logger.js';
import { configManager } from '@/core/config.js';
import { sshTunnelManager } from './ssh-tunnel.js';
export interface DatabaseConnection {
query<T = any>(sql: string, params?: any[]): Promise<T[]>;
queryOne<T = any>(sql: string, params?: any[]): Promise<T | null>;
close(): Promise<void>;
getConnectionInfo(): { host: string; port: number; tunneled: boolean };
}
export class PostgreSQLConnection implements DatabaseConnection {
private pool!: Pool;
private connectionInfo: { host: string; port: number; tunneled: boolean };
private sshTunnelCleanup?: () => Promise<void>;
constructor(config: ConnectionConfig) {
this.connectionInfo = {
host: config.host,
port: config.port,
tunneled: false
};
this.initializeConnection(config);
}
private async initializeConnection(config: ConnectionConfig): Promise<void> {
let actualHost = config.host;
let actualPort = config.port;
// SSHトンネルが設定されている場合
if (config.ssh_tunnel?.enabled) {
try {
const tunnelId = `postgres-${config.host}-${config.port}-${Date.now()}`;
const { localPort, cleanup } = await sshTunnelManager.createTunnel(tunnelId, config.ssh_tunnel);
actualHost = 'localhost';
actualPort = localPort;
this.sshTunnelCleanup = cleanup;
this.connectionInfo = {
host: actualHost,
port: actualPort,
tunneled: true
};
logger.info(`PostgreSQL connection via SSH tunnel: ${config.host}:${config.port} -> localhost:${localPort}`);
} catch (error) {
logger.error('Failed to create SSH tunnel for PostgreSQL:', error);
throw new DatabaseError(`SSH tunnel creation failed: ${error}`);
}
}
this.pool = new Pool({
host: actualHost,
port: actualPort,
user: config.user,
password: config.password,
database: config.database,
ssl: config.ssl ? { rejectUnauthorized: false } : false,
max: 10,
idleTimeoutMillis: 30000,
connectionTimeoutMillis: 5000,
});
this.pool.on('error', (err) => {
logger.error('PostgreSQL pool error:', err);
});
}
async query<T = any>(sql: string, params?: any[]): Promise<T[]> {
const client = await this.pool.connect();
try {
const result = await client.query(sql, params);
return result.rows as T[];
} catch (error) {
logger.error('PostgreSQL query error:', { sql, params, error });
throw new DatabaseError(`PostgreSQL query failed: ${error}`, { sql, params });
} finally {
client.release();
}
}
async queryOne<T = any>(sql: string, params?: any[]): Promise<T | null> {
const results = await this.query<T>(sql, params);
return results.length > 0 ? results[0] : null;
}
async close(): Promise<void> {
await this.pool.end();
// SSHトンネルを閉じる
if (this.sshTunnelCleanup) {
await this.sshTunnelCleanup();
}
}
getConnectionInfo(): { host: string; port: number; tunneled: boolean } {
return this.connectionInfo;
}
}
export class MySQLConnection implements DatabaseConnection {
private pool!: mysql.Pool;
private connectionInfo: { host: string; port: number; tunneled: boolean };
private sshTunnelCleanup?: () => Promise<void>;
constructor(config: ConnectionConfig) {
this.connectionInfo = {
host: config.host,
port: config.port,
tunneled: false
};
this.initializeConnection(config);
}
private async initializeConnection(config: ConnectionConfig): Promise<void> {
let actualHost = config.host;
let actualPort = config.port;
// SSHトンネルが設定されている場合
if (config.ssh_tunnel?.enabled) {
try {
const tunnelId = `mysql-${config.host}-${config.port}-${Date.now()}`;
const { localPort, cleanup } = await sshTunnelManager.createTunnel(tunnelId, config.ssh_tunnel);
actualHost = 'localhost';
actualPort = localPort;
this.sshTunnelCleanup = cleanup;
this.connectionInfo = {
host: actualHost,
port: actualPort,
tunneled: true
};
logger.info(`MySQL connection via SSH tunnel: ${config.host}:${config.port} -> localhost:${localPort}`);
} catch (error) {
logger.error('Failed to create SSH tunnel for MySQL:', error);
throw new DatabaseError(`SSH tunnel creation failed: ${error}`);
}
}
const poolConfig: mysql.PoolOptions = {
host: actualHost,
port: actualPort,
user: config.user,
password: config.password,
database: config.database,
connectionLimit: 10,
};
if (config.ssl) {
poolConfig.ssl = { rejectUnauthorized: false };
}
this.pool = mysql.createPool(poolConfig);
}
async query<T = any>(sql: string, params?: any[]): Promise<T[]> {
try {
const [rows] = await this.pool.execute(sql, params);
return rows as T[];
} catch (error) {
logger.error('MySQL query error:', { sql, params, error });
throw new DatabaseError(`MySQL query failed: ${error}`, { sql, params });
}
}
async queryOne<T = any>(sql: string, params?: any[]): Promise<T | null> {
const results = await this.query<T>(sql, params);
return results.length > 0 ? results[0] : null;
}
async close(): Promise<void> {
await this.pool.end();
// SSHトンネルを閉じる
if (this.sshTunnelCleanup) {
await this.sshTunnelCleanup();
}
}
getConnectionInfo(): { host: string; port: number; tunneled: boolean } {
return this.connectionInfo;
}
}
export class ConnectionManager {
private static instance: ConnectionManager;
private readOnlyConnection: DatabaseConnection | null = null;
private ddlCommentConnection: DatabaseConnection | null = null;
private engine: DatabaseEngine | null = null;
private constructor() {}
public static getInstance(): ConnectionManager {
if (!ConnectionManager.instance) {
ConnectionManager.instance = new ConnectionManager();
}
return ConnectionManager.instance;
}
public async initialize(): Promise<void> {
const config = configManager.getConfig();
this.engine = config.engine;
logger.info(`Initializing database connections for ${this.engine}`);
// Initialize read-only connection
if (configManager.validateConnection('read_only')) {
this.readOnlyConnection = this.createConnection(config.connections.read_only);
await this.testConnection(this.readOnlyConnection, 'read_only');
} else {
throw new DatabaseError('Invalid read-only connection configuration');
}
// Initialize DDL comment connection
if (configManager.validateConnection('ddl_comment')) {
this.ddlCommentConnection = this.createConnection(config.connections.ddl_comment);
await this.testConnection(this.ddlCommentConnection, 'ddl_comment');
} else {
logger.warn('DDL comment connection not configured - comment application will be disabled');
}
logger.info('Database connections initialized successfully');
}
private createConnection(config: ConnectionConfig): DatabaseConnection {
switch (this.engine) {
case 'postgres':
return new PostgreSQLConnection(config);
case 'mysql':
return new MySQLConnection(config);
default:
throw new DatabaseError(`Unsupported database engine: ${this.engine}`);
}
}
private async testConnection(connection: DatabaseConnection, name: string): Promise<void> {
try {
const testSql = this.engine === 'postgres' ? 'SELECT 1 as test' : 'SELECT 1 as test';
const result = await connection.queryOne(testSql);
if (result && result.test === 1) {
const info = connection.getConnectionInfo();
logger.info(`${name} connection test successful (${info.host}:${info.port}${info.tunneled ? ' via SSH tunnel' : ''})`);
} else {
throw new Error('Connection test returned unexpected result');
}
} catch (error) {
logger.error(`${name} connection test failed:`, error);
throw new DatabaseError(`Failed to connect to database (${name}): ${error}`);
}
}
public getReadOnlyConnection(): DatabaseConnection {
if (!this.readOnlyConnection) {
throw new DatabaseError('Read-only connection not initialized');
}
return this.readOnlyConnection;
}
public getDDLCommentConnection(): DatabaseConnection {
if (!this.ddlCommentConnection) {
throw new DatabaseError('DDL comment connection not initialized');
}
return this.ddlCommentConnection;
}
public getEngine(): DatabaseEngine {
if (!this.engine) {
throw new DatabaseError('Database engine not initialized');
}
return this.engine;
}
public async close(): Promise<void> {
const promises: Promise<void>[] = [];
if (this.readOnlyConnection) {
promises.push(this.readOnlyConnection.close());
}
if (this.ddlCommentConnection) {
promises.push(this.ddlCommentConnection.close());
}
await Promise.all(promises);
// SSHトンネルもすべて閉じる
await sshTunnelManager.closeAllTunnels();
this.readOnlyConnection = null;
this.ddlCommentConnection = null;
this.engine = null;
logger.info('Database connections closed');
}
}
export const connectionManager = ConnectionManager.getInstance();