@stryke/prisma-trpc-generator
Version:
A fork of the prisma-trpc-generator code to work in ESM with Prisma v6.
237 lines (235 loc) • 9.52 kB
JavaScript
import { lowerCaseFirst } from "../string-format/src/lower-case-first.mjs";
import { getPrismaInternals } from "../utils/get-prisma-internals.mjs";
import { getJSDocs, getZodDocElements } from "./docs-helpers.mjs";
import path from "node:path";
import { StructureKind, VariableDeclarationKind } from "ts-morph";
//#region src/zod/model-helpers.ts
function checkModelHasModelRelation(model) {
const { fields: modelFields } = model;
for (const modelField of modelFields) if (checkIsModelRelationField(modelField)) return true;
return false;
}
function checkModelHasManyModelRelation(model) {
const { fields: modelFields } = model;
for (const modelField of modelFields) if (checkIsManyModelRelationField(modelField)) return true;
return false;
}
function checkIsModelRelationField(modelField) {
const { kind, relationName } = modelField;
return kind === "object" && !!relationName;
}
function checkIsManyModelRelationField(modelField) {
return checkIsModelRelationField(modelField) && modelField.isList;
}
function findModelByName(models, modelName) {
return models.find(({ name }) => name === modelName);
}
const writeArray = (writer, array, newLine = true) => array.forEach((line) => writer.write(line).conditionalNewLine(newLine));
const useModelNames = ({ modelCase, modelSuffix, relationModel }) => {
const formatModelName = (name, prefix = "") => {
if (modelCase === "camelCase") name = name.slice(0, 1).toLowerCase() + name.slice(1);
return `${prefix}${name}${modelSuffix}`;
};
return {
modelName: (name) => formatModelName(name, relationModel === "default" ? "_" : ""),
relatedModelName: (name) => formatModelName(relationModel === "default" ? name.toString() : `Related${name.toString()}`)
};
};
const dotSlash = (input) => {
const converted = input.replace(/^\\\\\?\\/, "").replace(/\\/g, "/").replace(/\/{2,}/g, "/");
if (converted.includes(`/node_modules/`)) return converted.split(`/node_modules/`).slice(-1)[0];
if (converted.startsWith(`../`)) return converted;
return `./${converted}`;
};
const chunk = (input, size) => {
return input.reduce((arr, item, idx) => {
return idx % size === 0 ? [...arr, [item]] : [...arr.slice(0, -1), [...arr.slice(-1)[0], item]];
}, []);
};
const needsRelatedModel = (model, config) => model.fields.some((field) => field.kind === "object") && config.relationModel !== false;
const writeImportsForModel = async (model, sourceFile, config, options) => {
const outputPath = (await getPrismaInternals()).parseEnvValue(options.generator.output);
const { relatedModelName } = useModelNames(config);
const importList = [{
kind: StructureKind.ImportDeclaration,
namespaceImport: "z",
moduleSpecifier: "zod"
}];
if (config.imports) importList.push({
kind: StructureKind.ImportDeclaration,
namespaceImport: "imports",
moduleSpecifier: dotSlash(path.relative(outputPath, path.resolve(path.dirname(options.schemaPath), config.imports)))
});
if (config.useDecimalJs && model.fields.some((f) => f.type === "Decimal")) importList.push({
kind: StructureKind.ImportDeclaration,
namedImports: ["Decimal"],
moduleSpecifier: "decimal.js"
});
const enumFields = model.fields.filter((f) => f.kind === "enum");
const relationFields = model.fields.filter((f) => f.kind === "object");
const clientPath = options.otherGenerators.find((each) => each.provider.value === "prisma-client-js").output.value;
const relativePath = path.relative(outputPath, clientPath);
if (enumFields.length > 0) importList.push({
kind: StructureKind.ImportDeclaration,
isTypeOnly: enumFields.length === 0,
moduleSpecifier: dotSlash(relativePath),
namedImports: enumFields.map((f) => f.type)
});
if (config.relationModel !== false && relationFields.length > 0) {
const filteredFields = relationFields.filter((f) => f.type !== model.name);
if (filteredFields.length > 0) importList.push({
kind: StructureKind.ImportDeclaration,
moduleSpecifier: "./index",
namedImports: Array.from(new Set(filteredFields.flatMap((f) => [`${f.type}`, relatedModelName(f.type)])))
});
}
sourceFile.addImportDeclarations(importList);
};
const computeCustomSchema = (docString) => {
return getZodDocElements(docString).find((modifier) => modifier.startsWith("custom("))?.slice(7).slice(0, -1);
};
const computeModifiers = (docString) => {
return getZodDocElements(docString).filter((each) => !each.startsWith("custom("));
};
const getZodConstructor = (field, getRelatedModelName = (name) => name.toString()) => {
let zodType = "z.unknown()";
const extraModifiers = [""];
if (field.kind === "scalar") switch (field.type) {
case "String":
zodType = "z.string()";
break;
case "Int":
zodType = "z.number()";
extraModifiers.push("int()");
break;
case "BigInt":
zodType = "z.bigint()";
break;
case "DateTime":
zodType = "z.date()";
break;
case "Float":
zodType = "z.number()";
break;
case "Decimal":
zodType = "z.number()";
break;
case "Json":
zodType = "jsonSchema";
break;
case "Boolean":
zodType = "z.boolean()";
break;
case "Bytes":
zodType = "z.unknown()";
break;
}
else if (field.kind === "enum") zodType = `z.nativeEnum(${field.type})`;
else if (field.kind === "object") zodType = getRelatedModelName(field.type);
if (field.isList) extraModifiers.push("array()");
if (field.documentation) {
zodType = computeCustomSchema(field.documentation) ?? zodType;
extraModifiers.push(...computeModifiers(field.documentation));
}
if (!field.isRequired && field.type !== "Json") extraModifiers.push("nullish()");
return `${zodType}${extraModifiers.join(".")}`;
};
const writeTypeSpecificSchemas = (model, sourceFile, config) => {
if (model.fields.some((f) => f.type === "Json")) sourceFile.addStatements((writer) => {
writer.newLine();
writeArray(writer, [
"// Helper schema for JSON fields",
`type Literal = boolean | number | string${config.prismaJsonNullability ? "" : "| null"}`,
"type Json = Literal | { [key: string]: Json } | Json[]",
`const literalSchema = z.union([z.string(), z.number(), z.boolean()${config.prismaJsonNullability ? "" : ", z.null()"}])`,
"const jsonSchema: z.ZodSchema<Json> = z.lazy(() => z.union([literalSchema, z.array(jsonSchema), z.record(jsonSchema)]))"
]);
});
if (config.useDecimalJs && model.fields.some((f) => f.type === "Decimal")) sourceFile.addStatements((writer) => {
writer.newLine();
writeArray(writer, [
"// Helper schema for Decimal fields",
"z",
".instanceof(Decimal)",
".or(z.string())",
".or(z.number())",
".refine((value) => {",
" try {",
" return new Decimal(value);",
" } catch (error) {",
" return false;",
" }",
"})",
".transform((value) => new Decimal(value));"
]);
});
};
const generateSchemaForModel = (model, sourceFile, config) => {
const { modelName } = useModelNames(config);
sourceFile.addVariableStatement({
declarationKind: VariableDeclarationKind.Const,
isExported: true,
leadingTrivia: (writer) => writer.blankLineIfLastNot(),
declarations: [{
name: modelName(model.name),
initializer(writer) {
writer.write("z.object(").inlineBlock(() => {
model.fields.filter((f) => f.kind !== "object").forEach((field) => {
writeArray(writer, getJSDocs(field.documentation));
writer.write(`${field.name}: ${getZodConstructor(field)}`).write(",").newLine();
});
}).write(")");
}
}]
});
};
const generateRelatedSchemaForModel = (model, sourceFile, config) => {
const { modelName, relatedModelName } = useModelNames(config);
const relationFields = model.fields.filter((f) => f.kind === "object");
sourceFile.addInterface({
name: `${model.name}`,
isExported: true,
extends: [`z.infer<typeof ${modelName(model.name)}>`],
properties: relationFields.map((f) => ({
hasQuestionToken: !f.isRequired,
name: f.name,
type: `${f.type}${f.isList ? "[]" : ""}${!f.isRequired ? " | null" : ""}`
}))
});
sourceFile.addStatements((writer) => writeArray(writer, [
"",
"/**",
` * ${relatedModelName(model.name)} contains all relations on your model in addition to the scalars`,
" *",
" * NOTE: Lazy required in case of potential circular dependencies within schema",
" */"
]));
sourceFile.addVariableStatement({
declarationKind: VariableDeclarationKind.Const,
isExported: true,
declarations: [{
name: relatedModelName(model.name),
type: `z.ZodSchema<${model.name}>`,
initializer(writer) {
writer.write(`z.lazy(() => ${modelName(model.name)}.extend(`).inlineBlock(() => {
relationFields.forEach((field) => {
writeArray(writer, getJSDocs(field.documentation));
writer.write(`${field.name}: ${getZodConstructor(field, relatedModelName)}`).write(",").newLine();
});
}).write("))");
}
}]
});
};
const populateModelFile = async (model, sourceFile, config, options) => {
await writeImportsForModel(model, sourceFile, config, options);
writeTypeSpecificSchemas(model, sourceFile, config);
generateSchemaForModel(model, sourceFile, config);
if (needsRelatedModel(model, config)) generateRelatedSchemaForModel(model, sourceFile, config);
};
const generateBarrelFile = (models, indexFile) => {
models.forEach((model) => indexFile.addExportDeclaration({ moduleSpecifier: `./${lowerCaseFirst(model.name)}.schema` }));
};
//#endregion
export { checkIsModelRelationField, checkModelHasManyModelRelation, checkModelHasModelRelation, chunk, findModelByName, generateBarrelFile, populateModelFile };
//# sourceMappingURL=model-helpers.mjs.map