UNPKG

sql-talk

Version:

SQL Talk - 自然言語をSQLに変換するMCPサーバー(安全性保護・SSHトンネル対応) / SQL Talk - MCP Server for Natural Language to SQL conversion with safety guards and SSH tunnel support

229 lines (196 loc) 7.17 kB
import { Client, ConnectConfig } from 'ssh2'; import { SshTunnelConfig } from '@/types/index.js'; import { logger } from '@/core/logger.js'; import { DatabaseError } from '@/core/errors.js'; import { readFileSync } from 'fs'; import { createServer, Server } from 'net'; export class SshTunnelManager { private static instance: SshTunnelManager; private tunnels: Map<string, { client: Client; server: Server; localPort: number }> = new Map(); private nextLocalPort = 10000; // ローカルポートの開始番号 private constructor() {} public static getInstance(): SshTunnelManager { if (!SshTunnelManager.instance) { SshTunnelManager.instance = new SshTunnelManager(); } return SshTunnelManager.instance; } /** * SSHトンネルを作成して、ローカルポートを返す */ public async createTunnel( tunnelId: string, config: SshTunnelConfig ): Promise<{ localPort: number; cleanup: () => Promise<void> }> { if (!config.enabled) { throw new DatabaseError('SSH tunnel is not enabled'); } // 既存のトンネルがある場合は再利用 const existing = this.tunnels.get(tunnelId); if (existing) { logger.info(`Reusing existing SSH tunnel: ${tunnelId} (port: ${existing.localPort})`); return { localPort: existing.localPort, cleanup: () => this.closeTunnel(tunnelId) }; } const client = new Client(); const localPort = config.local_port || this.getNextLocalPort(); return new Promise((resolve, reject) => { const connectConfig: ConnectConfig = { host: config.host, port: config.port, username: config.username, keepaliveInterval: config.keep_alive ? 20000 : 0, keepaliveCountMax: config.keep_alive ? 3 : 0, }; // 認証方法の設定 if (config.password) { connectConfig.password = config.password; } else if (config.private_key_path) { try { connectConfig.privateKey = readFileSync(config.private_key_path, 'utf8'); if (config.passphrase) { connectConfig.passphrase = config.passphrase; } } catch (error) { reject(new DatabaseError(`Failed to read private key file: ${config.private_key_path}`)); return; } } else if (config.private_key) { connectConfig.privateKey = config.private_key; if (config.passphrase) { connectConfig.passphrase = config.passphrase; } } else { reject(new DatabaseError('SSH authentication method not specified (password, private_key, or private_key_path required)')); return; } // 接続エラーハンドリング client.on('error', (err) => { logger.error(`SSH tunnel error for ${tunnelId}:`, err); this.tunnels.delete(tunnelId); reject(new DatabaseError(`SSH tunnel connection failed: ${err.message}`)); }); // SSH接続 client.connect(connectConfig); client.on('ready', () => { logger.info(`SSH connection established for tunnel: ${tunnelId}`); // ローカルサーバーを作成 const server = createServer((localConnection) => { // ローカル接続が来たらSSH経由でリモートサーバーに接続 client.forwardOut( localConnection.remoteAddress || '127.0.0.1', localConnection.remotePort || 0, config.remote_host, config.remote_port, (err: any, stream: any) => { if (err) { logger.error(`SSH forwardOut error for ${tunnelId}:`, err); localConnection.destroy(); return; } // データを双方向に転送 localConnection.pipe(stream); stream.pipe(localConnection); localConnection.on('close', () => { stream.end(); }); stream.on('close', () => { localConnection.end(); }); localConnection.on('error', (err: any) => { logger.debug(`Local connection error for ${tunnelId}:`, err); stream.end(); }); stream.on('error', (err: any) => { logger.debug(`Stream error for ${tunnelId}:`, err); localConnection.end(); }); } ); }); server.listen(localPort, '127.0.0.1', () => { logger.info(`SSH tunnel established: localhost:${localPort} -> ${config.remote_host}:${config.remote_port} (via ${config.host}:${config.port})`); // トンネル情報を保存 this.tunnels.set(tunnelId, { client, server, localPort }); resolve({ localPort, cleanup: () => this.closeTunnel(tunnelId) }); }); server.on('error', (err) => { logger.error(`Local server error for ${tunnelId}:`, err); client.end(); reject(new DatabaseError(`Local server error: ${err.message}`)); }); }); // 接続タイムアウト setTimeout(() => { if (!this.tunnels.has(tunnelId)) { client.end(); reject(new DatabaseError(`SSH tunnel connection timeout for ${tunnelId}`)); } }, 10000); // 10秒タイムアウト }); } /** * SSHトンネルを閉じる */ public async closeTunnel(tunnelId: string): Promise<void> { const tunnel = this.tunnels.get(tunnelId); if (!tunnel) { logger.warn(`SSH tunnel not found: ${tunnelId}`); return; } try { // サーバーを閉じる await new Promise<void>((resolve) => { tunnel.server.close(() => { resolve(); }); }); // SSH接続を閉じる tunnel.client.end(); this.tunnels.delete(tunnelId); logger.info(`SSH tunnel closed: ${tunnelId} (port: ${tunnel.localPort})`); } catch (error) { logger.error(`Error closing SSH tunnel ${tunnelId}:`, error); throw new DatabaseError(`Failed to close SSH tunnel: ${error}`); } } /** * すべてのSSHトンネルを閉じる */ public async closeAllTunnels(): Promise<void> { const promises: Promise<void>[] = []; for (const tunnelId of this.tunnels.keys()) { promises.push(this.closeTunnel(tunnelId)); } await Promise.all(promises); logger.info('All SSH tunnels closed'); } /** * トンネルの状態を取得 */ public getTunnelStatus(tunnelId: string): { exists: boolean; localPort?: number } { const tunnel = this.tunnels.get(tunnelId); return { exists: !!tunnel, localPort: tunnel?.localPort }; } /** * 次の利用可能なローカルポートを取得 */ private getNextLocalPort(): number { return this.nextLocalPort++; } /** * アクティブなトンネル数を取得 */ public getActiveTunnelCount(): number { return this.tunnels.size; } } export const sshTunnelManager = SshTunnelManager.getInstance();