@stryke/prisma-trpc-generator
Version:
A fork of the prisma-trpc-generator code to work in ESM with Prisma v6.
325 lines (320 loc) • 12.6 kB
JavaScript
import { lowerCaseFirst } from "./string-format/src/lower-case-first.mjs";
import { project } from "./project.mjs";
import "./utils/get-prisma-internals.mjs";
import { generateBarrelFile, populateModelFile } from "./zod/model-helpers.mjs";
import { joinPaths } from "@stryke/path/join-paths";
import { relativePath } from "@stryke/path/file-path-fns";
//#region src/helpers.ts
const getProcedureName = (config) => {
return config.withShield ? "shieldedProcedure" : config.withMiddleware ? "protectedProcedure" : "publicProcedure";
};
const generateCreateRouterImport = ({ sourceFile, config }) => {
const imports = ["t"];
if (config) imports.push(getProcedureName(config));
sourceFile.addImportDeclaration({
moduleSpecifier: "../trpc",
namedImports: imports
});
};
const generateRouterImport = (sourceFile, modelNamePlural, modelNameCamelCase) => {
sourceFile.addImportDeclaration({
moduleSpecifier: `./${lowerCaseFirst(modelNameCamelCase)}.router`,
namedImports: [`${modelNamePlural}Router`]
});
};
async function generateTRPCExports(sourceFile, config, options, outputDir) {
if (config.withShield) sourceFile.addImportDeclaration({
moduleSpecifier: relativePath(outputDir, joinPaths(outputDir, typeof config.withShield === "string" ? config.withShield : "shield")),
namedImports: ["permissions"]
});
sourceFile.addStatements(`import type { Context } from '${relativePath(outputDir, joinPaths(outputDir, config.contextPath))}';`);
if (config.trpcOptions) sourceFile.addStatements(`import trpcOptions from '${typeof config.trpcOptions === "string" ? relativePath(outputDir, joinPaths(outputDir, config.trpcOptions)) : "./options"}';`);
if (config.withNext) sourceFile.addStatements(`import { createContext } from '${relativePath(outputDir, joinPaths(outputDir, config.contextPath))}';
import { initTRPC } from '@trpc/server';
import { createTRPCServerActionHandler } from '@stryke/trpc-next/action-handler';
import { cookies } from "next/headers";`);
sourceFile.addStatements(`
export const t = initTRPC.context<Context>().create(${config.trpcOptions ? "trpcOptions" : ""});`);
const middlewares = [];
if (config.withMiddleware && typeof config.withMiddleware === "boolean") {
sourceFile.addStatements(`
export const globalMiddleware = t.middleware(async ({ ctx, next }) => {
console.log('inside middleware!')
return next()
});`);
middlewares.push({
type: "global",
value: `.use(globalMiddleware)`
});
}
if (config.withMiddleware && typeof config.withMiddleware === "string") {
sourceFile.addStatements(`
import middleware from '${relativePath(outputDir, joinPaths(outputDir, typeof config.withMiddleware === "string" ? config.withMiddleware : "middleware"))}';
`);
sourceFile.addStatements(`
export const globalMiddleware = t.middleware(middleware);`);
middlewares.push({
type: "global",
value: `.use(globalMiddleware)`
});
}
if (config.withShield) {
sourceFile.addStatements(`
export const permissionsMiddleware = t.middleware(permissions);
`);
middlewares.push({
type: "shield",
value: `
.use(permissions)`
});
}
sourceFile.addStatements(`
/**
* Create a server-side caller
* @see https://trpc.io/docs/server/server-side-calls
*/
export const createCallerFactory = t.createCallerFactory;`);
sourceFile.addStatements(`
export const publicProcedure = t.procedure; `);
if (middlewares.length > 0) {
const procName = getProcedureName(config);
middlewares.forEach((middleware, i) => {
if (i === 0) sourceFile.addStatements(`
export const ${procName} = t.procedure`);
sourceFile.addStatements(`.use(${middleware.type === "shield" ? "permissionsMiddleware" : "globalMiddleware"})`);
});
}
if (config.withNext) sourceFile.addStatements(`
export const createAction: ReturnType<typeof createTRPCServerActionHandler> =
createTRPCServerActionHandler(cookies, t, createContext);
`);
sourceFile.formatText({ indentSize: 2 });
}
function generateProcedure(sourceFile, name, typeName, modelName, opType, baseOpType, config) {
let input = `input${!config.withZod ? " as any" : ""}`;
const nameWithoutModel = name.replace(modelName, "");
if (nameWithoutModel === "groupBy" && config.withZod) input = "{ where: input.where, orderBy: input.orderBy, by: input.by, having: input.having, take: input.take, skip: input.skip }";
sourceFile.addStatements(`${config.showModelNameInProcedure ? name : nameWithoutModel}: ${getProcedureName(config)}
${config.withZod ? `.input(${lowerCaseFirst(typeName)})` : ""}.${getProcedureTypeByOpName(baseOpType)}(async ({ ctx, input }) => {
const ${name} = await ctx.prisma.${lowerCaseFirst(modelName)}.${opType.replace("One", "")}(${input});
return ${name};
}),`);
}
function generateRouterSchemaImports(sourceFile, modelName, modelActions) {
sourceFile.addStatements([...new Set(modelActions.map((opName) => getRouterSchemaImportByOpName(opName, modelName)))].join("\n"));
}
const getRouterSchemaImportByOpName = (opName, modelName) => {
const opType = opName.replace("OrThrow", "").replace("ManyAndReturn", "");
const inputType = getInputTypeByOpName(opType, modelName);
return inputType ? `import { ${lowerCaseFirst(inputType)} } from "../schemas/${lowerCaseFirst(opType)}${modelName}.schema"; ` : "";
};
const getInputTypeByOpName = (opName, modelName) => {
let inputType;
switch (opName) {
case "findUnique":
inputType = `${modelName}FindUniqueSchema`;
break;
case "findFirst":
inputType = `${modelName}FindFirstSchema`;
break;
case "findMany":
inputType = `${modelName}FindManySchema`;
break;
case "findRaw":
inputType = `${modelName}FindRawObjectSchema`;
break;
case "createOne":
inputType = `${modelName}CreateOneSchema`;
break;
case "createMany":
inputType = `${modelName}CreateManySchema`;
break;
case "createManyAndReturn":
inputType = `${modelName}CreateManySchema`;
break;
case "deleteOne":
inputType = `${modelName}DeleteOneSchema`;
break;
case "deleteMany":
inputType = `${modelName}DeleteManySchema`;
break;
case "updateOne":
inputType = `${modelName}UpdateOneSchema`;
break;
case "updateMany":
inputType = `${modelName}UpdateManySchema`;
break;
case "updateManyAndReturn":
inputType = `${modelName}UpdateManySchema`;
break;
case "upsertOne":
inputType = `${modelName}UpsertSchema`;
break;
case "aggregate":
inputType = `${modelName}AggregateSchema`;
break;
case "aggregateRaw":
inputType = `${modelName}AggregateRawObjectSchema`;
break;
case "groupBy":
inputType = `${modelName}GroupBySchema`;
break;
default: console.log("getInputTypeByOpName: ", {
opName,
modelName
});
}
return inputType;
};
const getProcedureTypeByOpName = (opName) => {
let procType;
switch (opName) {
case "findUnique":
case "findFirst":
case "findMany":
case "findRaw":
case "aggregate":
case "aggregateRaw":
case "groupBy":
procType = "query";
break;
case "createOne":
case "createMany":
case "createManyAndReturn":
case "deleteOne":
case "updateOne":
case "deleteMany":
case "updateMany":
case "updateManyAndReturn":
case "upsertOne":
procType = "mutation";
break;
default: console.log("getProcedureTypeByOpName: ", { opName });
}
return procType;
};
function resolveModelsComments(models, hiddenModels) {
const modelAttributeRegex = /(?:@@Gen\.)+[A-z]+\(.+\)/;
const attributeNameRegex = /\.+[A-Z]+\(+/i;
const attributeArgsRegex = /\(+[A-Z]+:.+\)/i;
for (const model of models) if (model.documentation) {
const attribute = model.documentation?.match(modelAttributeRegex)?.[0];
if (attribute?.match(attributeNameRegex)?.[0]?.slice(1, -1) !== "model") continue;
const rawAttributeArgs = attribute?.match(attributeArgsRegex)?.[0]?.slice(1, -1);
const parsedAttributeArgs = {};
if (rawAttributeArgs) {
const rawAttributeArgsParts = rawAttributeArgs.split(":").map((it) => it.trim()).map((part) => part.startsWith("[") ? part : part.split(",")).flat().map((it) => it.trim());
for (let i = 0; i < rawAttributeArgsParts.length; i += 2) {
const key = rawAttributeArgsParts[i];
const value = rawAttributeArgsParts[i + 1];
parsedAttributeArgs[key] = JSON.parse(value);
}
}
if (parsedAttributeArgs.hide) hiddenModels.push(model.name);
}
}
const getImports = (type, newPath) => {
let statement = "";
if (type === "trpc") statement = "import * as trpc from '@trpc/server';\n";
else if (type === "trpc-shield") statement = "import { shield, allow } from '@stryke/trpc-next/shield';\n";
else if (type === "context") statement = `import type { Context } from '${newPath}';\n`;
return statement;
};
const wrapWithObject = ({ shieldItemLines }) => {
let wrapped = "{";
wrapped += "\n";
wrapped += Array.isArray(shieldItemLines) ? ` ${shieldItemLines.join(",\r\n")}` : ` ${shieldItemLines}`;
wrapped += "\n";
wrapped += "}";
return wrapped;
};
const wrapWithTrpcShieldCall = ({ shieldObjectTextWrapped }) => {
let wrapped = "shield<Context>(";
wrapped += "\n";
wrapped += ` ${shieldObjectTextWrapped}`;
wrapped += "\n";
wrapped += ")";
return wrapped;
};
const wrapWithExport = ({ shieldObjectText }) => {
return `export const permissions: ReturnType<typeof shield<Context>> = ${shieldObjectText};`;
};
const constructShield = async ({ queries, mutations, subscriptions }, config, options, outputDir) => {
if (queries.length === 0 && mutations.length === 0 && subscriptions.length === 0) return "";
let rootItems = "";
if (queries.length > 0) {
const queryLinesWrapped = `query: ${wrapWithObject({ shieldItemLines: queries.map((query) => `${query}: allow`) })},`;
rootItems += queryLinesWrapped;
}
if (mutations.length > 0) {
const mutationLinesWrapped = `mutation: ${wrapWithObject({ shieldItemLines: mutations.map((mutation) => `${mutation}: allow`) })},`;
rootItems += mutationLinesWrapped;
}
if (subscriptions.length > 0) {
const subscriptionLinesWrapped = `subscription: ${wrapWithObject({ shieldItemLines: subscriptions.map((subscription) => `${subscription}: allow`) })},`;
rootItems += subscriptionLinesWrapped;
}
if (rootItems.length === 0) return "";
let shieldText = getImports("trpc-shield");
shieldText += getImports("context", relativePath(outputDir, joinPaths(outputDir, config.contextPath)));
shieldText += "\n\n";
shieldText += wrapWithExport({ shieldObjectText: wrapWithTrpcShieldCall({ shieldObjectTextWrapped: wrapWithObject({ shieldItemLines: rootItems }) }) });
return shieldText;
};
const constructDefaultOptions = (config, options, outputDir) => {
return `import { ZodError } from 'zod';${config.withNext ? "\nimport { transformer } from \"@stryke/trpc-next/shared\";" : ""}
import type {
DataTransformerOptions,
RootConfig
} from "@trpc/server/unstable-core-do-not-import";
import type { Context } from "${relativePath(outputDir, joinPaths(outputDir, config.contextPath))}";
interface RuntimeConfigOptions<
TContext extends object,
TMeta extends object = object
> extends Partial<
Omit<
RootConfig<{
ctx: TContext;
meta: TMeta;
errorShape: any;
transformer: any;
}>,
"$types" | "transformer"
>
> {
/**
* Use a data transformer
* @see https://trpc.io/docs/v11/data-transformers
*/
transformer?: DataTransformerOptions;
}
const options: RuntimeConfigOptions<Context> = {${config.withNext ? "\n transformer," : ""}
errorFormatter({ shape, error }) {
return {
...shape,
data: {
...shape.data,
zodError:
error.code === "BAD_REQUEST" && error.cause instanceof ZodError
? error.cause.flatten()
: null
}
};
}
};
export default options;
`;
};
const constructZodModels = async (models, outputPath, config, options) => {
const indexFile = project.createSourceFile(`${outputPath}/index.ts`, {}, { overwrite: true });
generateBarrelFile(models, indexFile);
indexFile.formatText({ indentSize: 2 });
await Promise.all(models.map(async (model) => {
const sourceFile = project.createSourceFile(`${outputPath}/${lowerCaseFirst(model.name)}.schema.ts`, {}, { overwrite: true });
await populateModelFile(model, sourceFile, config, options);
sourceFile.formatText({ indentSize: 2 });
}));
};
//#endregion
export { constructDefaultOptions, constructShield, constructZodModels, generateCreateRouterImport, generateProcedure, generateRouterImport, generateRouterSchemaImports, generateTRPCExports, getInputTypeByOpName, resolveModelsComments };
//# sourceMappingURL=helpers.mjs.map