@markandrus/effect-derive
Version:
Derive Covariant (Functor), Foldable, and Traversable instances, as well as base functors, for algebraic data types (ADTs)
134 lines (130 loc) • 5.29 kB
JavaScript
import { Node } from 'ts-morph';
import { deriveTypeLambda } from './deriveTypeLambda';
import { OutFile } from './OutFile';
import { createRegistryMatcher } from './RegistryMatcher';
const tyParamPlaceholders = ['C', 'D'];
export function deriveCovariant(inFilePath, forType, discriminator, registries, node) {
const outFile = new OutFile();
const tyParams = node.getTypeParameters();
if (tyParams.length < 1) {
throw new Error('At least one type parameter is required to derive Covariant');
}
else if (tyParams.length > 3) {
throw new Error('At most 3 type parameters are supported when deriving Covariant, due to limitations in effect\'s HKT encoding');
}
// In Haskell-style, we take the rightmost type parameter to be the "hole".
const holeIndex = tyParams.length - 1;
const tyParam = tyParams[holeIndex];
registries.covariant.set(forType, [holeIndex, 'map']);
const matcher = createRegistryMatcher(registries.covariant);
let freeTyParams = '';
for (let i = tyParams.length - 2; i >= 0; i--) {
freeTyParams += `${tyParamPlaceholders[i]}${i > 0 ? ', ' : ''}`;
}
const freeTyParamsPrefix = freeTyParams === '' ? '' : `${freeTyParams}, `;
freeTyParams = freeTyParams === '' ? '' : `<${freeTyParams}>`;
const tyNode = node.getTypeNodeOrThrow();
let tyNodes = [tyNode];
if (Node.isUnionTypeNode(tyNode)) {
if (discriminator == null) {
throw new Error('--discriminator is required for union types');
}
tyNodes = tyNode.getTypeNodes();
}
else if (!Node.isTypeLiteral(tyNode)) {
throw new Error(`Type alias "${forType}" must be a union or type literal`);
}
const switchStmt = handleTypeNodes(matcher, forType, discriminator, tyParam.getName(), tyNodes);
outFile
.addPackageAsteriskImport('@effect/typeclass/Covariant', 'covariant')
.addPackageImport('effect/Function', 'dual');
if (!registries.typeLambda.has(forType)) {
outFile.merge(deriveTypeLambda(inFilePath, forType, registries.typeLambda, node));
}
if (inFilePath != null)
outFile.addLocalImport(inFilePath, forType, true);
return outFile.addDeclarations(`\
export const map: {
<A, B>(f: (a: A) => B): ${freeTyParams}(self: ${forType}<${freeTyParamsPrefix}A>) => ${forType}<${freeTyParamsPrefix}B>
<${freeTyParamsPrefix}A, B>(self: ${forType}<${freeTyParamsPrefix}A>, f: (a: A) => B): ${forType}<${freeTyParamsPrefix}B>
} = dual(
2,
<${freeTyParamsPrefix}A, B>(self: ${forType}<${freeTyParamsPrefix}A>, f: (a: A) => B): ${forType}<${freeTyParamsPrefix}B> => {
${switchStmt}
}
)
const imap = covariant.imap<${forType}TypeLambda>(map)
export const Covariant: covariant.Covariant<${forType}TypeLambda> = {
imap,
map
}
`);
}
function handleTypeNodes(matcher, forType, discriminator, tyParam, tyNodes) {
let cases = '';
for (const tyNode of tyNodes) {
cases += handleTypeNode(matcher, forType, discriminator, tyParam, tyNode);
}
if (discriminator == null) {
return cases;
}
return `\
switch (self[${JSON.stringify(discriminator)}]) {
${cases} default:
throw new Error(\`Unknown tag "\${self[${JSON.stringify(discriminator)}]}"\`)
}`;
}
function handleTypeNode(matcher, forType, discriminator, tyParam, tyNode) {
if (!Node.isTypeLiteral(tyNode)) {
throw new Error(`Every member of the union type "${forType}" must be a TypeLiteral`);
}
let discriminatorValue;
let updates = '';
for (const member of tyNode.getMembers()) {
if (!Node.isPropertySignature(member)) {
throw new Error(`Expected a PropertySignature; got ${member.getKindName()}`);
}
const memberName = member.getName();
const memberValue = member.getTypeNodeOrThrow();
if (discriminator != null && memberName === discriminator) {
if (!Node.isLiteralTypeNode(memberValue)) {
throw new Error(`Expected discriminator "${discriminator}" to be a LiteralType; got ${memberValue.getKindName()}`);
}
discriminatorValue = memberValue.getText();
continue;
}
const mapFunctions = matcher(tyParam, memberValue);
if (mapFunctions == null)
continue;
updates += `, ${JSON.stringify(memberName)}: `;
if (mapFunctions.length === 0) {
updates += `f(self[${JSON.stringify(memberName)}])`;
continue;
}
let i = 0;
let suffix = '';
for (const mapFunction of mapFunctions) {
if (i++ === 0) {
updates += `${mapFunction}(self[${JSON.stringify(memberName)}], `;
suffix += 'f)';
}
else {
updates += `_ => ${mapFunction}(_, `;
suffix += ')';
}
}
updates += suffix;
}
if (discriminator != null && discriminatorValue == null) {
throw new Error(`Missing a discriminator "${discriminator}"`);
}
updates = updates === '' ? 'self' : `{ ...self${updates} }`;
if (discriminator == null) {
return `\
return ${updates}`;
}
return `\
case ${discriminatorValue}:
return ${updates}
`;
}