pg-transactional-tests
Version:
Wraps each test in transaction for `pg` package
201 lines (200 loc) • 7.85 kB
JavaScript
;
var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) {
function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); }
return new (P || (P = Promise))(function (resolve, reject) {
function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } }
function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } }
function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); }
step((generator = generator.apply(thisArg, _arguments || [])).next());
});
};
Object.defineProperty(exports, "__esModule", { value: true });
exports.testTransaction = void 0;
const node_async_hooks_1 = require("node:async_hooks");
const pg_1 = require("pg");
const { connect, query } = pg_1.Client.prototype;
const { connect: poolConnect, query: poolQuery } = pg_1.Pool.prototype;
let asyncLocalStorage;
let parallelId = 0;
const getClientId = (client) => {
const { connectionParameters: p } = client;
const parallelId = asyncLocalStorage === null || asyncLocalStorage === void 0 ? void 0 : asyncLocalStorage.getStore();
return `${p.host} ${p.port} ${p.user} ${p.database}${parallelId ? ' parallel:' + parallelId : ''}`;
};
let prependStartTransaction = false;
let clientStates = {};
const getState = (self) => {
const thisId = getClientId(self);
let state = clientStates[thisId];
if (!state) {
const parallelId = asyncLocalStorage === null || asyncLocalStorage === void 0 ? void 0 : asyncLocalStorage.getStore();
let client;
if (parallelId) {
client = new pg_1.Client();
client.connectionParameters = self.connectionParameters;
}
else {
client = self;
}
clientStates[thisId] = state = {
client,
transactionId: 0,
prependStartTransaction,
};
}
return state;
};
function patchedConnect(callback) {
return __awaiter(this, void 0, void 0, function* () {
// @types/pg says there is no second parameter, but actually pg itself relies on it
const cb = callback;
const state = getState(this);
if (state.connectPromise) {
yield state.connectPromise;
cb === null || cb === void 0 ? void 0 : cb(undefined, state.client);
return;
}
return (state.connectPromise = new Promise((resolve, reject) => {
connect.call(state.client, (err) => {
if (err) {
cb === null || cb === void 0 ? void 0 : cb(err);
reject(err);
}
else {
cb === null || cb === void 0 ? void 0 : cb(undefined, state.client);
resolve();
}
});
}));
});
}
function patchedPoolConnect(cb) {
this.options.max = 1;
if (cb) {
// @ts-expect-error whatever
poolConnect.call(this, cb);
return undefined;
}
else {
return poolConnect.call(this);
}
}
function patchedQuery(inputArg, ...args) {
return __awaiter(this, void 0, void 0, function* () {
const state = getState(this);
let input = inputArg;
const sql = (typeof input === 'string' ? input : input.text)
.trim()
.toUpperCase();
// Don't wrap in transactions for selects as they won't mutate
if (!sql.startsWith('SELECT')) {
let replacingSql;
if (state.prependStartTransaction) {
state.prependStartTransaction = false;
yield this.query('BEGIN');
}
if (sql.startsWith('START TRANSACTION') || sql.startsWith('BEGIN')) {
if (state.transactionId > 0) {
replacingSql = `SAVEPOINT "${state.transactionId++}"`;
}
else {
state.transactionId = 1;
}
}
else {
const isCommit = sql.startsWith('COMMIT');
const isRollback = !isCommit && sql.startsWith('ROLLBACK');
if (isCommit || isRollback) {
if (state.transactionId === 0) {
throw new Error(`Trying to ${isCommit ? 'COMMIT' : 'ROLLBACK'} outside of transaction`);
}
if (state.transactionId > 1) {
const savePoint = --state.transactionId;
replacingSql = `${isCommit ? 'RELEASE' : 'ROLLBACK TO'} SAVEPOINT "${savePoint}"`;
}
else {
state.transactionId = 0;
}
}
}
if (replacingSql) {
if (typeof input === 'string') {
input = replacingSql;
}
else {
input.text = replacingSql;
}
}
}
yield pg_1.Client.prototype.connect.call(this);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
return query.call(state.client, input, ...args);
});
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
function patchedPoolQuery(...args) {
return __awaiter(this, void 0, void 0, function* () {
const client = yield this.connect();
try {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
return client.query(...args);
}
finally {
client.release();
}
});
}
let started = 0;
exports.testTransaction = {
patch() {
pg_1.Client.prototype.connect = patchedConnect;
pg_1.Pool.prototype.connect = patchedPoolConnect;
pg_1.Client.prototype.query = patchedQuery;
pg_1.Pool.prototype.query = patchedPoolQuery;
},
unpatch() {
clientStates = {};
pg_1.Client.prototype.connect = connect;
pg_1.Client.prototype.query = query;
pg_1.Pool.prototype.connect = poolConnect;
pg_1.Pool.prototype.query = poolQuery;
},
start() {
started++;
if (pg_1.Client.prototype.connect !== patchedConnect) {
exports.testTransaction.patch();
}
prependStartTransaction = true;
for (const state of Object.values(clientStates)) {
state.prependStartTransaction = true;
}
},
rollback() {
return __awaiter(this, void 0, void 0, function* () {
yield Promise.all(Object.entries(clientStates).map(([id, state]) => __awaiter(this, void 0, void 0, function* () {
var _a, _b;
if (state.transactionId > 0) {
yield ((_a = state.client) === null || _a === void 0 ? void 0 : _a.query('ROLLBACK'));
}
else if (state.transactionId === 0 && / parallel:\d+$/.test(id)) {
yield ((_b = state.client) === null || _b === void 0 ? void 0 : _b.end());
}
})));
if (!--started) {
exports.testTransaction.unpatch();
}
});
},
parallel(fn) {
exports.testTransaction.patch();
asyncLocalStorage !== null && asyncLocalStorage !== void 0 ? asyncLocalStorage : (asyncLocalStorage = new node_async_hooks_1.AsyncLocalStorage());
return asyncLocalStorage.run(parallelId++, fn);
},
close() {
return __awaiter(this, void 0, void 0, function* () {
started = 0;
yield Promise.all(Object.values(clientStates).map((state) => state.client.end()));
exports.testTransaction.unpatch();
});
},
};