UNPKG

@codama/renderers-rust

Version:
1,119 lines (1,114 loc) 44.2 kB
import { CodamaError, CODAMA_ERROR__RENDERERS__UNSUPPORTED_NODE, logWarn, logError, CODAMA_ERROR__UNEXPECTED_NODE_KIND } from '@codama/errors'; import { REGISTERED_TYPE_NODE_KINDS, definedTypeNode, pascalCase, snakeCase, parseDocs, resolveNestedTypeNode, isNode, remainderCountNode, fixedCountNode, prefixedCountNode, arrayTypeNode, numberTypeNode, getAllPrograms, getAllAccounts, getAllInstructionsWithSubs, getAllDefinedTypes, VALUE_NODES, structTypeNodeFromInstructionArgumentNodes, isNodeFilter, camelCase, kebabCase, titleCase, assertIsNode, bytesValueNode, numberValueNode, arrayValueNode, isScalarEnum } from '@codama/nodes'; import { createRenderMap, mergeRenderMaps, addToRenderMap, deleteDirectory, writeRenderMapVisitor } from '@codama/renderers-core'; import { pipe, mergeVisitor, extendVisitor, visit, LinkableDictionary, NodeStack, staticVisitor, recordNodeStackVisitor, recordLinkablesOnFirstVisitVisitor, rootNodeVisitor } from '@codama/visitors-core'; import { getBase64Encoder, getBase58Encoder, getBase16Encoder, getUtf8Encoder } from '@solana/codecs-strings'; import { dirname, join } from 'path'; import { fileURLToPath } from 'url'; import nunjucks from 'nunjucks'; import { spawnSync } from 'child_process'; // src/ImportMap.ts var DEFAULT_MODULE_MAP = { generated: "crate::generated", generatedAccounts: "crate::generated::accounts", generatedErrors: "crate::generated::errors", generatedInstructions: "crate::generated::instructions", generatedTypes: "crate::generated::types", hooked: "crate::hooked", mplEssentials: "mpl_toolbox", mplToolbox: "mpl_toolbox" }; var ImportMap = class _ImportMap { _imports = /* @__PURE__ */ new Set(); _aliases = /* @__PURE__ */ new Map(); get imports() { return this._imports; } get aliases() { return this._aliases; } add(imports) { const newImports = typeof imports === "string" ? [imports] : imports; newImports.forEach((i) => this._imports.add(i)); return this; } remove(imports) { const importsToRemove = typeof imports === "string" ? [imports] : imports; importsToRemove.forEach((i) => this._imports.delete(i)); return this; } mergeWith(...others) { others.forEach((other) => { this.add(other._imports); other._aliases.forEach((alias, importName) => this.addAlias(importName, alias)); }); return this; } mergeWithManifest(manifest) { return this.mergeWith(manifest.imports); } addAlias(importName, alias) { this._aliases.set(importName, alias); return this; } isEmpty() { return this._imports.size === 0; } resolveDependencyMap(dependencies) { const dependencyMap = { ...DEFAULT_MODULE_MAP, ...dependencies }; const newImportMap = new _ImportMap(); const resolveDependency = (i) => { const dependencyKey = Object.keys(dependencyMap).find((key) => i.startsWith(`${key}::`)); if (!dependencyKey) return i; const dependencyValue = dependencyMap[dependencyKey]; return dependencyValue + i.slice(dependencyKey.length); }; this._imports.forEach((i) => newImportMap.add(resolveDependency(i))); this._aliases.forEach((alias, i) => newImportMap.addAlias(resolveDependency(i), alias)); return newImportMap; } toString(dependencies) { const resolvedMap = this.resolveDependencyMap(dependencies); const importStatements = [...resolvedMap.imports].map((i) => { const alias = resolvedMap.aliases.get(i); if (alias) return `use ${i} as ${alias};`; return `use ${i};`; }); return importStatements.join("\n"); } }; function getBytesFromBytesValueNode(node) { switch (node.encoding) { case "utf8": return getUtf8Encoder().encode(node.data); case "base16": return getBase16Encoder().encode(node.data); case "base58": return getBase58Encoder().encode(node.data); case "base64": default: return getBase64Encoder().encode(node.data); } } function renderValueNode(value, getImportFrom, useStr = false) { return visit(value, renderValueNodeVisitor(getImportFrom, useStr)); } function renderValueNodeVisitor(getImportFrom, useStr = false) { return { visitArrayValue(node) { const list = node.items.map((v) => visit(v, this)); return { imports: new ImportMap().mergeWith(...list.map((c) => c.imports)), render: `[${list.map((c) => c.render).join(", ")}]` }; }, visitBooleanValue(node) { return { imports: new ImportMap(), render: JSON.stringify(node.boolean) }; }, visitBytesValue(node) { const bytes = getBytesFromBytesValueNode(node); const numbers = Array.from(bytes).map(numberValueNode); return visit(arrayValueNode(numbers), this); }, visitConstantValue(node) { if (isNode(node.value, "bytesValueNode")) { return visit(node.value, this); } if (isNode(node.type, "stringTypeNode") && isNode(node.value, "stringValueNode")) { return visit(bytesValueNode(node.type.encoding, node.value.string), this); } if (isNode(node.type, "numberTypeNode") && isNode(node.value, "numberValueNode")) { const numberManifest = visit(node.value, this); const { format, endian } = node.type; const byteFunction = endian === "le" ? "to_le_bytes" : "to_be_bytes"; numberManifest.render = `${numberManifest.render}${format}.${byteFunction}()`; return numberManifest; } throw new Error("Unsupported constant value type."); }, visitEnumValue(node) { const imports = new ImportMap(); const enumName = pascalCase(node.enum.name); const variantName = pascalCase(node.variant); const importFrom = getImportFrom(node.enum); imports.add(`${importFrom}::${enumName}`); if (!node.value) { return { imports, render: `${enumName}::${variantName}` }; } const enumValue = visit(node.value, this); const fields = enumValue.render; return { imports: imports.mergeWith(enumValue.imports), render: `${enumName}::${variantName} ${fields}` }; }, visitMapEntryValue(node) { const mapKey = visit(node.key, this); const mapValue = visit(node.value, this); return { imports: mapKey.imports.mergeWith(mapValue.imports), render: `[${mapKey.render}, ${mapValue.render}]` }; }, visitMapValue(node) { const map = node.entries.map((entry) => visit(entry, this)); const imports = new ImportMap().add("std::collection::HashMap"); return { imports: imports.mergeWith(...map.map((c) => c.imports)), render: `HashMap::from([${map.map((c) => c.render).join(", ")}])` }; }, visitNoneValue() { return { imports: new ImportMap(), render: "None" }; }, visitNumberValue(node) { return { imports: new ImportMap(), render: node.number.toString() }; }, visitPublicKeyValue(node) { return { imports: new ImportMap().add("solana_pubkey"), render: `pubkey!("${node.publicKey}")` }; }, visitSetValue(node) { const set = node.items.map((v) => visit(v, this)); const imports = new ImportMap().add("std::collection::HashSet"); return { imports: imports.mergeWith(...set.map((c) => c.imports)), render: `HashSet::from([${set.map((c) => c.render).join(", ")}])` }; }, visitSomeValue(node) { const child = visit(node.value, this); return { ...child, render: `Some(${child.render})` }; }, visitStringValue(node) { return { imports: new ImportMap(), render: useStr ? `${JSON.stringify(node.string)}` : `String::from(${JSON.stringify(node.string)})` }; }, visitStructFieldValue(node) { const structValue = visit(node.value, this); return { imports: structValue.imports, render: `${node.name}: ${structValue.render}` }; }, visitStructValue(node) { const struct = node.fields.map((field) => visit(field, this)); return { imports: new ImportMap().mergeWith(...struct.map((c) => c.imports)), render: `{ ${struct.map((c) => c.render).join(", ")} }` }; }, visitTupleValue(node) { const tuple = node.items.map((v) => visit(v, this)); return { imports: new ImportMap().mergeWith(...tuple.map((c) => c.imports)), render: `(${tuple.map((c) => c.render).join(", ")})` }; } }; } // src/utils/discriminatorConstant.ts function mergeFragments(fragments, merge) { const imports = fragments.reduce((acc, frag) => acc.mergeWith(frag.imports), new ImportMap()); const render2 = merge(fragments.map((frag) => frag.render)); return { imports, render: render2 }; } function getDiscriminatorConstants(scope) { const fragments = scope.discriminatorNodes.map((node) => getDiscriminatorConstant(node, scope)).filter(Boolean); return mergeFragments(fragments, (r) => r.join("\n\n")); } function getDiscriminatorConstant(discriminatorNode, scope) { switch (discriminatorNode.kind) { case "constantDiscriminatorNode": return getConstantDiscriminatorConstant(discriminatorNode, scope); case "fieldDiscriminatorNode": return getFieldDiscriminatorConstant(discriminatorNode, scope); default: return null; } } function getConstantDiscriminatorConstant(discriminatorNode, scope) { const { discriminatorNodes, getImportFrom, prefix, typeManifestVisitor } = scope; const index = discriminatorNodes.filter(isNodeFilter("constantDiscriminatorNode")).indexOf(discriminatorNode); const suffix = index <= 0 ? "" : `_${index + 1}`; const name = camelCase(`${prefix}_discriminator${suffix}`); const typeManifest = visit(discriminatorNode.constant.type, typeManifestVisitor); const value = renderValueNode(discriminatorNode.constant.value, getImportFrom); return getConstant(name, typeManifest, value); } function getFieldDiscriminatorConstant(discriminatorNode, scope) { const { fields, prefix, getImportFrom, typeManifestVisitor } = scope; const field = fields.find((f) => f.name === discriminatorNode.name); if (!field || !field.defaultValue || !isNode(field.defaultValue, VALUE_NODES)) { return null; } const name = camelCase(`${prefix}_${discriminatorNode.name}`); const typeManifest = visit(field.type, typeManifestVisitor); const value = renderValueNode(field.defaultValue, getImportFrom); return getConstant(name, typeManifest, value); } function getConstant(name, typeManifest, value) { const type = { imports: typeManifest.imports, render: typeManifest.type }; return mergeFragments([type, value], ([t, v]) => `pub const ${snakeCase(name).toUpperCase()}: ${t} = ${v};`); } function getImportFromFactory(overrides) { const linkOverrides = { accounts: overrides.accounts ?? {}, definedTypes: overrides.definedTypes ?? {}, instructions: overrides.instructions ?? {}, pdas: overrides.pdas ?? {}, programs: overrides.programs ?? {}, resolvers: overrides.resolvers ?? {} }; return (node) => { const kind = node.kind; switch (kind) { case "accountLinkNode": return linkOverrides.accounts[node.name] ?? "generatedAccounts"; case "definedTypeLinkNode": return linkOverrides.definedTypes[node.name] ?? "generatedTypes"; case "instructionLinkNode": return linkOverrides.instructions[node.name] ?? "generatedInstructions"; case "pdaLinkNode": return linkOverrides.pdas[node.name] ?? "generatedAccounts"; case "programLinkNode": return linkOverrides.programs[node.name] ?? "generatedPrograms"; case "resolverValueNode": return linkOverrides.resolvers[node.name] ?? "hooked"; default: throw new CodamaError(CODAMA_ERROR__UNEXPECTED_NODE_KIND, { expectedKinds: [ "AccountLinkNode", "DefinedTypeLinkNode", "InstructionLinkNode", "PdaLinkNode", "ProgramLinkNode", "resolverValueNode" ], kind, node }); } }; } function rustDocblock(docs) { if (docs.length <= 0) return ""; const lines = docs.map((doc) => `/// ${doc}`); return `${lines.join("\n")} `; } var render = (template, context, options) => { const dirname$1 = dirname(fileURLToPath(import.meta.url)) ; const templates = join(dirname$1, "templates"); const env = nunjucks.configure(templates, { autoescape: false, trimBlocks: true, ...options }); env.addFilter("pascalCase", pascalCase); env.addFilter("camelCase", camelCase); env.addFilter("snakeCase", snakeCase); env.addFilter("kebabCase", kebabCase); env.addFilter("titleCase", titleCase); env.addFilter("rustDocblock", rustDocblock); env.addFilter("hasTrait", (traits, ...traitNames) => { if (typeof traits !== "string") return false; return traitNames.some((traitName) => traits.includes(traitName)); }); return env.render(template, context); }; var DEFAULT_TRAIT_OPTIONS = { baseDefaults: [ "borsh::BorshSerialize", "borsh::BorshDeserialize", "serde::Serialize", "serde::Deserialize", "Clone", "Debug", "Eq", "PartialEq" ], dataEnumDefaults: [], featureFlags: { serde: ["serde::Serialize", "serde::Deserialize"] }, overrides: {}, scalarEnumDefaults: ["Copy", "PartialOrd", "Hash", "num_derive::FromPrimitive"], structDefaults: [], useFullyQualifiedName: false }; function getTraitsFromNodeFactory(options = {}) { return (node) => getTraitsFromNode(node, options); } function getTraitsFromNode(node, userOptions = {}) { assertIsNode(node, ["accountNode", "definedTypeNode", "instructionNode"]); const options = { ...DEFAULT_TRAIT_OPTIONS, ...userOptions }; const nodeType = getNodeType(node); if (nodeType === "alias") { return { imports: new ImportMap(), render: "" }; } const sanitizedOverrides = Object.fromEntries( Object.entries(options.overrides).map(([key, value]) => [camelCase(key), value]) ); const nodeOverrides = sanitizedOverrides[node.name]; const allTraits = nodeOverrides === void 0 ? getDefaultTraits(nodeType, options) : nodeOverrides; const partitionedTraits = partitionTraitsInFeatures(allTraits, options.featureFlags); let unfeaturedTraits = partitionedTraits[0]; const featuredTraits = partitionedTraits[1]; const imports = new ImportMap(); if (!options.useFullyQualifiedName) { unfeaturedTraits = extractFullyQualifiedNames(unfeaturedTraits, imports); } const traitLines = [ ...unfeaturedTraits.length > 0 ? [`#[derive(${unfeaturedTraits.join(", ")})] `] : [], ...Object.entries(featuredTraits).map(([feature, traits]) => { return `#[cfg_attr(feature = "${feature}", derive(${traits.join(", ")}))] `; }) ]; return { imports, render: traitLines.join("") }; } function getNodeType(node) { if (isNode(node, ["accountNode", "instructionNode"])) return "struct"; if (isNode(node.type, "structTypeNode")) return "struct"; if (isNode(node.type, "enumTypeNode")) { return isScalarEnum(node.type) ? "scalarEnum" : "dataEnum"; } return "alias"; } function getDefaultTraits(nodeType, options) { switch (nodeType) { case "dataEnum": return [...options.baseDefaults, ...options.dataEnumDefaults]; case "scalarEnum": return [...options.baseDefaults, ...options.scalarEnumDefaults]; case "struct": return [...options.baseDefaults, ...options.structDefaults]; } } function partitionTraitsInFeatures(traits, featureFlags) { const reverseFeatureFlags = Object.entries(featureFlags).reduce( (acc, [feature, traits2]) => { for (const trait of traits2) { if (!acc[trait]) acc[trait] = feature; } return acc; }, {} ); const unfeaturedTraits = []; const featuredTraits = {}; for (const trait of traits) { const feature = reverseFeatureFlags[trait]; if (feature === void 0) { unfeaturedTraits.push(trait); } else { if (!featuredTraits[feature]) featuredTraits[feature] = []; featuredTraits[feature].push(trait); } } return [unfeaturedTraits, featuredTraits]; } function extractFullyQualifiedNames(traits, imports) { return traits.map((trait) => { const index = trait.lastIndexOf("::"); if (index === -1) return trait; imports.add(trait); return trait.slice(index + 2); }); } function getSerdeFieldAttribute(serdeWith, node, userOptions = {}) { assertIsNode(node, ["accountNode", "definedTypeNode", "instructionNode"]); const options = { ...DEFAULT_TRAIT_OPTIONS, ...userOptions }; const nodeType = getNodeType(node); if (nodeType === "alias") { return ""; } const sanitizedOverrides = Object.fromEntries( Object.entries(options.overrides).map(([key, value]) => [camelCase(key), value]) ); const nodeOverrides = sanitizedOverrides[node.name]; const allTraits = nodeOverrides === void 0 ? getDefaultTraits(nodeType, options) : nodeOverrides; const hasSerdeSerialize = allTraits.some((t) => t === "serde::Serialize" || t === "Serialize"); const hasSerdeDeserialize = allTraits.some((t) => t === "serde::Deserialize" || t === "Deserialize"); if (!hasSerdeSerialize && !hasSerdeDeserialize) { return ""; } const partitionedTraits = partitionTraitsInFeatures(allTraits, options.featureFlags); const featuredTraits = partitionedTraits[1]; let serdeFeatureName; for (const [feature, traits] of Object.entries(featuredTraits)) { if (traits.some( (t) => t === "serde::Serialize" || t === "serde::Deserialize" || t === "Serialize" || t === "Deserialize" )) { serdeFeatureName = feature; break; } } if (serdeFeatureName) { return `#[cfg_attr(feature = "${serdeFeatureName}", serde(with = "${serdeWith}"))] `; } else { return `#[serde(with = "${serdeWith}")] `; } } // src/getTypeManifestVisitor.ts function getTypeManifestVisitor(options) { const { getImportFrom, getTraitsFromNode: getTraitsFromNode2, traitOptions } = options; let parentName = options.parentName ?? null; let nestedStruct = options.nestedStruct ?? false; let inlineStruct = false; let parentSize = null; let parentNode = null; return pipe( mergeVisitor( () => ({ imports: new ImportMap(), nestedStructs: [], type: "" }), (_, values) => ({ ...mergeManifests(values), type: values.map((v) => v.type).join("\n") }), { keys: [...REGISTERED_TYPE_NODE_KINDS, "definedTypeLinkNode", "definedTypeNode", "accountNode"] } ), (v) => extendVisitor(v, { visitAccount(account, { self }) { parentName = pascalCase(account.name); parentNode = account; const manifest = visit(account.data, self); const traits = getTraitsFromNode2(account); manifest.imports.mergeWith(traits.imports); parentName = null; parentNode = null; return { ...manifest, type: traits.render + manifest.type }; }, visitArrayType(arrayType, { self }) { const childManifest = visit(arrayType.item, self); if (isNode(arrayType.count, "fixedCountNode")) { return { ...childManifest, type: `[${childManifest.type}; ${arrayType.count.value}]` }; } if (isNode(arrayType.count, "remainderCountNode")) { childManifest.imports.add("kaigan::types::RemainderVec"); return { ...childManifest, type: `RemainderVec<${childManifest.type}>` }; } const prefix = resolveNestedTypeNode(arrayType.count.prefix); if (prefix.endian === "le") { switch (prefix.format) { case "u32": return { ...childManifest, type: `Vec<${childManifest.type}>` }; case "u8": case "u16": case "u64": { const prefixFormat = prefix.format.toUpperCase(); childManifest.imports.add(`kaigan::types::${prefixFormat}PrefixVec`); return { ...childManifest, type: `${prefixFormat}PrefixVec<${childManifest.type}>` }; } case "shortU16": { childManifest.imports.add("solana_short_vec::ShortVec"); return { ...childManifest, type: `ShortVec<${childManifest.type}>` }; } default: throw new Error(`Array prefix not supported: ${prefix.format}`); } } throw new Error("Array size not supported by Borsh"); }, visitBooleanType(booleanType) { const resolvedSize = resolveNestedTypeNode(booleanType.size); if (resolvedSize.format === "u8" && resolvedSize.endian === "le") { return { imports: new ImportMap(), nestedStructs: [], type: "bool" }; } throw new Error("Bool size not supported by Borsh"); }, visitBytesType(_bytesType, { self }) { let arraySize = remainderCountNode(); if (typeof parentSize === "number") { arraySize = fixedCountNode(parentSize); } else if (parentSize && typeof parentSize === "object") { arraySize = prefixedCountNode(parentSize); } const arrayType = arrayTypeNode(numberTypeNode("u8"), arraySize); return visit(arrayType, self); }, visitDefinedType(definedType, { self }) { parentName = pascalCase(definedType.name); parentNode = definedType; const manifest = visit(definedType.type, self); const traits = getTraitsFromNode2(definedType); manifest.imports.mergeWith(traits.imports); parentName = null; parentNode = null; const renderedType = isNode(definedType.type, ["enumTypeNode", "structTypeNode"]) ? manifest.type : `pub type ${pascalCase(definedType.name)} = ${manifest.type};`; return { ...manifest, type: `${traits.render}${renderedType}` }; }, visitDefinedTypeLink(node) { const pascalCaseDefinedType = pascalCase(node.name); const importFrom = getImportFrom(node); return { imports: new ImportMap().add(`${importFrom}::${pascalCaseDefinedType}`), nestedStructs: [], type: pascalCaseDefinedType }; }, visitEnumEmptyVariantType(enumEmptyVariantType) { const name = pascalCase(enumEmptyVariantType.name); return { imports: new ImportMap(), nestedStructs: [], type: `${name},` }; }, visitEnumStructVariantType(enumStructVariantType, { self }) { const name = pascalCase(enumStructVariantType.name); const originalParentName = parentName; if (!originalParentName) { throw new Error("Enum struct variant type must have a parent name."); } inlineStruct = true; parentName = pascalCase(originalParentName) + name; const typeManifest = visit(enumStructVariantType.struct, self); inlineStruct = false; parentName = originalParentName; return { ...typeManifest, type: `${name} ${typeManifest.type},` }; }, visitEnumTupleVariantType(enumTupleVariantType, { self }) { const name = pascalCase(enumTupleVariantType.name); const originalParentName = parentName; if (!originalParentName) { throw new Error("Enum struct variant type must have a parent name."); } parentName = pascalCase(originalParentName) + name; const childManifest = visit(enumTupleVariantType.tuple, self); parentName = originalParentName; let derive = ""; if (parentNode && childManifest.type === "(Pubkey)") { derive = getSerdeFieldAttribute( "serde_with::As::<serde_with::DisplayFromStr>", parentNode, traitOptions ); } else if (parentNode && childManifest.type === "(Vec<Pubkey>)") { derive = getSerdeFieldAttribute( "serde_with::As::<Vec<serde_with::DisplayFromStr>>", parentNode, traitOptions ); } return { ...childManifest, type: `${derive}${name}${childManifest.type},` }; }, visitEnumType(enumType, { self }) { const originalParentName = parentName; if (!originalParentName) { throw new Error("Enum type must have a parent name."); } const variants = enumType.variants.map((variant) => visit(variant, self)); const variantNames = variants.map((variant) => variant.type).join("\n"); const mergedManifest = mergeManifests(variants); return { ...mergedManifest, type: `pub enum ${pascalCase(originalParentName)} { ${variantNames} }` }; }, visitFixedSizeType(fixedSizeType, { self }) { parentSize = fixedSizeType.size; const manifest = visit(fixedSizeType.type, self); parentSize = null; return manifest; }, visitMapType(mapType, { self }) { const key = visit(mapType.key, self); const value = visit(mapType.value, self); const mergedManifest = mergeManifests([key, value]); mergedManifest.imports.add("std::collections::HashMap"); return { ...mergedManifest, type: `HashMap<${key.type}, ${value.type}>` }; }, visitNumberType(numberType) { if (numberType.endian !== "le") { throw new Error("Number endianness not supported by Borsh"); } if (numberType.format === "shortU16") { return { imports: new ImportMap().add("solana_short_vec::ShortU16"), nestedStructs: [], type: "ShortU16" }; } return { imports: new ImportMap(), nestedStructs: [], type: numberType.format }; }, visitOptionType(optionType, { self }) { const childManifest = visit(optionType.item, self); const optionPrefix = resolveNestedTypeNode(optionType.prefix); if (optionPrefix.format === "u8" && optionPrefix.endian === "le") { return { ...childManifest, type: `Option<${childManifest.type}>` }; } throw new Error("Option size not supported by Borsh"); }, visitPublicKeyType() { return { imports: new ImportMap().add("solana_pubkey::Pubkey"), nestedStructs: [], type: "Pubkey" }; }, visitRemainderOptionType(node) { throw new CodamaError(CODAMA_ERROR__RENDERERS__UNSUPPORTED_NODE, { kind: node.kind, node }); }, visitSetType(setType, { self }) { const childManifest = visit(setType.item, self); childManifest.imports.add("std::collections::HashSet"); return { ...childManifest, type: `HashSet<${childManifest.type}>` }; }, visitSizePrefixType(sizePrefixType, { self }) { parentSize = resolveNestedTypeNode(sizePrefixType.prefix); const manifest = visit(sizePrefixType.type, self); parentSize = null; return manifest; }, visitStringType() { if (!parentSize) { return { imports: new ImportMap().add(`kaigan::types::RemainderStr`), nestedStructs: [], type: `RemainderStr` }; } if (typeof parentSize === "number") { return { imports: new ImportMap(), nestedStructs: [], type: `[u8; ${parentSize}]` }; } if (isNode(parentSize, "numberTypeNode") && parentSize.endian === "le") { switch (parentSize.format) { case "u32": return { imports: new ImportMap(), nestedStructs: [], type: "String" }; case "u8": case "u16": case "u64": { const prefix = parentSize.format.toUpperCase(); return { imports: new ImportMap().add(`kaigan::types::${prefix}PrefixString`), nestedStructs: [], type: `${prefix}PrefixString` }; } default: throw new Error(`'String size not supported: ${parentSize.format}`); } } throw new Error("String size not supported by Borsh"); }, visitStructFieldType(structFieldType, { self }) { const originalParentName = parentName; const originalInlineStruct = inlineStruct; const originalNestedStruct = nestedStruct; if (!originalParentName) { throw new Error("Struct field type must have a parent name."); } parentName = pascalCase(originalParentName) + pascalCase(structFieldType.name); nestedStruct = true; inlineStruct = false; const fieldManifest = visit(structFieldType.type, self); parentName = originalParentName; inlineStruct = originalInlineStruct; nestedStruct = originalNestedStruct; const fieldName = snakeCase(structFieldType.name); const docblock = rustDocblock(parseDocs(structFieldType.docs)); const resolvedNestedType = resolveNestedTypeNode(structFieldType.type); let derive = ""; if (parentNode) { if (fieldManifest.type === "Pubkey") { derive = getSerdeFieldAttribute( "serde_with::As::<serde_with::DisplayFromStr>", parentNode, traitOptions ); } else if (fieldManifest.type === "Vec<Pubkey>") { derive = getSerdeFieldAttribute( "serde_with::As::<Vec<serde_with::DisplayFromStr>>", parentNode, traitOptions ); } else if (isNode(resolvedNestedType, "arrayTypeNode") && isNode(resolvedNestedType.count, "fixedCountNode") && resolvedNestedType.count.value > 32) { derive = getSerdeFieldAttribute("serde_big_array::BigArray", parentNode, traitOptions); } else if (isNode(resolvedNestedType, ["bytesTypeNode", "stringTypeNode"]) && isNode(structFieldType.type, "fixedSizeTypeNode") && structFieldType.type.size > 32) { derive = getSerdeFieldAttribute( "serde_with::As::<serde_with::Bytes>", parentNode, traitOptions ); } } return { ...fieldManifest, type: inlineStruct ? `${docblock}${derive}${fieldName}: ${fieldManifest.type},` : `${docblock}${derive}pub ${fieldName}: ${fieldManifest.type},` }; }, visitStructType(structType, { self }) { const originalParentName = parentName; if (!originalParentName) { throw new Error("Struct type must have a parent name."); } const fields = structType.fields.map((field) => visit(field, self)); const fieldTypes = fields.map((field) => field.type).join("\n"); const mergedManifest = mergeManifests(fields); if (nestedStruct) { const nestedTraits = getTraitsFromNode2( definedTypeNode({ name: originalParentName, type: structType }) ); mergedManifest.imports.mergeWith(nestedTraits.imports); return { ...mergedManifest, nestedStructs: [ ...mergedManifest.nestedStructs, `${nestedTraits.render}pub struct ${pascalCase(originalParentName)} { ${fieldTypes} }` ], type: pascalCase(originalParentName) }; } if (inlineStruct) { return { ...mergedManifest, type: `{ ${fieldTypes} }` }; } return { ...mergedManifest, type: `pub struct ${pascalCase(originalParentName)} { ${fieldTypes} }` }; }, visitTupleType(tupleType, { self }) { const items = tupleType.items.map((item) => visit(item, self)); const mergedManifest = mergeManifests(items); return { ...mergedManifest, type: `(${items.map((item) => item.type).join(", ")})` }; }, visitZeroableOptionType(node) { throw new CodamaError(CODAMA_ERROR__RENDERERS__UNSUPPORTED_NODE, { kind: node.kind, node }); } }) ); } function mergeManifests(manifests) { return { imports: new ImportMap().mergeWith(...manifests.map((td) => td.imports)), nestedStructs: manifests.flatMap((m) => m.nestedStructs) }; } // src/getRenderMapVisitor.ts function getRenderMapVisitor(options = {}) { const linkables = new LinkableDictionary(); const stack = new NodeStack(); let program = null; const renderParentInstructions = options.renderParentInstructions ?? false; const dependencyMap = options.dependencyMap ?? {}; const getImportFrom = getImportFromFactory(options.linkOverrides ?? {}); const getTraitsFromNode2 = getTraitsFromNodeFactory(options.traitOptions); const typeManifestVisitor = getTypeManifestVisitor({ getImportFrom, getTraitsFromNode: getTraitsFromNode2, traitOptions: options.traitOptions }); const anchorTraits = options.anchorTraits ?? true; return pipe( staticVisitor(() => createRenderMap(), { keys: ["rootNode", "programNode", "instructionNode", "accountNode", "definedTypeNode"] }), (v) => extendVisitor(v, { visitAccount(node) { const typeManifest = visit(node, typeManifestVisitor); const fields = resolveNestedTypeNode(node.data).fields; const discriminatorConstants = getDiscriminatorConstants({ discriminatorNodes: node.discriminators ?? [], fields, getImportFrom, prefix: node.name, typeManifestVisitor }); const seedsImports = new ImportMap(); const pda = node.pda ? linkables.get([...stack.getPath(), node.pda]) : void 0; const pdaSeeds = pda?.seeds ?? []; const seeds = pdaSeeds.map((seed) => { if (isNode(seed, "variablePdaSeedNode")) { const seedManifest2 = visit(seed.type, typeManifestVisitor); seedsImports.mergeWith(seedManifest2.imports); const resolvedType2 = resolveNestedTypeNode(seed.type); return { ...seed, resolvedType: resolvedType2, typeManifest: seedManifest2 }; } if (isNode(seed.value, "programIdValueNode")) { return seed; } const seedManifest = visit(seed.type, typeManifestVisitor); const valueManifest = renderValueNode(seed.value, getImportFrom, true); seedsImports.mergeWith(valueManifest.imports); const resolvedType = resolveNestedTypeNode(seed.type); return { ...seed, resolvedType, typeManifest: seedManifest, valueManifest }; }); const hasVariableSeeds = pdaSeeds.filter(isNodeFilter("variablePdaSeedNode")).length > 0; const constantSeeds = seeds.filter(isNodeFilter("constantPdaSeedNode")).filter((seed) => !isNode(seed.value, "programIdValueNode")); const { imports } = typeManifest; if (hasVariableSeeds) { imports.mergeWith(seedsImports); } return createRenderMap(`accounts/${snakeCase(node.name)}.rs`, { content: render("accountsPage.njk", { account: node, anchorTraits, constantSeeds, discriminatorConstants: discriminatorConstants.render, hasVariableSeeds, imports: imports.mergeWith(discriminatorConstants.imports).remove(`generatedAccounts::${pascalCase(node.name)}`).toString(dependencyMap), pda, program, seeds, typeManifest }) }); }, visitDefinedType(node) { const typeManifest = visit(node, typeManifestVisitor); const imports = new ImportMap().mergeWithManifest(typeManifest); return createRenderMap(`types/${snakeCase(node.name)}.rs`, { content: render("definedTypesPage.njk", { definedType: node, imports: imports.remove(`generatedTypes::${pascalCase(node.name)}`).toString(dependencyMap), typeManifest }) }); }, visitInstruction(node) { const imports = new ImportMap(); const accountsAndArgsConflicts = getConflictsForInstructionAccountsAndArgs(node); if (accountsAndArgsConflicts.length > 0) { logWarn( `[Rust] Accounts and args of instruction [${node.name}] have the following conflicting attributes [${accountsAndArgsConflicts.join(", ")}]. Thus, the conflicting arguments will be suffixed with "_arg". You may want to rename the conflicting attributes.` ); } const discriminatorConstants = getDiscriminatorConstants({ discriminatorNodes: node.discriminators ?? [], fields: node.arguments, getImportFrom, prefix: node.name, typeManifestVisitor }); const instructionArgs = []; let hasArgs = false; let hasOptional = false; node.arguments.forEach((argument) => { const argumentVisitor = getTypeManifestVisitor({ getImportFrom, getTraitsFromNode: getTraitsFromNode2, nestedStruct: true, parentName: `${pascalCase(node.name)}InstructionData` }); const manifest = visit(argument.type, argumentVisitor); imports.mergeWith(manifest.imports); const innerOptionType = isNode(argument.type, "optionTypeNode") ? manifest.type.slice("Option<".length, -1) : null; const hasDefaultValue = !!argument.defaultValue && isNode(argument.defaultValue, VALUE_NODES); let renderValue = null; if (hasDefaultValue) { const { imports: argImports, render: value } = renderValueNode( argument.defaultValue, getImportFrom ); imports.mergeWith(argImports); renderValue = value; } hasArgs = hasArgs || argument.defaultValueStrategy !== "omitted"; hasOptional = hasOptional || hasDefaultValue && argument.defaultValueStrategy !== "omitted"; const name = accountsAndArgsConflicts.includes(argument.name) ? `${argument.name}_arg` : argument.name; instructionArgs.push({ default: hasDefaultValue && argument.defaultValueStrategy === "omitted", innerOptionType, name, optional: hasDefaultValue && argument.defaultValueStrategy !== "omitted", type: manifest.type, value: renderValue }); }); const struct = structTypeNodeFromInstructionArgumentNodes(node.arguments); const structVisitor = getTypeManifestVisitor({ getImportFrom, getTraitsFromNode: getTraitsFromNode2, parentName: `${pascalCase(node.name)}InstructionData` }); const typeManifest = visit(struct, structVisitor); const dataTraits = getTraitsFromNode2(node); imports.mergeWith(dataTraits.imports); return createRenderMap(`instructions/${snakeCase(node.name)}.rs`, { content: render("instructionsPage.njk", { dataTraits: dataTraits.render, discriminatorConstants: discriminatorConstants.render, hasArgs, hasOptional, imports: imports.mergeWith(discriminatorConstants.imports).remove(`generatedInstructions::${pascalCase(node.name)}`).toString(dependencyMap), instruction: node, instructionArgs, program, typeManifest }) }); }, visitProgram(node, { self }) { program = node; let renders = mergeRenderMaps([ ...node.accounts.map((account) => visit(account, self)), ...node.definedTypes.map((type) => visit(type, self)), ...getAllInstructionsWithSubs(node, { leavesOnly: !renderParentInstructions }).map((ix) => visit(ix, self)) ]); if (node.errors.length > 0) { renders = addToRenderMap(renders, `errors/${snakeCase(node.name)}.rs`, { content: render("errorsPage.njk", { errors: node.errors, imports: new ImportMap().toString(dependencyMap), program: node }) }); } program = null; return renders; }, visitRoot(node, { self }) { const programsToExport = getAllPrograms(node); const accountsToExport = getAllAccounts(node); const instructionsToExport = getAllInstructionsWithSubs(node, { leavesOnly: !renderParentInstructions }); const definedTypesToExport = getAllDefinedTypes(node); const hasAnythingToExport = programsToExport.length > 0 || accountsToExport.length > 0 || instructionsToExport.length > 0 || definedTypesToExport.length > 0; const ctx = { accountsToExport, definedTypesToExport, hasAnythingToExport, instructionsToExport, programsToExport, root: node }; return mergeRenderMaps([ createRenderMap({ ["accounts/mod.rs"]: accountsToExport.length > 0 ? { content: render("accountsMod.njk", ctx) } : void 0, ["errors/mod.rs"]: programsToExport.length > 0 ? { content: render("errorsMod.njk", ctx) } : void 0, ["instructions/mod.rs"]: instructionsToExport.length > 0 ? { content: render("instructionsMod.njk", ctx) } : void 0, ["mod.rs"]: { content: render("rootMod.njk", ctx) }, ["programs.rs"]: programsToExport.length > 0 ? { content: render("programsMod.njk", ctx) } : void 0, ["shared.rs"]: accountsToExport.length > 0 ? { content: render("sharedPage.njk", ctx) } : void 0, ["types/mod.rs"]: definedTypesToExport.length > 0 ? { content: render("definedTypesMod.njk", ctx) } : void 0 }), ...getAllPrograms(node).map((p) => visit(p, self)) ]); } }), (v) => recordNodeStackVisitor(v, stack), (v) => recordLinkablesOnFirstVisitVisitor(v, linkables) ); } function getConflictsForInstructionAccountsAndArgs(instruction) { const allNames = [ ...instruction.accounts.map((account) => account.name), ...instruction.arguments.map((argument) => argument.name) ]; const duplicates = allNames.filter((e, i, a) => a.indexOf(e) !== i); return [...new Set(duplicates)]; } function renderVisitor(path, options = {}) { return rootNodeVisitor((root) => { if (options.deleteFolderBeforeRendering ?? true) { deleteDirectory(path); } visit(root, writeRenderMapVisitor(getRenderMapVisitor(options), path)); if (options.formatCode) { if (options.crateFolder) { const removeFalsy = (arg) => Boolean(arg); runFormatter( "cargo", [options.toolchain, "fmt", "--manifest-path", `${options.crateFolder}/Cargo.toml`].filter( removeFalsy ) ); } else { logWarn("No crate folder specified, skipping formatting."); } } }); } function runFormatter(cmd, args) { const { stdout, stderr, error } = spawnSync(cmd, args); if (error?.message?.includes("ENOENT")) { logWarn(`Could not find ${cmd}, skipping formatting.`); return; } if (stdout.length > 0) { logWarn(`(cargo-fmt) ${stdout ? stdout?.toString() : error}`); } if (stderr.length > 0) { logError(`(cargo-fmt) ${stderr ? stderr.toString() : error}`); } } export { ImportMap, renderVisitor as default, getRenderMapVisitor, getTypeManifestVisitor, renderVisitor }; //# sourceMappingURL=index.node.mjs.map //# sourceMappingURL=index.node.mjs.map