sql-talk
Version:
SQL Talk - 自然言語をSQLに変換するMCPサーバー(安全性保護・SSHトンネル対応) / SQL Talk - MCP Server for Natural Language to SQL conversion with safety guards and SSH tunnel support
229 lines (196 loc) • 7.17 kB
text/typescript
import { Client, ConnectConfig } from 'ssh2';
import { SshTunnelConfig } from '@/types/index.js';
import { logger } from '@/core/logger.js';
import { DatabaseError } from '@/core/errors.js';
import { readFileSync } from 'fs';
import { createServer, Server } from 'net';
export class SshTunnelManager {
private static instance: SshTunnelManager;
private tunnels: Map<string, { client: Client; server: Server; localPort: number }> = new Map();
private nextLocalPort = 10000; // ローカルポートの開始番号
private constructor() {}
public static getInstance(): SshTunnelManager {
if (!SshTunnelManager.instance) {
SshTunnelManager.instance = new SshTunnelManager();
}
return SshTunnelManager.instance;
}
/**
* SSHトンネルを作成して、ローカルポートを返す
*/
public async createTunnel(
tunnelId: string,
config: SshTunnelConfig
): Promise<{ localPort: number; cleanup: () => Promise<void> }> {
if (!config.enabled) {
throw new DatabaseError('SSH tunnel is not enabled');
}
// 既存のトンネルがある場合は再利用
const existing = this.tunnels.get(tunnelId);
if (existing) {
logger.info(`Reusing existing SSH tunnel: ${tunnelId} (port: ${existing.localPort})`);
return {
localPort: existing.localPort,
cleanup: () => this.closeTunnel(tunnelId)
};
}
const client = new Client();
const localPort = config.local_port || this.getNextLocalPort();
return new Promise((resolve, reject) => {
const connectConfig: ConnectConfig = {
host: config.host,
port: config.port,
username: config.username,
keepaliveInterval: config.keep_alive ? 20000 : 0,
keepaliveCountMax: config.keep_alive ? 3 : 0,
};
// 認証方法の設定
if (config.password) {
connectConfig.password = config.password;
} else if (config.private_key_path) {
try {
connectConfig.privateKey = readFileSync(config.private_key_path, 'utf8');
if (config.passphrase) {
connectConfig.passphrase = config.passphrase;
}
} catch (error) {
reject(new DatabaseError(`Failed to read private key file: ${config.private_key_path}`));
return;
}
} else if (config.private_key) {
connectConfig.privateKey = config.private_key;
if (config.passphrase) {
connectConfig.passphrase = config.passphrase;
}
} else {
reject(new DatabaseError('SSH authentication method not specified (password, private_key, or private_key_path required)'));
return;
}
// 接続エラーハンドリング
client.on('error', (err) => {
logger.error(`SSH tunnel error for ${tunnelId}:`, err);
this.tunnels.delete(tunnelId);
reject(new DatabaseError(`SSH tunnel connection failed: ${err.message}`));
});
// SSH接続
client.connect(connectConfig);
client.on('ready', () => {
logger.info(`SSH connection established for tunnel: ${tunnelId}`);
// ローカルサーバーを作成
const server = createServer((localConnection) => {
// ローカル接続が来たらSSH経由でリモートサーバーに接続
client.forwardOut(
localConnection.remoteAddress || '127.0.0.1',
localConnection.remotePort || 0,
config.remote_host,
config.remote_port,
(err: any, stream: any) => {
if (err) {
logger.error(`SSH forwardOut error for ${tunnelId}:`, err);
localConnection.destroy();
return;
}
// データを双方向に転送
localConnection.pipe(stream);
stream.pipe(localConnection);
localConnection.on('close', () => {
stream.end();
});
stream.on('close', () => {
localConnection.end();
});
localConnection.on('error', (err: any) => {
logger.debug(`Local connection error for ${tunnelId}:`, err);
stream.end();
});
stream.on('error', (err: any) => {
logger.debug(`Stream error for ${tunnelId}:`, err);
localConnection.end();
});
}
);
});
server.listen(localPort, '127.0.0.1', () => {
logger.info(`SSH tunnel established: localhost:${localPort} -> ${config.remote_host}:${config.remote_port} (via ${config.host}:${config.port})`);
// トンネル情報を保存
this.tunnels.set(tunnelId, { client, server, localPort });
resolve({
localPort,
cleanup: () => this.closeTunnel(tunnelId)
});
});
server.on('error', (err) => {
logger.error(`Local server error for ${tunnelId}:`, err);
client.end();
reject(new DatabaseError(`Local server error: ${err.message}`));
});
});
// 接続タイムアウト
setTimeout(() => {
if (!this.tunnels.has(tunnelId)) {
client.end();
reject(new DatabaseError(`SSH tunnel connection timeout for ${tunnelId}`));
}
}, 10000); // 10秒タイムアウト
});
}
/**
* SSHトンネルを閉じる
*/
public async closeTunnel(tunnelId: string): Promise<void> {
const tunnel = this.tunnels.get(tunnelId);
if (!tunnel) {
logger.warn(`SSH tunnel not found: ${tunnelId}`);
return;
}
try {
// サーバーを閉じる
await new Promise<void>((resolve) => {
tunnel.server.close(() => {
resolve();
});
});
// SSH接続を閉じる
tunnel.client.end();
this.tunnels.delete(tunnelId);
logger.info(`SSH tunnel closed: ${tunnelId} (port: ${tunnel.localPort})`);
} catch (error) {
logger.error(`Error closing SSH tunnel ${tunnelId}:`, error);
throw new DatabaseError(`Failed to close SSH tunnel: ${error}`);
}
}
/**
* すべてのSSHトンネルを閉じる
*/
public async closeAllTunnels(): Promise<void> {
const promises: Promise<void>[] = [];
for (const tunnelId of this.tunnels.keys()) {
promises.push(this.closeTunnel(tunnelId));
}
await Promise.all(promises);
logger.info('All SSH tunnels closed');
}
/**
* トンネルの状態を取得
*/
public getTunnelStatus(tunnelId: string): { exists: boolean; localPort?: number } {
const tunnel = this.tunnels.get(tunnelId);
return {
exists: !!tunnel,
localPort: tunnel?.localPort
};
}
/**
* 次の利用可能なローカルポートを取得
*/
private getNextLocalPort(): number {
return this.nextLocalPort++;
}
/**
* アクティブなトンネル数を取得
*/
public getActiveTunnelCount(): number {
return this.tunnels.size;
}
}
export const sshTunnelManager = SshTunnelManager.getInstance();