@authzkit/prisma-tenant-guard
Version:
Tenant guardrails for Prisma clients with assist/assert/strict enforcement and CLI tooling
539 lines (453 loc) • 14.4 kB
text/typescript
import { readFile } from 'node:fs/promises';
import { dirname, resolve } from 'node:path';
import { pathToFileURL } from 'node:url';
import {
createTenantClient,
TenantGuardError,
type CreateTenantClientOptions,
type Mode,
type TenantMeta
} from './index.js';
export interface TenantGuardCliOptions {
cwd?: string;
configPath?: string;
mode?: Mode;
silent?: boolean;
}
export interface TenantGuardCliResult {
status: 'ok' | 'error';
command: string;
messages: string[];
}
interface ParsedArgs {
command: string;
flags: Record<string, string | boolean>;
positionals: string[];
}
interface TenantGuardConfig extends CreateTenantClientOptions {
metaFile?: string;
schemaPath?: string;
}
export async function runTenantGuardCli(
argv: string[],
options: TenantGuardCliOptions = {},
): Promise<TenantGuardCliResult> {
const parsed = parseArgs(argv);
if (parsed.flags.help === true) {
return emitResult(
'ok',
parsed.command,
[
'Usage: authzkit-tenant-guard <command> [--config <path>] [--mode <mode>]',
'',
'Commands: check | plan | rls | smoke',
],
options,
);
}
const cwd = options.cwd ?? process.cwd();
const configPath = resolve(
cwd,
(parsed.flags.config as string | undefined) ??
options.configPath ??
'tenant-guard.config.json',
);
let config: TenantGuardConfig;
try {
config = await loadConfig(configPath);
} catch (error) {
return emitResult(
'error',
parsed.command,
[formatError(error, `Failed to read config at ${configPath}`)],
options,
);
}
if (parsed.flags.mode && typeof parsed.flags.mode === 'string') {
config.mode = parsed.flags.mode as Mode;
}
if (options.mode) {
config.mode = options.mode;
}
if (parsed.flags.tenant && typeof parsed.flags.tenant === 'string') {
config.tenantId = parsed.flags.tenant;
}
const command = parsed.command ?? 'check';
try {
switch (command) {
case 'check':
return emitResult('ok', 'check', runCheck(config), options);
case 'plan':
return emitResult('ok', 'plan', runPlan(config), options);
case 'rls':
return emitResult('ok', 'rls', runRls(config), options);
case 'smoke':
return emitResult('ok', 'smoke', await runSmoke(config), options);
default:
return emitResult('error', command, [`Unknown command: ${command}`], options);
}
} catch (error) {
if (error instanceof TenantGuardError) {
return emitResult('error', command, [formatError(error)], options);
}
return emitResult(
'error',
command,
[formatError(error, 'Unexpected error')],
options,
);
}
}
function runCheck(config: TenantGuardConfig): string[] {
const messages: string[] = [];
const errors: string[] = [];
if (typeof config.tenantId !== 'string' || config.tenantId.length === 0) {
errors.push('tenantId must be a non-empty string');
}
if (!config.meta || Object.keys(config.meta).length === 0) {
errors.push('meta must define at least one model');
}
for (const [model, modelMeta] of Object.entries(config.meta ?? {})) {
const tenantField = modelMeta.tenantField ?? 'tenantId';
if (tenantField.length === 0) {
errors.push(`${model}: tenantField must be a non-empty string`);
}
if (modelMeta.nestedTargets) {
for (const [relation, target] of Object.entries(modelMeta.nestedTargets)) {
const candidate = resolveNestedTarget(config.meta, target, 'create');
if (!candidate) {
errors.push(
`${model}.${relation}: nested target does not resolve to a known model`,
);
}
}
}
}
const mode = config.mode ?? 'strict';
const rlsEnabled = config.rls?.enabled === true;
if (mode === 'strict' && !rlsEnabled) {
errors.push('strict mode requires rls.enabled=true to guarantee tenant isolation');
}
if (rlsEnabled) {
const varName = config.rls?.varName ?? 'authzkit.tenant_id';
if (!varName || typeof varName !== 'string' || varName.trim().length === 0) {
errors.push('rls.varName must be a non-empty string when RLS is enabled');
}
const missingComposite = Object.entries(config.meta ?? {}).filter(
([, meta]) => !meta.compositeSelector,
);
if (missingComposite.length > 0) {
errors.push(
`rls.enabled requires compositeSelector for: ${missingComposite.map(([name]) => name).join(', ')}`,
);
}
}
if (errors.length > 0) {
throw new Error(errors.join('\n'));
}
messages.push(`tenantId: ${config.tenantId}`);
messages.push(`mode: ${mode}`);
messages.push(`models: ${Object.keys(config.meta ?? {}).join(', ')}`);
if (config.metaFile) {
messages.push(`meta source: ${config.metaFile}`);
}
return messages;
}
function runPlan(config: TenantGuardConfig): string[] {
const lines: string[] = [];
lines.push(`Plan for tenant guard enforcement (mode: ${config.mode ?? 'strict'})`);
for (const [model, modelMeta] of Object.entries(config.meta ?? {})) {
const tenantField = modelMeta.tenantField ?? 'tenantId';
let header = `- ${model}: tenant field → ${tenantField}`;
if (modelMeta.compositeSelector) {
header += `, composite selector → ${modelMeta.compositeSelector}`;
} else {
header += ', composite selector ✖';
}
lines.push(header);
if (modelMeta.nestedTargets) {
const nestedLines: string[] = [];
for (const [relation, target] of Object.entries(modelMeta.nestedTargets)) {
const createTarget = resolveNestedTarget(config.meta, target, 'create');
if (createTarget) {
nestedLines.push(` • ${relation}.create → ${createTarget}`);
}
const updateTarget = resolveNestedTarget(config.meta, target, 'update');
if (updateTarget && updateTarget !== createTarget) {
nestedLines.push(` • ${relation}.update → ${updateTarget}`);
}
}
if (nestedLines.length > 0) {
lines.push(...nestedLines);
}
}
}
if (config.rls?.enabled) {
lines.push(`RLS enabled with var ${config.rls.varName ?? 'authzkit.tenant_id'}`);
}
return lines;
}
function runRls(config: TenantGuardConfig): string[] {
if (!config.rls?.enabled) {
return ['RLS is disabled. Enable via rls.enabled in config.'];
}
const varName = config.rls.varName ?? 'authzkit.tenant_id';
const lines = [
'Row-level security guidance:',
`- Ensure each table has policies referencing current_setting('${varName}')`,
`- Wrap mutations using withTenantRLS(prisma, tenantId, fn, '${varName}')`,
`- Example policy: USING (tenant_id = current_setting('${varName}')::text)`,
];
const missingComposite = Object.entries(config.meta ?? {}).filter(
([, meta]) => !meta.compositeSelector,
);
if (missingComposite.length > 0) {
lines.push(
`- Missing composite selectors for: ${missingComposite.map(([name]) => name).join(', ')}`,
);
}
return lines;
}
async function runSmoke(config: TenantGuardConfig): Promise<string[]> {
const messages: string[] = [];
const fake = new FakePrisma();
const guarded = createTenantClient(fake, config);
const client = guarded as unknown as GuardedFakePrisma;
const okPayload = {
model: 'Post',
operation: 'create',
args: { data: { tenantId: config.tenantId } },
};
await client.__execute(okPayload);
messages.push('✔ create with matching tenant allowed');
const badPayload = {
model: 'Post',
operation: 'create',
args: { data: { tenantId: `${config.tenantId}-other` } },
};
let rejected = false;
try {
await client.__execute(badPayload);
} catch (error) {
if (error instanceof TenantGuardError) {
rejected = true;
messages.push(`✔ cross-tenant create rejected (${error.details.code})`);
} else {
throw error;
}
}
if (!rejected) {
throw new Error('Cross-tenant create was not rejected');
}
return messages;
}
function resolveNestedTarget(
meta: TenantMeta,
target: NestedTargetLike,
operation: string,
): string | undefined {
if (typeof target === 'string') {
return meta[target] ? target : undefined;
}
const record = target as Record<string, string | undefined> & { $default?: string };
if (record.$default && meta[record.$default]) {
return record.$default;
}
const opTarget = record[operation];
if (opTarget && meta[opTarget]) {
return opTarget;
}
return undefined;
}
type NestedTargetLike =
| string
| ({ $default?: string } & Record<string, string | undefined>);
async function loadConfig(configPath: string): Promise<TenantGuardConfig> {
const absoluteConfigPath = resolve(configPath);
const configDir = dirname(absoluteConfigPath);
const content = await readFile(absoluteConfigPath, 'utf8');
const parsed = JSON.parse(content) as TenantGuardConfig;
if (!parsed.tenantId || typeof parsed.tenantId !== 'string') {
throw new Error('Config must include tenantId');
}
let meta = parsed.meta;
if (!meta && parsed.metaFile) {
const metaPath = resolve(configDir, parsed.metaFile);
meta = await loadMetaFile(metaPath);
parsed.metaFile = metaPath;
}
if (!meta) {
throw new Error('Config must include meta or metaFile');
}
parsed.meta = validateMeta(meta, parsed.metaFile ?? 'inline meta');
return parsed;
}
async function loadMetaFile(metaPath: string): Promise<TenantMeta> {
if (metaPath.endsWith('.json')) {
const raw = await readFile(metaPath, 'utf8');
try {
return JSON.parse(raw) as TenantMeta;
} catch (error) {
const reason = error instanceof Error ? error.message : String(error);
throw new Error(`Failed to parse meta JSON at ${metaPath}: ${reason}`);
}
}
if (metaPath.endsWith('.ts')) {
throw new Error(
`Cannot import TypeScript meta file (${metaPath}). Compile to JS or generate JSON.`,
);
}
if (
metaPath.endsWith('.js') ||
metaPath.endsWith('.mjs') ||
metaPath.endsWith('.cjs')
) {
const mod = await import(pathToFileURL(metaPath).href);
const candidate = mod.default ?? mod.meta ?? mod.tenantMeta;
if (!candidate) {
throw new Error(
`Meta module at ${metaPath} does not export default/meta/tenantMeta`,
);
}
return candidate as TenantMeta;
}
throw new Error(`Unsupported meta file extension: ${metaPath}`);
}
function validateMeta(meta: unknown, source: string): TenantMeta {
if (!meta || typeof meta !== 'object') {
throw new Error(`Meta from ${source} must be an object`);
}
return meta as TenantMeta;
}
function parseArgs(argv: string[]): ParsedArgs {
const args = [...argv];
const flags: Record<string, string | boolean> = {};
const positionals: string[] = [];
let command: string | undefined;
while (args.length > 0) {
const token = args.shift();
if (!token) {
continue;
}
if (token.startsWith('--')) {
const body = token.slice(2);
if (!body) {
continue;
}
const eqIndex = body.indexOf('=');
if (eqIndex !== -1) {
const key = body.slice(0, eqIndex);
const value = body.slice(eqIndex + 1);
if (key) {
flags[key] = value;
}
continue;
}
const name = body;
if (name === 'help') {
flags.help = true;
continue;
}
const peek = args[0];
if (peek && !peek.startsWith('-')) {
const value = args.shift();
if (value !== undefined) {
flags[name] = value;
}
} else {
flags[name] = true;
}
continue;
}
if (token.startsWith('-') && token.length > 1) {
const name = token.slice(1);
const peek = args[0];
if (peek && !peek.startsWith('-')) {
const value = args.shift();
if (value !== undefined) {
flags[name] = value;
}
} else {
flags[name] = true;
}
continue;
}
if (!command) {
command = token;
continue;
}
positionals.push(token);
}
return {
command: command ?? 'check',
flags,
positionals,
};
}
function emitResult(
status: 'ok' | 'error',
command: string,
messages: string[],
options: TenantGuardCliOptions,
): TenantGuardCliResult {
if (!options.silent) {
for (const line of messages) {
console.log(line);
}
}
return { status, command, messages };
}
function formatError(error: unknown, prefix?: string): string {
if (error instanceof Error) {
if (error instanceof TenantGuardError) {
const detail = error.details;
return `${prefix ? `${prefix}: ` : ''}${error.message} [${detail.code} at ${detail.path}]`;
}
return `${prefix ? `${prefix}: ` : ''}${error.message}`;
}
return `${prefix ? `${prefix}: ` : ''}${String(error)}`;
}
interface GuardHandlerPayload {
model: string;
operation: string;
args: Record<string, unknown>;
query: (args: Record<string, unknown>) => Promise<unknown>;
}
type GuardHandler = (payload: GuardHandlerPayload) => Promise<unknown>;
class FakePrisma {
$extends(extension: unknown): FakePrisma {
if (!extension || typeof extension !== 'object') {
throw new Error('Invalid extension passed to FakePrisma');
}
const query = (extension as Record<string, unknown>).query;
if (!query || typeof query !== 'object') {
throw new Error('Extension missing query handler');
}
const allModels = (query as Record<string, unknown>).$allModels;
if (!allModels || typeof allModels !== 'object') {
throw new Error('Extension missing $allModels');
}
const handler = (allModels as { $allOperations?: GuardHandler }).$allOperations;
if (typeof handler !== 'function') {
throw new Error('Extension missing $allOperations handler');
}
return new GuardedFakePrisma(handler);
}
}
class GuardedFakePrisma extends FakePrisma {
private readonly handler: GuardHandler;
constructor(handler: GuardHandler) {
super();
this.handler = handler;
}
async __execute(payload: {
model: string;
operation: string;
args: Record<string, unknown>;
}) {
return this.handler({
...payload,
query: async (args: Record<string, unknown>) => args,
});
}
}