UNPKG

pgwire

Version:

PostgreSQL client library for Node.js which exposes all features of wire protocol

588 lines (574 loc) 18.4 kB
const net = require('net'); const EventEmmiter = require('events'); const { Readable, PassThrough, Transform, pipeline, finished } = require('stream'); const { createHash, pbkdf2, createHmac, randomBytes } = require('crypto'); const { debuglog } = require('util'); const BackendDecoder = require('./backend.js'); const { FrontendEncoder, ...fe } = require('./frontend.js'); const { logicalReplication } = require('./replication.js'); const { datatypesById, datatypesByName } = require('./datatypes.js'); const logBeMsg = debuglog('pgwire-be'); module.exports = class Connection extends EventEmmiter { constructor({ idleTimeout, ...options }) { super(); this.parameters = {}; this._idleTimeoutMillis = Math.max(idleTimeout, 0); this._idleTimer = null; this._options = options; this._responses = []; this._rowDecoder = null; this._pgerror = null; this._closeError = null; this._socket = null; this._tx = new FrontendEncoder(); this._startuptx = new FrontendEncoder(); this._tx.write(new fe.SubProtocol(this._startuptx)); this._rx = new BackendDecoder(); this._rx.on('data', this._recvMessage.bind(this)); this._isCopyOutMode = false; } query(arg, ...args) { if (typeof arg == 'string') { return this._querySimple(arg, ...args); } return this._queryExtended([arg].concat(args)); } _querySimple(script, { stdin } = {}) { const outmsg = [new fe.Query(script)]; if (stdin) { outmsg.push(...[].concat(stdin).map( it => new fe.SubProtocol(copyin(it)), )); } let onReadyForQuery; if (/^\s*(CREATE_REPLICATION_SLOT|START_REPLICATION)\b/.test(script)) { // should wait for ReadyForQuery const blocktx = new PassThrough(); outmsg.push(new fe.SubProtocol(blocktx)); onReadyForQuery = _ => blocktx.end(); } else { outmsg.push(new fe.CopyFail('Missing copy upstream')); } const response = this._request(outmsg); if (onReadyForQuery) { response.on('readyForQuery', onReadyForQuery); } return response; } _queryExtended(extendedOps) { return this._request( extendedOps .map(script2msg) .reduce((flat, it) => (flat.push(...it), flat), []) .concat(new fe.Sync()), ); function script2msg(it) { switch (it.op) { case undefined: case null: return [ ...script2msg({ op: 'parse', statement: it.statement, paramTypes: it.params && it.params.map(({ type }) => type), }), ...script2msg({ op: 'bind', params: it.params, }), ...script2msg({ op: 'execute', limit: it.limit, stdin: it.stdin, }), ]; case 'parse': return [new fe.Parse({ ...it, paramTypes: it.paramTypes && it.paramTypes.map(normalizeDatatypeId), })]; case 'bind': return [new fe.Bind({ outFormats0t1b: [1], ...it, params: it.params && it.params.map(({ type, value }) => ( type ? findDatatype(type).encode(value) : value )), })]; case 'execute': return [ // TODO write test to explain why // we need unconditional DescribePortal // before Execute new fe.DescribePortal(it.portal), new fe.Execute(it), it.stdin ? new fe.SubProtocol(copyin(it.stdin)) // CopyFail message ignored by postgres // if there is no COPY FROM statement : new fe.CopyFail('Missing copy upstream'), ]; case 'describe-statement': return [new fe.DescribeStatement(it.statementName)]; case 'close-statement': return [new fe.CloseStatement(it.statementName)]; case 'describe-portal': return [new fe.DescribePortal(it.portal)]; case 'close-portal': return [new fe.ClosePortal(it.portal)]; default: throw Error(`Unknown op "${it.op}"`); } } } end() { if (!this._tx) { return; } if (this._responses.length) { // pgbouncer closes connection before return query result // if we send Terminate message asynchronously, // so we need to wait for all responses before terminate const [lastResponse] = this._responses.slice(-1); const blocktx = new PassThrough(); lastResponse.on('readyForQuery', _ => blocktx.end()); this._tx.write(new fe.SubProtocol(blocktx)); } this._tx.write(new fe.Terminate()); this._tx.end(); this._tx = null; } destroy(err) { this._rx.destroy(err); } get queueSize() { return this._responses.length; } async logicalReplication(options) { const replstream = await logicalReplication(this, options); // TODO: respect ?keepalives=0 this._socket.setKeepAlive(true); return replstream; } _connectIfNotConnected() { if (this._connectionInited) { return false; } this._connectionInited = true; const { unixSocket, hostname, port, password: _password, ...startupParameters } = this._options; const socket = this._socket = net.connect(unixSocket || { host: hostname, port }); // cannot use stream.pipeline https://github.com/exe-dealer/node-iss-33050 this._tx .on('error', err => socket.destroy(err)) .pipe(socket) .on('error', err => this._rx.destroy(err)) .pipe(this._rx) .on('error', _ => socket.destroy()); finished(this._rx, this._onFinished.bind(this)); this._startuptx.write(new fe.StartupMessage(startupParameters)); return true; } _recvMessage(msg) { logBeMsg('-> %j', msg); switch (msg.tag) { case 'DataRow': return this._recvDataRow(msg); case 'CopyData': return this._recvCopyData(msg); case 'CopyDone': return this._recvCopyDone(msg); case 'CopyOutResponse': return this._recvCopyOutResponse(msg); case 'CopyBothResponse': return this._recvCopyBothResponse(msg); case 'CommandComplete': return this._recvCommandComplete(msg); case 'EmptyQueryResponse': return this._recvEmptyQueryResponse(msg); case 'PortalSuspended': return this._recvPortalSuspended(msg); case 'NoData': return this._recvNoData(msg); case 'RowDescription': return this._recvRowDescription(msg); case 'ReadyForQuery': return this._recvReadyForQuery(msg); case 'ErrorResponse': return this._recvErrorResponse(msg); case 'NotificationResponse': return this._recvNotificationResponse(msg); case 'NoticeResponse': return this._recvNoticeResponse(msg); case 'ParameterStatus': return this._recvParameterStatus(msg); case 'BackendKeyData': return this._recvBackendKeyData(msg); case 'AuthenticationCleartextPassword': return this._recvAuthenticationCleartextPassword(msg); case 'AuthenticationMD5Password': return this._recvAuthenticationMD5Password(msg); case 'AuthenticationSASL': return this._recvAuthenticationSASL(msg); case 'AuthenticationSASLContinue': return this._recvAuthenticationSASLContinue(msg); case 'AuthenticationSASLFinal': return this._recvAuthenticationSASLFinal(msg); } } _passResponse(chunk) { const response = this._responses[0]; return response.push(chunk) || response.destroyed; } _passResponseBackpreasured(chunk) { if (!this._passResponse(chunk)) { this._rx.pause(); } } _passBoundary(boundary) { return this._passResponse(Object.assign( Buffer.alloc(0), { boundary }, )); } _recvReadyForQuery(msg) { if (this._startuptx) { this._startuptx.end(); this._startuptx = null; this.emit('ready'); } else { this._passBoundary(msg); const response = this._responses.shift(); response.emit('readyForQuery', msg); if (this._pgerror) { response.destroy(this._pgerror); this._pgerror = null; } else { response.push(null); } } if (this._idleTimeoutMillis > 0 && !this._responses.length) { this._idleTimer = setTimeout( this._onIdleTimeout.bind(this), this._idleTimeoutMillis, ); } } _onFinished(ioerr) { clearTimeout(this._idleTimer); this._tx = null; this._socket = null; // TODO what error should we report // if we have both ioerr and this._pgerror ? const err = this._closeError = ioerr || this._pgerror; this.emit('close', err); while (this._responses.length) { this._responses.shift().destroy(err || Error('Connection terminated')); } } _request(messages) { const response = new Response(this); if (!this._tx) { return response.destroy(Object.assign(Error(), { message: 'Query after connection end', closeError: this._closeError, })); } clearTimeout(this._idleTimer); this._connectIfNotConnected(); for (const m of messages) { this._tx.write(m); } this._responses.push(response); return response; } _recvErrorResponse({ tag: _, code, ...props }) { this._isCopyOutMode = false; this._pgerror = Object.assign(Error(), props, { name: 'PGError', code: 'PGERR_' + code, prevError: this._pgerror, }); } _recvRowDescription(msg) { this._rowDecoder = msg.fields.map(({ typeid, binary }) => { const { decodeBin = noop => noop, decodeText = noop => noop, } = datatypesById[typeid] || {}; if (binary) { return decodeBin; } return buf => decodeText(String(buf)); }); this._passBoundary(msg); } _recvNoData(msg) { this._passBoundary(msg); } _recvDataRow({ data }) { for (let i = 0; i < data.length; i++) { const val = data[i]; if (val != null) { data[i] = this._rowDecoder[i](val); } } this._passResponseBackpreasured(data); } _recvCopyData({ data }) { // postgres can send CopyData when not in copy out mode // steps to reproduce: // // StartupMessage({ database: 'postgres', user: 'postgres', replication: 'database' }) // // Query('CREATE_REPLICATION_SLOT a LOGICAL test_decoding') // wait for ReadyForQuery // // Query('START_REPLICATION SLOT a LOGICAL 0/0'); // // CopyDone() // before wal_sender_timeout expires // copydata-keepalive message can be received here after CopyDone // but before ReadyForQuery // // Query('CREATE_REPLICATION_SLOT b LOGICAL test_decoding') // sometimes copydata-keepalive message can be received here // before RowDescription message if (this._isCopyOutMode) { this._passResponseBackpreasured(data); } } _recvCopyDone(msg) { this._isCopyOutMode = false; this._passBoundary(msg); } _recvCopyOutResponse(msg) { this._isCopyOutMode = true; this._passBoundary(msg); } _recvCopyBothResponse(msg) { this._isCopyOutMode = true; this._passBoundary(msg); } _recvCommandComplete(msg) { this._passBoundary(msg); } _recvEmptyQueryResponse(msg) { this._passBoundary(msg); } _recvPortalSuspended(msg) { this._passBoundary(msg); } _recvNotificationResponse(msg) { this.emit('notification', msg); } _recvParameterStatus(msg) { this.parameters[msg.parameter] = msg.value; this.emit('parameter', msg.parameter, msg.value); } _recvBackendKeyData({ pid, secretkey }) { this.pid = pid; this.secretkey = secretkey; } _recvNoticeResponse(msg) { if (this._responses[0]) { if (!this._passBoundary(msg)) { this._rx.pause(); } return; } if (this.emit('notice', msg)) { return; } console.error(msg); } _recvAuthenticationCleartextPassword() { // if (!this._options.password) { // return this._startuptx.destroy(Error('Password was not specified')); // } this._startuptx.write(new fe.PasswordMessage(this._options.password)); } _recvAuthenticationMD5Password({ salt }) { const pwdhash = ( createHash('md5') .update(this._options.password, 'utf-8') .update(this._options.user, 'utf-8') .digest('hex') ); this._startuptx.write(new fe.PasswordMessage('md5' + ( createHash('md5') .update(pwdhash, 'utf-8') .update(salt) .digest('hex') ))); } _recvAuthenticationSASL(_msg) { randomBytes(24, (err, clientNonce) => { if (err) { return this._rx.destroy(err); } const clientNonceB64 = clientNonce.toString('base64'); this._saslClientFirstMessageBare = 'n=,r=' + clientNonceB64; this._startuptx.write(new fe.SASLInitialResponse({ mechanism: 'SCRAM-SHA-256', data: 'n,,' + this._saslClientFirstMessageBare, })); }); } _recvAuthenticationSASLContinue({ data }) { const firstServerMessage = data.toString(); const { i: iter, s: saltB64, r: nonceB64 } = ( firstServerMessage .split(',') .map(it => /^([^=]*)=(.*)$/.exec(it)) .reduce((acc, [, key, val]) => ({ ...acc, [key]: val }), {}) ); const finalMessageWithoutProof = 'c=biws,r=' + nonceB64; const salt = Buffer.from(saltB64, 'base64'); const password = this._options.password.normalize(); pbkdf2(password, salt, +iter, 32, 'sha256', (err, saltedPassword) => { if (err) { return this._rx.destroy(err); } const hmac = (key, inp) => createHmac('sha256', key).update(inp).digest(); const clientKey = hmac(saltedPassword, 'Client Key'); const storedKey = createHash('sha256').update(clientKey).digest(); const authMessage = ( this._saslClientFirstMessageBare + ',' + firstServerMessage + ',' + finalMessageWithoutProof ); const clientSignature = hmac(storedKey, authMessage); const clientProof = Buffer.from(clientKey.map((b, i) => b ^ clientSignature[i])); const serverKey = hmac(saltedPassword, 'Server Key'); this._saslServerSignatureB64 = hmac(serverKey, authMessage).toString('base64'); this._startuptx.write(new fe.SASLResponse( finalMessageWithoutProof + ',p=' + clientProof.toString('base64'), )); }); } _recvAuthenticationSASLFinal({ data }) { const receivedServerSignatureB64 = /(?<=^v=)[^,]*/.exec(data); if (this._saslServerSignatureB64 != receivedServerSignatureB64) { this._rx.destroy(Error('Invalid server SCRAM signature')); } } _onIdleTimeout() { // this.emit('idletimeout'); this.end(); } }; function copyin(upstream) { const feEncoder = new FrontendEncoder(); const msgstream = pipeline(upstream, new Transform({ readableObjectMode: true, transform: (chunk, _enc, done) => done(null, new fe.CopyData(chunk)), }), err => { if (err) { feEncoder.end(new fe.CopyFail(String(err))); } else { feEncoder.end(new fe.CopyDone()); } }); feEncoder.on('error', err => msgstream.destroy(err)); return msgstream.pipe(feEncoder, { end: false }); } class Response extends Readable { constructor(conn) { super({ objectMode: true }); this._conn = conn; this._promise = null; this._stack = Object.assign(Error(), { name: '\n ...' }).stack; } _read() { this._resumeConIfCurrentResponse(); } _destroy(...args) { this._resumeConIfCurrentResponse(); return super._destroy(...args); } _resumeConIfCurrentResponse() { if (this == this._conn._responses[0]) { this._conn._rx.resume(); } } _fetch() { if (!this._promise) { this._promise = fetchResponse(this).catch(err => { err.stack += this._stack; throw err; }); } return this._promise; } then(...args) { return this._fetch().then(...args); } catch(...args) { return this._fetch().catch(...args); } } function fetchResponse(stream) { return new Promise((resolve, reject) => { let rows = []; let notices = []; let inTransaction = null; const results = []; stream.on('error', err => reject(Object.assign(err, { inTransaction }))); stream.on('end', _ => { const lastResult = results[results.length - 1]; const lastRows = lastResult && lastResult.rows; let scalar = lastRows && lastRows[0]; if (Array.isArray(scalar)) { // DataRow or CopyData scalar = scalar[0]; } return resolve({ inTransaction, results, rows: lastRows, scalar, command: lastResult && lastResult.command, empty: Boolean(lastResult && lastResult.empty), suspended: Boolean(lastResult && lastResult.suspended), }); }); stream.on('data', chunk => { const { boundary } = chunk; if (!boundary) { return rows.push(chunk); } switch (boundary.tag) { case 'NoticeResponse': notices.push(boundary); break; case 'CommandComplete': results.push({ rows, notices, command: boundary.command }); rows = []; notices = []; break; case 'PortalSuspended': results.push({ rows, notices, suspended: true }); rows = []; notices = []; break; case 'EmptyQueryResponse': results.push({ rows, notices, empty: true }); rows = []; notices = []; break; case 'ReadyForQuery': inTransaction = boundary.transactionStatus != 'I'.charCodeAt(); break; } }); }); } function normalizeDatatypeId(typeIdOrName) { if (typeof typeIdOrName == 'number') { return typeIdOrName; } const result = datatypesByName[typeIdOrName]; if (result) { return result.id; } throw Error(`Unknown builtin datatype name ${JSON.stringify(typeIdOrName)}`); } function findDatatype(typeIdOrName) { if (typeof typeIdOrName == 'number') { const result = datatypesById[typeIdOrName]; if (result) { return result; } throw Error(`Unknown builtin datatype id ${typeIdOrName}`); } else { const result = datatypesByName[typeIdOrName]; if (result) { return result; } throw Error(`Unknown builtin datatype name ${JSON.stringify(typeIdOrName)}`); } }