@codama/renderers-rust
Version:
Renders Rust clients for your programs
1,119 lines (1,114 loc) • 44.2 kB
JavaScript
import { CodamaError, CODAMA_ERROR__RENDERERS__UNSUPPORTED_NODE, logWarn, logError, CODAMA_ERROR__UNEXPECTED_NODE_KIND } from '@codama/errors';
import { REGISTERED_TYPE_NODE_KINDS, definedTypeNode, pascalCase, snakeCase, parseDocs, resolveNestedTypeNode, isNode, remainderCountNode, fixedCountNode, prefixedCountNode, arrayTypeNode, numberTypeNode, getAllPrograms, getAllAccounts, getAllInstructionsWithSubs, getAllDefinedTypes, VALUE_NODES, structTypeNodeFromInstructionArgumentNodes, isNodeFilter, camelCase, kebabCase, titleCase, assertIsNode, bytesValueNode, numberValueNode, arrayValueNode, isScalarEnum } from '@codama/nodes';
import { createRenderMap, mergeRenderMaps, addToRenderMap, deleteDirectory, writeRenderMapVisitor } from '@codama/renderers-core';
import { pipe, mergeVisitor, extendVisitor, visit, LinkableDictionary, NodeStack, staticVisitor, recordNodeStackVisitor, recordLinkablesOnFirstVisitVisitor, rootNodeVisitor } from '@codama/visitors-core';
import { getBase64Encoder, getBase58Encoder, getBase16Encoder, getUtf8Encoder } from '@solana/codecs-strings';
import { dirname, join } from 'path';
import { fileURLToPath } from 'url';
import nunjucks from 'nunjucks';
import { spawnSync } from 'child_process';
// src/ImportMap.ts
var DEFAULT_MODULE_MAP = {
generated: "crate::generated",
generatedAccounts: "crate::generated::accounts",
generatedErrors: "crate::generated::errors",
generatedInstructions: "crate::generated::instructions",
generatedTypes: "crate::generated::types",
hooked: "crate::hooked",
mplEssentials: "mpl_toolbox",
mplToolbox: "mpl_toolbox"
};
var ImportMap = class _ImportMap {
_imports = /* @__PURE__ */ new Set();
_aliases = /* @__PURE__ */ new Map();
get imports() {
return this._imports;
}
get aliases() {
return this._aliases;
}
add(imports) {
const newImports = typeof imports === "string" ? [imports] : imports;
newImports.forEach((i) => this._imports.add(i));
return this;
}
remove(imports) {
const importsToRemove = typeof imports === "string" ? [imports] : imports;
importsToRemove.forEach((i) => this._imports.delete(i));
return this;
}
mergeWith(...others) {
others.forEach((other) => {
this.add(other._imports);
other._aliases.forEach((alias, importName) => this.addAlias(importName, alias));
});
return this;
}
mergeWithManifest(manifest) {
return this.mergeWith(manifest.imports);
}
addAlias(importName, alias) {
this._aliases.set(importName, alias);
return this;
}
isEmpty() {
return this._imports.size === 0;
}
resolveDependencyMap(dependencies) {
const dependencyMap = { ...DEFAULT_MODULE_MAP, ...dependencies };
const newImportMap = new _ImportMap();
const resolveDependency = (i) => {
const dependencyKey = Object.keys(dependencyMap).find((key) => i.startsWith(`${key}::`));
if (!dependencyKey) return i;
const dependencyValue = dependencyMap[dependencyKey];
return dependencyValue + i.slice(dependencyKey.length);
};
this._imports.forEach((i) => newImportMap.add(resolveDependency(i)));
this._aliases.forEach((alias, i) => newImportMap.addAlias(resolveDependency(i), alias));
return newImportMap;
}
toString(dependencies) {
const resolvedMap = this.resolveDependencyMap(dependencies);
const importStatements = [...resolvedMap.imports].map((i) => {
const alias = resolvedMap.aliases.get(i);
if (alias) return `use ${i} as ${alias};`;
return `use ${i};`;
});
return importStatements.join("\n");
}
};
function getBytesFromBytesValueNode(node) {
switch (node.encoding) {
case "utf8":
return getUtf8Encoder().encode(node.data);
case "base16":
return getBase16Encoder().encode(node.data);
case "base58":
return getBase58Encoder().encode(node.data);
case "base64":
default:
return getBase64Encoder().encode(node.data);
}
}
function renderValueNode(value, getImportFrom, useStr = false) {
return visit(value, renderValueNodeVisitor(getImportFrom, useStr));
}
function renderValueNodeVisitor(getImportFrom, useStr = false) {
return {
visitArrayValue(node) {
const list = node.items.map((v) => visit(v, this));
return {
imports: new ImportMap().mergeWith(...list.map((c) => c.imports)),
render: `[${list.map((c) => c.render).join(", ")}]`
};
},
visitBooleanValue(node) {
return {
imports: new ImportMap(),
render: JSON.stringify(node.boolean)
};
},
visitBytesValue(node) {
const bytes = getBytesFromBytesValueNode(node);
const numbers = Array.from(bytes).map(numberValueNode);
return visit(arrayValueNode(numbers), this);
},
visitConstantValue(node) {
if (isNode(node.value, "bytesValueNode")) {
return visit(node.value, this);
}
if (isNode(node.type, "stringTypeNode") && isNode(node.value, "stringValueNode")) {
return visit(bytesValueNode(node.type.encoding, node.value.string), this);
}
if (isNode(node.type, "numberTypeNode") && isNode(node.value, "numberValueNode")) {
const numberManifest = visit(node.value, this);
const { format, endian } = node.type;
const byteFunction = endian === "le" ? "to_le_bytes" : "to_be_bytes";
numberManifest.render = `${numberManifest.render}${format}.${byteFunction}()`;
return numberManifest;
}
throw new Error("Unsupported constant value type.");
},
visitEnumValue(node) {
const imports = new ImportMap();
const enumName = pascalCase(node.enum.name);
const variantName = pascalCase(node.variant);
const importFrom = getImportFrom(node.enum);
imports.add(`${importFrom}::${enumName}`);
if (!node.value) {
return { imports, render: `${enumName}::${variantName}` };
}
const enumValue = visit(node.value, this);
const fields = enumValue.render;
return {
imports: imports.mergeWith(enumValue.imports),
render: `${enumName}::${variantName} ${fields}`
};
},
visitMapEntryValue(node) {
const mapKey = visit(node.key, this);
const mapValue = visit(node.value, this);
return {
imports: mapKey.imports.mergeWith(mapValue.imports),
render: `[${mapKey.render}, ${mapValue.render}]`
};
},
visitMapValue(node) {
const map = node.entries.map((entry) => visit(entry, this));
const imports = new ImportMap().add("std::collection::HashMap");
return {
imports: imports.mergeWith(...map.map((c) => c.imports)),
render: `HashMap::from([${map.map((c) => c.render).join(", ")}])`
};
},
visitNoneValue() {
return {
imports: new ImportMap(),
render: "None"
};
},
visitNumberValue(node) {
return {
imports: new ImportMap(),
render: node.number.toString()
};
},
visitPublicKeyValue(node) {
return {
imports: new ImportMap().add("solana_pubkey"),
render: `pubkey!("${node.publicKey}")`
};
},
visitSetValue(node) {
const set = node.items.map((v) => visit(v, this));
const imports = new ImportMap().add("std::collection::HashSet");
return {
imports: imports.mergeWith(...set.map((c) => c.imports)),
render: `HashSet::from([${set.map((c) => c.render).join(", ")}])`
};
},
visitSomeValue(node) {
const child = visit(node.value, this);
return {
...child,
render: `Some(${child.render})`
};
},
visitStringValue(node) {
return {
imports: new ImportMap(),
render: useStr ? `${JSON.stringify(node.string)}` : `String::from(${JSON.stringify(node.string)})`
};
},
visitStructFieldValue(node) {
const structValue = visit(node.value, this);
return {
imports: structValue.imports,
render: `${node.name}: ${structValue.render}`
};
},
visitStructValue(node) {
const struct = node.fields.map((field) => visit(field, this));
return {
imports: new ImportMap().mergeWith(...struct.map((c) => c.imports)),
render: `{ ${struct.map((c) => c.render).join(", ")} }`
};
},
visitTupleValue(node) {
const tuple = node.items.map((v) => visit(v, this));
return {
imports: new ImportMap().mergeWith(...tuple.map((c) => c.imports)),
render: `(${tuple.map((c) => c.render).join(", ")})`
};
}
};
}
// src/utils/discriminatorConstant.ts
function mergeFragments(fragments, merge) {
const imports = fragments.reduce((acc, frag) => acc.mergeWith(frag.imports), new ImportMap());
const render2 = merge(fragments.map((frag) => frag.render));
return { imports, render: render2 };
}
function getDiscriminatorConstants(scope) {
const fragments = scope.discriminatorNodes.map((node) => getDiscriminatorConstant(node, scope)).filter(Boolean);
return mergeFragments(fragments, (r) => r.join("\n\n"));
}
function getDiscriminatorConstant(discriminatorNode, scope) {
switch (discriminatorNode.kind) {
case "constantDiscriminatorNode":
return getConstantDiscriminatorConstant(discriminatorNode, scope);
case "fieldDiscriminatorNode":
return getFieldDiscriminatorConstant(discriminatorNode, scope);
default:
return null;
}
}
function getConstantDiscriminatorConstant(discriminatorNode, scope) {
const { discriminatorNodes, getImportFrom, prefix, typeManifestVisitor } = scope;
const index = discriminatorNodes.filter(isNodeFilter("constantDiscriminatorNode")).indexOf(discriminatorNode);
const suffix = index <= 0 ? "" : `_${index + 1}`;
const name = camelCase(`${prefix}_discriminator${suffix}`);
const typeManifest = visit(discriminatorNode.constant.type, typeManifestVisitor);
const value = renderValueNode(discriminatorNode.constant.value, getImportFrom);
return getConstant(name, typeManifest, value);
}
function getFieldDiscriminatorConstant(discriminatorNode, scope) {
const { fields, prefix, getImportFrom, typeManifestVisitor } = scope;
const field = fields.find((f) => f.name === discriminatorNode.name);
if (!field || !field.defaultValue || !isNode(field.defaultValue, VALUE_NODES)) {
return null;
}
const name = camelCase(`${prefix}_${discriminatorNode.name}`);
const typeManifest = visit(field.type, typeManifestVisitor);
const value = renderValueNode(field.defaultValue, getImportFrom);
return getConstant(name, typeManifest, value);
}
function getConstant(name, typeManifest, value) {
const type = { imports: typeManifest.imports, render: typeManifest.type };
return mergeFragments([type, value], ([t, v]) => `pub const ${snakeCase(name).toUpperCase()}: ${t} = ${v};`);
}
function getImportFromFactory(overrides) {
const linkOverrides = {
accounts: overrides.accounts ?? {},
definedTypes: overrides.definedTypes ?? {},
instructions: overrides.instructions ?? {},
pdas: overrides.pdas ?? {},
programs: overrides.programs ?? {},
resolvers: overrides.resolvers ?? {}
};
return (node) => {
const kind = node.kind;
switch (kind) {
case "accountLinkNode":
return linkOverrides.accounts[node.name] ?? "generatedAccounts";
case "definedTypeLinkNode":
return linkOverrides.definedTypes[node.name] ?? "generatedTypes";
case "instructionLinkNode":
return linkOverrides.instructions[node.name] ?? "generatedInstructions";
case "pdaLinkNode":
return linkOverrides.pdas[node.name] ?? "generatedAccounts";
case "programLinkNode":
return linkOverrides.programs[node.name] ?? "generatedPrograms";
case "resolverValueNode":
return linkOverrides.resolvers[node.name] ?? "hooked";
default:
throw new CodamaError(CODAMA_ERROR__UNEXPECTED_NODE_KIND, {
expectedKinds: [
"AccountLinkNode",
"DefinedTypeLinkNode",
"InstructionLinkNode",
"PdaLinkNode",
"ProgramLinkNode",
"resolverValueNode"
],
kind,
node
});
}
};
}
function rustDocblock(docs) {
if (docs.length <= 0) return "";
const lines = docs.map((doc) => `/// ${doc}`);
return `${lines.join("\n")}
`;
}
var render = (template, context, options) => {
const dirname$1 = dirname(fileURLToPath(import.meta.url)) ;
const templates = join(dirname$1, "templates");
const env = nunjucks.configure(templates, { autoescape: false, trimBlocks: true, ...options });
env.addFilter("pascalCase", pascalCase);
env.addFilter("camelCase", camelCase);
env.addFilter("snakeCase", snakeCase);
env.addFilter("kebabCase", kebabCase);
env.addFilter("titleCase", titleCase);
env.addFilter("rustDocblock", rustDocblock);
env.addFilter("hasTrait", (traits, ...traitNames) => {
if (typeof traits !== "string") return false;
return traitNames.some((traitName) => traits.includes(traitName));
});
return env.render(template, context);
};
var DEFAULT_TRAIT_OPTIONS = {
baseDefaults: [
"borsh::BorshSerialize",
"borsh::BorshDeserialize",
"serde::Serialize",
"serde::Deserialize",
"Clone",
"Debug",
"Eq",
"PartialEq"
],
dataEnumDefaults: [],
featureFlags: { serde: ["serde::Serialize", "serde::Deserialize"] },
overrides: {},
scalarEnumDefaults: ["Copy", "PartialOrd", "Hash", "num_derive::FromPrimitive"],
structDefaults: [],
useFullyQualifiedName: false
};
function getTraitsFromNodeFactory(options = {}) {
return (node) => getTraitsFromNode(node, options);
}
function getTraitsFromNode(node, userOptions = {}) {
assertIsNode(node, ["accountNode", "definedTypeNode", "instructionNode"]);
const options = { ...DEFAULT_TRAIT_OPTIONS, ...userOptions };
const nodeType = getNodeType(node);
if (nodeType === "alias") {
return { imports: new ImportMap(), render: "" };
}
const sanitizedOverrides = Object.fromEntries(
Object.entries(options.overrides).map(([key, value]) => [camelCase(key), value])
);
const nodeOverrides = sanitizedOverrides[node.name];
const allTraits = nodeOverrides === void 0 ? getDefaultTraits(nodeType, options) : nodeOverrides;
const partitionedTraits = partitionTraitsInFeatures(allTraits, options.featureFlags);
let unfeaturedTraits = partitionedTraits[0];
const featuredTraits = partitionedTraits[1];
const imports = new ImportMap();
if (!options.useFullyQualifiedName) {
unfeaturedTraits = extractFullyQualifiedNames(unfeaturedTraits, imports);
}
const traitLines = [
...unfeaturedTraits.length > 0 ? [`#[derive(${unfeaturedTraits.join(", ")})]
`] : [],
...Object.entries(featuredTraits).map(([feature, traits]) => {
return `#[cfg_attr(feature = "${feature}", derive(${traits.join(", ")}))]
`;
})
];
return { imports, render: traitLines.join("") };
}
function getNodeType(node) {
if (isNode(node, ["accountNode", "instructionNode"])) return "struct";
if (isNode(node.type, "structTypeNode")) return "struct";
if (isNode(node.type, "enumTypeNode")) {
return isScalarEnum(node.type) ? "scalarEnum" : "dataEnum";
}
return "alias";
}
function getDefaultTraits(nodeType, options) {
switch (nodeType) {
case "dataEnum":
return [...options.baseDefaults, ...options.dataEnumDefaults];
case "scalarEnum":
return [...options.baseDefaults, ...options.scalarEnumDefaults];
case "struct":
return [...options.baseDefaults, ...options.structDefaults];
}
}
function partitionTraitsInFeatures(traits, featureFlags) {
const reverseFeatureFlags = Object.entries(featureFlags).reduce(
(acc, [feature, traits2]) => {
for (const trait of traits2) {
if (!acc[trait]) acc[trait] = feature;
}
return acc;
},
{}
);
const unfeaturedTraits = [];
const featuredTraits = {};
for (const trait of traits) {
const feature = reverseFeatureFlags[trait];
if (feature === void 0) {
unfeaturedTraits.push(trait);
} else {
if (!featuredTraits[feature]) featuredTraits[feature] = [];
featuredTraits[feature].push(trait);
}
}
return [unfeaturedTraits, featuredTraits];
}
function extractFullyQualifiedNames(traits, imports) {
return traits.map((trait) => {
const index = trait.lastIndexOf("::");
if (index === -1) return trait;
imports.add(trait);
return trait.slice(index + 2);
});
}
function getSerdeFieldAttribute(serdeWith, node, userOptions = {}) {
assertIsNode(node, ["accountNode", "definedTypeNode", "instructionNode"]);
const options = { ...DEFAULT_TRAIT_OPTIONS, ...userOptions };
const nodeType = getNodeType(node);
if (nodeType === "alias") {
return "";
}
const sanitizedOverrides = Object.fromEntries(
Object.entries(options.overrides).map(([key, value]) => [camelCase(key), value])
);
const nodeOverrides = sanitizedOverrides[node.name];
const allTraits = nodeOverrides === void 0 ? getDefaultTraits(nodeType, options) : nodeOverrides;
const hasSerdeSerialize = allTraits.some((t) => t === "serde::Serialize" || t === "Serialize");
const hasSerdeDeserialize = allTraits.some((t) => t === "serde::Deserialize" || t === "Deserialize");
if (!hasSerdeSerialize && !hasSerdeDeserialize) {
return "";
}
const partitionedTraits = partitionTraitsInFeatures(allTraits, options.featureFlags);
const featuredTraits = partitionedTraits[1];
let serdeFeatureName;
for (const [feature, traits] of Object.entries(featuredTraits)) {
if (traits.some(
(t) => t === "serde::Serialize" || t === "serde::Deserialize" || t === "Serialize" || t === "Deserialize"
)) {
serdeFeatureName = feature;
break;
}
}
if (serdeFeatureName) {
return `#[cfg_attr(feature = "${serdeFeatureName}", serde(with = "${serdeWith}"))]
`;
} else {
return `#[serde(with = "${serdeWith}")]
`;
}
}
// src/getTypeManifestVisitor.ts
function getTypeManifestVisitor(options) {
const { getImportFrom, getTraitsFromNode: getTraitsFromNode2, traitOptions } = options;
let parentName = options.parentName ?? null;
let nestedStruct = options.nestedStruct ?? false;
let inlineStruct = false;
let parentSize = null;
let parentNode = null;
return pipe(
mergeVisitor(
() => ({ imports: new ImportMap(), nestedStructs: [], type: "" }),
(_, values) => ({
...mergeManifests(values),
type: values.map((v) => v.type).join("\n")
}),
{ keys: [...REGISTERED_TYPE_NODE_KINDS, "definedTypeLinkNode", "definedTypeNode", "accountNode"] }
),
(v) => extendVisitor(v, {
visitAccount(account, { self }) {
parentName = pascalCase(account.name);
parentNode = account;
const manifest = visit(account.data, self);
const traits = getTraitsFromNode2(account);
manifest.imports.mergeWith(traits.imports);
parentName = null;
parentNode = null;
return {
...manifest,
type: traits.render + manifest.type
};
},
visitArrayType(arrayType, { self }) {
const childManifest = visit(arrayType.item, self);
if (isNode(arrayType.count, "fixedCountNode")) {
return {
...childManifest,
type: `[${childManifest.type}; ${arrayType.count.value}]`
};
}
if (isNode(arrayType.count, "remainderCountNode")) {
childManifest.imports.add("kaigan::types::RemainderVec");
return {
...childManifest,
type: `RemainderVec<${childManifest.type}>`
};
}
const prefix = resolveNestedTypeNode(arrayType.count.prefix);
if (prefix.endian === "le") {
switch (prefix.format) {
case "u32":
return {
...childManifest,
type: `Vec<${childManifest.type}>`
};
case "u8":
case "u16":
case "u64": {
const prefixFormat = prefix.format.toUpperCase();
childManifest.imports.add(`kaigan::types::${prefixFormat}PrefixVec`);
return {
...childManifest,
type: `${prefixFormat}PrefixVec<${childManifest.type}>`
};
}
case "shortU16": {
childManifest.imports.add("solana_short_vec::ShortVec");
return {
...childManifest,
type: `ShortVec<${childManifest.type}>`
};
}
default:
throw new Error(`Array prefix not supported: ${prefix.format}`);
}
}
throw new Error("Array size not supported by Borsh");
},
visitBooleanType(booleanType) {
const resolvedSize = resolveNestedTypeNode(booleanType.size);
if (resolvedSize.format === "u8" && resolvedSize.endian === "le") {
return {
imports: new ImportMap(),
nestedStructs: [],
type: "bool"
};
}
throw new Error("Bool size not supported by Borsh");
},
visitBytesType(_bytesType, { self }) {
let arraySize = remainderCountNode();
if (typeof parentSize === "number") {
arraySize = fixedCountNode(parentSize);
} else if (parentSize && typeof parentSize === "object") {
arraySize = prefixedCountNode(parentSize);
}
const arrayType = arrayTypeNode(numberTypeNode("u8"), arraySize);
return visit(arrayType, self);
},
visitDefinedType(definedType, { self }) {
parentName = pascalCase(definedType.name);
parentNode = definedType;
const manifest = visit(definedType.type, self);
const traits = getTraitsFromNode2(definedType);
manifest.imports.mergeWith(traits.imports);
parentName = null;
parentNode = null;
const renderedType = isNode(definedType.type, ["enumTypeNode", "structTypeNode"]) ? manifest.type : `pub type ${pascalCase(definedType.name)} = ${manifest.type};`;
return { ...manifest, type: `${traits.render}${renderedType}` };
},
visitDefinedTypeLink(node) {
const pascalCaseDefinedType = pascalCase(node.name);
const importFrom = getImportFrom(node);
return {
imports: new ImportMap().add(`${importFrom}::${pascalCaseDefinedType}`),
nestedStructs: [],
type: pascalCaseDefinedType
};
},
visitEnumEmptyVariantType(enumEmptyVariantType) {
const name = pascalCase(enumEmptyVariantType.name);
return {
imports: new ImportMap(),
nestedStructs: [],
type: `${name},`
};
},
visitEnumStructVariantType(enumStructVariantType, { self }) {
const name = pascalCase(enumStructVariantType.name);
const originalParentName = parentName;
if (!originalParentName) {
throw new Error("Enum struct variant type must have a parent name.");
}
inlineStruct = true;
parentName = pascalCase(originalParentName) + name;
const typeManifest = visit(enumStructVariantType.struct, self);
inlineStruct = false;
parentName = originalParentName;
return {
...typeManifest,
type: `${name} ${typeManifest.type},`
};
},
visitEnumTupleVariantType(enumTupleVariantType, { self }) {
const name = pascalCase(enumTupleVariantType.name);
const originalParentName = parentName;
if (!originalParentName) {
throw new Error("Enum struct variant type must have a parent name.");
}
parentName = pascalCase(originalParentName) + name;
const childManifest = visit(enumTupleVariantType.tuple, self);
parentName = originalParentName;
let derive = "";
if (parentNode && childManifest.type === "(Pubkey)") {
derive = getSerdeFieldAttribute(
"serde_with::As::<serde_with::DisplayFromStr>",
parentNode,
traitOptions
);
} else if (parentNode && childManifest.type === "(Vec<Pubkey>)") {
derive = getSerdeFieldAttribute(
"serde_with::As::<Vec<serde_with::DisplayFromStr>>",
parentNode,
traitOptions
);
}
return {
...childManifest,
type: `${derive}${name}${childManifest.type},`
};
},
visitEnumType(enumType, { self }) {
const originalParentName = parentName;
if (!originalParentName) {
throw new Error("Enum type must have a parent name.");
}
const variants = enumType.variants.map((variant) => visit(variant, self));
const variantNames = variants.map((variant) => variant.type).join("\n");
const mergedManifest = mergeManifests(variants);
return {
...mergedManifest,
type: `pub enum ${pascalCase(originalParentName)} {
${variantNames}
}`
};
},
visitFixedSizeType(fixedSizeType, { self }) {
parentSize = fixedSizeType.size;
const manifest = visit(fixedSizeType.type, self);
parentSize = null;
return manifest;
},
visitMapType(mapType, { self }) {
const key = visit(mapType.key, self);
const value = visit(mapType.value, self);
const mergedManifest = mergeManifests([key, value]);
mergedManifest.imports.add("std::collections::HashMap");
return {
...mergedManifest,
type: `HashMap<${key.type}, ${value.type}>`
};
},
visitNumberType(numberType) {
if (numberType.endian !== "le") {
throw new Error("Number endianness not supported by Borsh");
}
if (numberType.format === "shortU16") {
return {
imports: new ImportMap().add("solana_short_vec::ShortU16"),
nestedStructs: [],
type: "ShortU16"
};
}
return {
imports: new ImportMap(),
nestedStructs: [],
type: numberType.format
};
},
visitOptionType(optionType, { self }) {
const childManifest = visit(optionType.item, self);
const optionPrefix = resolveNestedTypeNode(optionType.prefix);
if (optionPrefix.format === "u8" && optionPrefix.endian === "le") {
return {
...childManifest,
type: `Option<${childManifest.type}>`
};
}
throw new Error("Option size not supported by Borsh");
},
visitPublicKeyType() {
return {
imports: new ImportMap().add("solana_pubkey::Pubkey"),
nestedStructs: [],
type: "Pubkey"
};
},
visitRemainderOptionType(node) {
throw new CodamaError(CODAMA_ERROR__RENDERERS__UNSUPPORTED_NODE, { kind: node.kind, node });
},
visitSetType(setType, { self }) {
const childManifest = visit(setType.item, self);
childManifest.imports.add("std::collections::HashSet");
return {
...childManifest,
type: `HashSet<${childManifest.type}>`
};
},
visitSizePrefixType(sizePrefixType, { self }) {
parentSize = resolveNestedTypeNode(sizePrefixType.prefix);
const manifest = visit(sizePrefixType.type, self);
parentSize = null;
return manifest;
},
visitStringType() {
if (!parentSize) {
return {
imports: new ImportMap().add(`kaigan::types::RemainderStr`),
nestedStructs: [],
type: `RemainderStr`
};
}
if (typeof parentSize === "number") {
return {
imports: new ImportMap(),
nestedStructs: [],
type: `[u8; ${parentSize}]`
};
}
if (isNode(parentSize, "numberTypeNode") && parentSize.endian === "le") {
switch (parentSize.format) {
case "u32":
return {
imports: new ImportMap(),
nestedStructs: [],
type: "String"
};
case "u8":
case "u16":
case "u64": {
const prefix = parentSize.format.toUpperCase();
return {
imports: new ImportMap().add(`kaigan::types::${prefix}PrefixString`),
nestedStructs: [],
type: `${prefix}PrefixString`
};
}
default:
throw new Error(`'String size not supported: ${parentSize.format}`);
}
}
throw new Error("String size not supported by Borsh");
},
visitStructFieldType(structFieldType, { self }) {
const originalParentName = parentName;
const originalInlineStruct = inlineStruct;
const originalNestedStruct = nestedStruct;
if (!originalParentName) {
throw new Error("Struct field type must have a parent name.");
}
parentName = pascalCase(originalParentName) + pascalCase(structFieldType.name);
nestedStruct = true;
inlineStruct = false;
const fieldManifest = visit(structFieldType.type, self);
parentName = originalParentName;
inlineStruct = originalInlineStruct;
nestedStruct = originalNestedStruct;
const fieldName = snakeCase(structFieldType.name);
const docblock = rustDocblock(parseDocs(structFieldType.docs));
const resolvedNestedType = resolveNestedTypeNode(structFieldType.type);
let derive = "";
if (parentNode) {
if (fieldManifest.type === "Pubkey") {
derive = getSerdeFieldAttribute(
"serde_with::As::<serde_with::DisplayFromStr>",
parentNode,
traitOptions
);
} else if (fieldManifest.type === "Vec<Pubkey>") {
derive = getSerdeFieldAttribute(
"serde_with::As::<Vec<serde_with::DisplayFromStr>>",
parentNode,
traitOptions
);
} else if (isNode(resolvedNestedType, "arrayTypeNode") && isNode(resolvedNestedType.count, "fixedCountNode") && resolvedNestedType.count.value > 32) {
derive = getSerdeFieldAttribute("serde_big_array::BigArray", parentNode, traitOptions);
} else if (isNode(resolvedNestedType, ["bytesTypeNode", "stringTypeNode"]) && isNode(structFieldType.type, "fixedSizeTypeNode") && structFieldType.type.size > 32) {
derive = getSerdeFieldAttribute(
"serde_with::As::<serde_with::Bytes>",
parentNode,
traitOptions
);
}
}
return {
...fieldManifest,
type: inlineStruct ? `${docblock}${derive}${fieldName}: ${fieldManifest.type},` : `${docblock}${derive}pub ${fieldName}: ${fieldManifest.type},`
};
},
visitStructType(structType, { self }) {
const originalParentName = parentName;
if (!originalParentName) {
throw new Error("Struct type must have a parent name.");
}
const fields = structType.fields.map((field) => visit(field, self));
const fieldTypes = fields.map((field) => field.type).join("\n");
const mergedManifest = mergeManifests(fields);
if (nestedStruct) {
const nestedTraits = getTraitsFromNode2(
definedTypeNode({ name: originalParentName, type: structType })
);
mergedManifest.imports.mergeWith(nestedTraits.imports);
return {
...mergedManifest,
nestedStructs: [
...mergedManifest.nestedStructs,
`${nestedTraits.render}pub struct ${pascalCase(originalParentName)} {
${fieldTypes}
}`
],
type: pascalCase(originalParentName)
};
}
if (inlineStruct) {
return { ...mergedManifest, type: `{
${fieldTypes}
}` };
}
return {
...mergedManifest,
type: `pub struct ${pascalCase(originalParentName)} {
${fieldTypes}
}`
};
},
visitTupleType(tupleType, { self }) {
const items = tupleType.items.map((item) => visit(item, self));
const mergedManifest = mergeManifests(items);
return {
...mergedManifest,
type: `(${items.map((item) => item.type).join(", ")})`
};
},
visitZeroableOptionType(node) {
throw new CodamaError(CODAMA_ERROR__RENDERERS__UNSUPPORTED_NODE, { kind: node.kind, node });
}
})
);
}
function mergeManifests(manifests) {
return {
imports: new ImportMap().mergeWith(...manifests.map((td) => td.imports)),
nestedStructs: manifests.flatMap((m) => m.nestedStructs)
};
}
// src/getRenderMapVisitor.ts
function getRenderMapVisitor(options = {}) {
const linkables = new LinkableDictionary();
const stack = new NodeStack();
let program = null;
const renderParentInstructions = options.renderParentInstructions ?? false;
const dependencyMap = options.dependencyMap ?? {};
const getImportFrom = getImportFromFactory(options.linkOverrides ?? {});
const getTraitsFromNode2 = getTraitsFromNodeFactory(options.traitOptions);
const typeManifestVisitor = getTypeManifestVisitor({
getImportFrom,
getTraitsFromNode: getTraitsFromNode2,
traitOptions: options.traitOptions
});
const anchorTraits = options.anchorTraits ?? true;
return pipe(
staticVisitor(() => createRenderMap(), {
keys: ["rootNode", "programNode", "instructionNode", "accountNode", "definedTypeNode"]
}),
(v) => extendVisitor(v, {
visitAccount(node) {
const typeManifest = visit(node, typeManifestVisitor);
const fields = resolveNestedTypeNode(node.data).fields;
const discriminatorConstants = getDiscriminatorConstants({
discriminatorNodes: node.discriminators ?? [],
fields,
getImportFrom,
prefix: node.name,
typeManifestVisitor
});
const seedsImports = new ImportMap();
const pda = node.pda ? linkables.get([...stack.getPath(), node.pda]) : void 0;
const pdaSeeds = pda?.seeds ?? [];
const seeds = pdaSeeds.map((seed) => {
if (isNode(seed, "variablePdaSeedNode")) {
const seedManifest2 = visit(seed.type, typeManifestVisitor);
seedsImports.mergeWith(seedManifest2.imports);
const resolvedType2 = resolveNestedTypeNode(seed.type);
return { ...seed, resolvedType: resolvedType2, typeManifest: seedManifest2 };
}
if (isNode(seed.value, "programIdValueNode")) {
return seed;
}
const seedManifest = visit(seed.type, typeManifestVisitor);
const valueManifest = renderValueNode(seed.value, getImportFrom, true);
seedsImports.mergeWith(valueManifest.imports);
const resolvedType = resolveNestedTypeNode(seed.type);
return { ...seed, resolvedType, typeManifest: seedManifest, valueManifest };
});
const hasVariableSeeds = pdaSeeds.filter(isNodeFilter("variablePdaSeedNode")).length > 0;
const constantSeeds = seeds.filter(isNodeFilter("constantPdaSeedNode")).filter((seed) => !isNode(seed.value, "programIdValueNode"));
const { imports } = typeManifest;
if (hasVariableSeeds) {
imports.mergeWith(seedsImports);
}
return createRenderMap(`accounts/${snakeCase(node.name)}.rs`, {
content: render("accountsPage.njk", {
account: node,
anchorTraits,
constantSeeds,
discriminatorConstants: discriminatorConstants.render,
hasVariableSeeds,
imports: imports.mergeWith(discriminatorConstants.imports).remove(`generatedAccounts::${pascalCase(node.name)}`).toString(dependencyMap),
pda,
program,
seeds,
typeManifest
})
});
},
visitDefinedType(node) {
const typeManifest = visit(node, typeManifestVisitor);
const imports = new ImportMap().mergeWithManifest(typeManifest);
return createRenderMap(`types/${snakeCase(node.name)}.rs`, {
content: render("definedTypesPage.njk", {
definedType: node,
imports: imports.remove(`generatedTypes::${pascalCase(node.name)}`).toString(dependencyMap),
typeManifest
})
});
},
visitInstruction(node) {
const imports = new ImportMap();
const accountsAndArgsConflicts = getConflictsForInstructionAccountsAndArgs(node);
if (accountsAndArgsConflicts.length > 0) {
logWarn(
`[Rust] Accounts and args of instruction [${node.name}] have the following conflicting attributes [${accountsAndArgsConflicts.join(", ")}]. Thus, the conflicting arguments will be suffixed with "_arg". You may want to rename the conflicting attributes.`
);
}
const discriminatorConstants = getDiscriminatorConstants({
discriminatorNodes: node.discriminators ?? [],
fields: node.arguments,
getImportFrom,
prefix: node.name,
typeManifestVisitor
});
const instructionArgs = [];
let hasArgs = false;
let hasOptional = false;
node.arguments.forEach((argument) => {
const argumentVisitor = getTypeManifestVisitor({
getImportFrom,
getTraitsFromNode: getTraitsFromNode2,
nestedStruct: true,
parentName: `${pascalCase(node.name)}InstructionData`
});
const manifest = visit(argument.type, argumentVisitor);
imports.mergeWith(manifest.imports);
const innerOptionType = isNode(argument.type, "optionTypeNode") ? manifest.type.slice("Option<".length, -1) : null;
const hasDefaultValue = !!argument.defaultValue && isNode(argument.defaultValue, VALUE_NODES);
let renderValue = null;
if (hasDefaultValue) {
const { imports: argImports, render: value } = renderValueNode(
argument.defaultValue,
getImportFrom
);
imports.mergeWith(argImports);
renderValue = value;
}
hasArgs = hasArgs || argument.defaultValueStrategy !== "omitted";
hasOptional = hasOptional || hasDefaultValue && argument.defaultValueStrategy !== "omitted";
const name = accountsAndArgsConflicts.includes(argument.name) ? `${argument.name}_arg` : argument.name;
instructionArgs.push({
default: hasDefaultValue && argument.defaultValueStrategy === "omitted",
innerOptionType,
name,
optional: hasDefaultValue && argument.defaultValueStrategy !== "omitted",
type: manifest.type,
value: renderValue
});
});
const struct = structTypeNodeFromInstructionArgumentNodes(node.arguments);
const structVisitor = getTypeManifestVisitor({
getImportFrom,
getTraitsFromNode: getTraitsFromNode2,
parentName: `${pascalCase(node.name)}InstructionData`
});
const typeManifest = visit(struct, structVisitor);
const dataTraits = getTraitsFromNode2(node);
imports.mergeWith(dataTraits.imports);
return createRenderMap(`instructions/${snakeCase(node.name)}.rs`, {
content: render("instructionsPage.njk", {
dataTraits: dataTraits.render,
discriminatorConstants: discriminatorConstants.render,
hasArgs,
hasOptional,
imports: imports.mergeWith(discriminatorConstants.imports).remove(`generatedInstructions::${pascalCase(node.name)}`).toString(dependencyMap),
instruction: node,
instructionArgs,
program,
typeManifest
})
});
},
visitProgram(node, { self }) {
program = node;
let renders = mergeRenderMaps([
...node.accounts.map((account) => visit(account, self)),
...node.definedTypes.map((type) => visit(type, self)),
...getAllInstructionsWithSubs(node, {
leavesOnly: !renderParentInstructions
}).map((ix) => visit(ix, self))
]);
if (node.errors.length > 0) {
renders = addToRenderMap(renders, `errors/${snakeCase(node.name)}.rs`, {
content: render("errorsPage.njk", {
errors: node.errors,
imports: new ImportMap().toString(dependencyMap),
program: node
})
});
}
program = null;
return renders;
},
visitRoot(node, { self }) {
const programsToExport = getAllPrograms(node);
const accountsToExport = getAllAccounts(node);
const instructionsToExport = getAllInstructionsWithSubs(node, {
leavesOnly: !renderParentInstructions
});
const definedTypesToExport = getAllDefinedTypes(node);
const hasAnythingToExport = programsToExport.length > 0 || accountsToExport.length > 0 || instructionsToExport.length > 0 || definedTypesToExport.length > 0;
const ctx = {
accountsToExport,
definedTypesToExport,
hasAnythingToExport,
instructionsToExport,
programsToExport,
root: node
};
return mergeRenderMaps([
createRenderMap({
["accounts/mod.rs"]: accountsToExport.length > 0 ? { content: render("accountsMod.njk", ctx) } : void 0,
["errors/mod.rs"]: programsToExport.length > 0 ? { content: render("errorsMod.njk", ctx) } : void 0,
["instructions/mod.rs"]: instructionsToExport.length > 0 ? { content: render("instructionsMod.njk", ctx) } : void 0,
["mod.rs"]: { content: render("rootMod.njk", ctx) },
["programs.rs"]: programsToExport.length > 0 ? { content: render("programsMod.njk", ctx) } : void 0,
["shared.rs"]: accountsToExport.length > 0 ? { content: render("sharedPage.njk", ctx) } : void 0,
["types/mod.rs"]: definedTypesToExport.length > 0 ? { content: render("definedTypesMod.njk", ctx) } : void 0
}),
...getAllPrograms(node).map((p) => visit(p, self))
]);
}
}),
(v) => recordNodeStackVisitor(v, stack),
(v) => recordLinkablesOnFirstVisitVisitor(v, linkables)
);
}
function getConflictsForInstructionAccountsAndArgs(instruction) {
const allNames = [
...instruction.accounts.map((account) => account.name),
...instruction.arguments.map((argument) => argument.name)
];
const duplicates = allNames.filter((e, i, a) => a.indexOf(e) !== i);
return [...new Set(duplicates)];
}
function renderVisitor(path, options = {}) {
return rootNodeVisitor((root) => {
if (options.deleteFolderBeforeRendering ?? true) {
deleteDirectory(path);
}
visit(root, writeRenderMapVisitor(getRenderMapVisitor(options), path));
if (options.formatCode) {
if (options.crateFolder) {
const removeFalsy = (arg) => Boolean(arg);
runFormatter(
"cargo",
[options.toolchain, "fmt", "--manifest-path", `${options.crateFolder}/Cargo.toml`].filter(
removeFalsy
)
);
} else {
logWarn("No crate folder specified, skipping formatting.");
}
}
});
}
function runFormatter(cmd, args) {
const { stdout, stderr, error } = spawnSync(cmd, args);
if (error?.message?.includes("ENOENT")) {
logWarn(`Could not find ${cmd}, skipping formatting.`);
return;
}
if (stdout.length > 0) {
logWarn(`(cargo-fmt) ${stdout ? stdout?.toString() : error}`);
}
if (stderr.length > 0) {
logError(`(cargo-fmt) ${stderr ? stderr.toString() : error}`);
}
}
export { ImportMap, renderVisitor as default, getRenderMapVisitor, getTypeManifestVisitor, renderVisitor };
//# sourceMappingURL=index.node.mjs.map
//# sourceMappingURL=index.node.mjs.map