postgrejs
Version:
Professional PostgreSQL client NodeJS
363 lines (362 loc) • 14.5 kB
JavaScript
import crypto from 'node:crypto';
import net from 'node:net';
import path from 'node:path';
import tls from 'node:tls';
import promisify from 'putil-promisify';
import { ConnectionState } from '../constants.js';
import { SafeEventEmitter } from '../safe-event-emitter.js';
import { Backend } from './backend.js';
import { DatabaseError } from './database-error.js';
import { Frontend } from './frontend.js';
import { Protocol } from './protocol.js';
import { SASL } from './sasl.js';
const DEFAULT_PORT_NUMBER = 5432;
const COMMAND_RESULT_PATTERN = /^([^\d]+)(?: (\d+)(?: (\d+))?)?$/;
export class PgSocket extends SafeEventEmitter {
constructor(options) {
super();
this.options = options;
this._state = ConnectionState.CLOSED;
this._backend = new Backend();
this._sessionParameters = {};
this._frontend = new Frontend({ buffer: options.buffer });
this.setMaxListeners(99);
}
get state() {
if (!this._socket || this._socket.destroyed)
this._state = ConnectionState.CLOSED;
return this._state;
}
get processID() {
return this._processID;
}
get secretKey() {
return this._secretKey;
}
get sessionParameters() {
return this._sessionParameters;
}
connect() {
if (this._socket)
return;
this._state = ConnectionState.CONNECTING;
const options = this.options;
const socket = (this._socket = new net.Socket());
const errorHandler = (err) => {
this._state = ConnectionState.CLOSED;
this._removeListeners();
this._reset();
socket.destroy();
this._socket = undefined;
this.emit('error', err);
};
const connectHandler = () => {
socket.setTimeout(0);
if (this.options.keepAlive || this.options.keepAlive == null)
socket.setKeepAlive(true);
socket.write(this._frontend.getSSLRequestMessage());
socket.once('data', x => {
this._removeListeners();
if (x.toString() === 'S') {
const tslOptions = { ...options.ssl, socket };
if (options.host && net.isIP(options.host) === 0)
tslOptions.servername = options.host;
const tlsSocket = (this._socket = tls.connect(tslOptions));
tlsSocket.once('error', errorHandler);
tlsSocket.once('secureConnect', () => {
this._removeListeners();
this._handleConnect();
});
return;
}
if (x.toString() === 'N') {
if (options.requireSSL) {
return errorHandler(new Error('Server does not support SSL connections'));
}
this._removeListeners();
this._handleConnect();
return;
}
return errorHandler(new Error('There was an error establishing an SSL connection'));
});
};
socket.setNoDelay(true);
socket.setTimeout(options.connectTimeoutMs || 30000, () => errorHandler(new Error('Connection timed out')));
socket.once('error', errorHandler);
socket.once('connect', connectHandler);
this.emit('connecting');
const port = options.port || DEFAULT_PORT_NUMBER;
if (options.host && options.host.startsWith('/')) {
socket.connect(path.join(options.host, '/.s.PGSQL.' + port));
}
else
socket.connect(options.port || DEFAULT_PORT_NUMBER, options.host || 'localhost');
}
close() {
if (!this._socket || this._socket.destroyed) {
this._state = ConnectionState.CLOSED;
this._socket = undefined;
this._reset();
return;
}
if (this._state === ConnectionState.CLOSING)
return;
const socket = this._socket;
this._state = ConnectionState.CLOSING;
this._removeListeners();
socket.once('close', () => this._handleClose());
socket.destroy();
}
sendParseMessage(args, cb) {
if (this.listenerCount('debug'))
this.emit('debug', { location: 'PgSocket.sendParseMessage', args });
this._send(this._frontend.getParseMessage(args), cb);
}
sendBindMessage(args, cb) {
if (this.listenerCount('debug'))
this.emit('debug', { location: 'PgSocket.sendBindMessage', args });
this._send(this._frontend.getBindMessage(args), cb);
}
sendDescribeMessage(args, cb) {
if (this.listenerCount('debug'))
this.emit('debug', { location: 'PgSocket.sendDescribeMessage', args });
this._send(this._frontend.getDescribeMessage(args), cb);
}
sendExecuteMessage(args, cb) {
if (this.listenerCount('debug'))
this.emit('debug', { location: 'PgSocket.sendDescribeMessage', args });
this._send(this._frontend.getExecuteMessage(args), cb);
}
sendCloseMessage(args, cb) {
if (this.listenerCount('debug'))
this.emit('debug', { location: 'PgSocket.sendCloseMessage', args });
this._send(this._frontend.getCloseMessage(args), cb);
}
sendQueryMessage(sql, cb) {
if (this.listenerCount('debug'))
this.emit('debug', { location: 'PgSocket.sendQueryMessage', sql });
this._send(this._frontend.getQueryMessage(sql), cb);
}
sendFlushMessage(cb) {
if (this.listenerCount('debug'))
this.emit('debug', { location: 'PgSocket.sendFlushMessage' });
this._send(this._frontend.getFlushMessage(), cb);
}
sendTerminateMessage(cb) {
if (this.listenerCount('debug'))
this.emit('debug', { location: 'PgSocket.sendTerminateMessage' });
this._send(this._frontend.getTerminateMessage(), cb);
}
sendSyncMessage() {
if (this.listenerCount('debug'))
this.emit('debug', { location: 'PgSocket.sendSyncMessage' });
this._send(this._frontend.getSyncMessage());
}
capture(callback) {
if (this._state === ConnectionState.CLOSING ||
this._state === ConnectionState.CLOSED) {
return Promise.reject(new Error('Connection closed'));
}
if (this._state !== ConnectionState.READY)
return Promise.reject(new Error('Connection is not ready'));
return new Promise((resolve, reject) => {
const done = (err, result) => {
this.removeListener('close', closeHandler);
this.removeListener('error', errorHandler);
this.removeListener('message', msgHandler);
if (err)
reject(err);
else
resolve(result);
};
const errorHandler = (err) => {
this.removeListener('close', closeHandler);
this.removeListener('message', msgHandler);
reject(err);
};
const closeHandler = () => {
this.removeListener('error', errorHandler);
this.removeListener('message', msgHandler);
reject(new Error('Connection closed'));
};
const msgHandler = (code, msg) => {
const x = callback(code, msg, done);
if (promisify.isPromise(x))
x.catch(err => done(err));
};
this.once('close', closeHandler);
this.once('error', errorHandler);
this.on('message', msgHandler);
});
}
_removeListeners() {
if (!this._socket)
return;
this._socket.removeAllListeners('error');
this._socket.removeAllListeners('connect');
this._socket.removeAllListeners('data');
this._socket.removeAllListeners('close');
}
_reset() {
this._backend.reset();
this._sessionParameters = {};
this._processID = undefined;
this._secretKey = undefined;
this._saslSession = undefined;
}
_handleConnect() {
const socket = this._socket;
if (!socket)
return;
this._state = ConnectionState.AUTHORIZING;
this._reset();
socket.on('data', (data) => this._handleData(data));
socket.on('error', (err) => this._handleError(err));
socket.on('close', () => this._handleClose());
this._send(this._frontend.getStartupMessage({
user: this.options.user || 'postgres',
database: this.options.database || '',
application_name: this.options.applicationName || '',
}));
}
_handleClose() {
this._reset();
this._socket = undefined;
this._state = ConnectionState.CLOSED;
this.emit('close');
}
_handleError(err) {
if (this._state !== ConnectionState.READY) {
this._socket?.end();
}
this.emit('error', err);
}
_handleData(data) {
this._backend.parse(data, (code, payload) => {
try {
switch (code) {
case Protocol.BackendMessageCode.Authentication:
this._handleAuthenticationMessage(payload);
break;
case Protocol.BackendMessageCode.ErrorResponse:
this.emit('error', new DatabaseError(payload));
break;
case Protocol.BackendMessageCode.NoticeResponse:
this.emit('notice', payload);
break;
case Protocol.BackendMessageCode.NotificationResponse:
this.emit('notification', payload);
break;
case Protocol.BackendMessageCode.ParameterStatus:
this._handleParameterStatus(payload);
break;
case Protocol.BackendMessageCode.BackendKeyData:
this._handleBackendKeyData(payload);
break;
case Protocol.BackendMessageCode.ReadyForQuery:
if (this._state !== ConnectionState.READY) {
this._state = ConnectionState.READY;
this.emit('ready');
}
else
this.emit('message', code, payload);
break;
case Protocol.BackendMessageCode.CommandComplete: {
const msg = this._handleCommandComplete(payload);
this.emit('message', code, msg);
break;
}
default:
this.emit('message', code, payload);
}
}
catch (e) {
this._handleError(e);
}
});
}
_resolvePassword(cb) {
(async () => {
const pass = typeof this.options.password === 'function'
? await this.options.password()
: this.options.password;
cb(pass || '');
})().catch(err => this._handleError(err));
}
_handleAuthenticationMessage(msg) {
if (!msg) {
this.emit('authenticate');
return;
}
switch (msg.kind) {
case Protocol.AuthenticationMessageKind.CleartextPassword:
this._resolvePassword(password => {
this._send(this._frontend.getPasswordMessage(password));
});
break;
case Protocol.AuthenticationMessageKind.MD5Password:
this._resolvePassword(password => {
const md5 = (x) => crypto.createHash('md5').update(x, 'utf8').digest('hex');
const l = md5(password + this.options.user);
const r = md5(Buffer.concat([Buffer.from(l), msg.salt]));
const pass = 'md5' + r;
this._send(this._frontend.getPasswordMessage(pass));
});
break;
case Protocol.AuthenticationMessageKind.SASL: {
if (!msg.mechanisms.includes('SCRAM-SHA-256')) {
throw new Error('SASL: Only mechanism SCRAM-SHA-256 is currently supported');
}
const saslSession = (this._saslSession = SASL.createSession(this.options.user || '', 'SCRAM-SHA-256'));
this._send(this._frontend.getSASLMessage(saslSession));
break;
}
case Protocol.AuthenticationMessageKind.SASLContinue: {
const saslSession = this._saslSession;
if (!saslSession)
throw new Error('SASL: Session not started yet');
this._resolvePassword(password => {
SASL.continueSession(saslSession, password, msg.data);
const buf = this._frontend.getSASLFinalMessage(saslSession);
this._send(buf);
});
break;
}
case Protocol.AuthenticationMessageKind.SASLFinal: {
const session = this._saslSession;
if (!session)
throw new Error('SASL: Session not started yet');
SASL.finalizeSession(session, msg.data);
this._saslSession = undefined;
break;
}
default:
break;
}
}
_handleParameterStatus(msg) {
this._sessionParameters[msg.name] = msg.value;
}
_handleBackendKeyData(msg) {
this._processID = msg.processID;
this._secretKey = msg.secretKey;
}
_handleCommandComplete(msg) {
const m = msg.command && msg.command.match(COMMAND_RESULT_PATTERN);
const result = {
command: m[1],
};
if (m[3] != null) {
result.oid = parseInt(m[2], 10);
result.rowCount = parseInt(m[3], 10);
}
else if (m[2])
result.rowCount = parseInt(m[2], 10);
return result;
}
_send(data, cb) {
if (this._socket && this._socket.writable) {
this._socket.write(data, cb);
}
}
}