@openzeppelin/upgrade-safe-transpiler
Version:
Solidity preprocessor used to generate OpenZeppelin Contracts Upgrade Safe.
197 lines (174 loc) • 7.25 kB
text/typescript
import { SourceUnit, VariableDeclaration } from 'solidity-ast';
import { findAll } from 'solidity-ast/utils';
import { getNodeBounds } from '../solc/ast-utils';
import { TransformerTools } from '../transform';
import { Transformation } from './type';
import { formatLines } from './utils/format-lines';
import { isStorageVariable } from './utils/is-storage-variable';
import { erc7201Location } from '../utils/erc7201';
import { contractStartPosition } from './utils/contract-start-position';
import { Node } from 'solidity-ast/node';
import { extractContractStorageSize } from '../utils/natspec';
export function getNamespaceStructName(contractName: string): string {
return contractName + 'Storage';
}
export function addNamespaceStruct(include?: (source: string) => boolean) {
return function* (sourceUnit: SourceUnit, tools: TransformerTools): Generator<Transformation> {
if (!include?.(sourceUnit.absolutePath)) {
return;
}
const { error, resolver } = tools;
for (const contract of findAll('ContractDefinition', sourceUnit)) {
const specifiesStorageSize = extractContractStorageSize(contract) !== undefined;
if (specifiesStorageSize) {
throw tools.error(
contract,
'Cannot combine namespaces with @custom:storage-size annotations',
);
}
let start = contractStartPosition(contract, tools);
let finished = false;
const nonStorageVars: [number, VariableDeclaration][] = [];
const storageVars: VariableDeclaration[] = [];
// We look for the start of the source code block in the contract
// where variables are written
for (const n of contract.nodes) {
if (
n.nodeType === 'VariableDeclaration' &&
(storageVars.length > 0 || isStorageVariable(n, resolver))
) {
if (finished) {
throw error(n, 'All variables in the contract must be contiguous');
}
if (!isStorageVariable(n, resolver)) {
const varStart = getRealEndIndex(storageVars.at(-1)!, tools) + 1;
nonStorageVars.push([varStart, n]);
} else {
storageVars.push(n);
}
} else if (storageVars.length > 0) {
// We've seen storage variables before and the current node is not a
// variable, so we consider the block to have finished
finished = true;
} else {
// We haven't found storage variables yet. We assume the block of
// variables will start after the current node
start = getRealEndIndex(n, tools) + 1;
}
}
if (storageVars.length > 0) {
// We first move non-storage variables from their location to the beginning of
// the block, so they are excluded from the namespace struct
for (const [s, v] of nonStorageVars) {
const bounds = { start: s, length: getRealEndIndex(v, tools) + 1 - s };
let removed = '';
yield {
kind: 'relocate-nonstorage-var-remove',
...bounds,
transform: source => {
removed = source;
return '';
},
};
yield {
kind: 'relocate-nonstorage-var-reinsert',
start,
length: 0,
text: removed,
};
}
if (nonStorageVars.length > 0) {
yield {
kind: 'relocate-nonstorage-var-newline',
start,
length: 0,
text: '\n',
};
}
for (const v of storageVars) {
const { start, length } = getNodeBounds(v);
yield {
kind: 'remove-var-modifier',
start,
length,
transform: source => source.replace(/\s*\bprivate\b/g, ''),
};
}
const namespace = getNamespaceStructName(contract.name);
const id = 'openzeppelin.storage.' + contract.name;
const end = getRealEndIndex(storageVars.at(-1)!, tools) + 1;
yield {
kind: 'add-namespace-struct',
start,
length: end - start,
transform: source => {
// We extract the newlines at the beginning of the block so we can leave
// them outside of the struct definition
const [, leadingNewlines, rest] = source.match(/^((?:[ \t\v\f]*[\n\r])*)(.*)$/s)!;
return (
leadingNewlines +
formatLines(1, [
`/// @custom:storage-location erc7201:${id}`,
`struct ${namespace} {`,
...rest.split('\n'),
`}`,
``,
`// keccak256(abi.encode(uint256(keccak256("${id}")) - 1)) & ~bytes32(uint256(0xff))`,
`bytes32 private constant ${namespace}Location = ${erc7201Location(id)};`,
``,
`function _get${namespace}() private pure returns (${namespace} storage $) {`,
[`assembly {`, [`$.slot := ${namespace}Location`], `}`],
`}`,
]).trimEnd()
);
},
};
for (const fnDef of findAll('FunctionDefinition', contract)) {
for (const ref of fnDef.modifiers.flatMap(m => [...findAll('Identifier', m)])) {
const varDecl = resolver.tryResolveNode(
'VariableDeclaration',
ref.referencedDeclaration!,
);
if (varDecl && isStorageVariable(varDecl, resolver)) {
throw error(ref, 'Unsupported storage variable found in modifier');
}
}
let foundReferences = false;
if (fnDef.body) {
for (const ref of findAll('Identifier', fnDef.body)) {
const varDecl = resolver.tryResolveNode(
'VariableDeclaration',
ref.referencedDeclaration!,
);
if (varDecl && isStorageVariable(varDecl, resolver)) {
if (varDecl.scope !== contract.id) {
throw error(varDecl, 'Namespaces assume all variables are private');
}
foundReferences = true;
const { start } = getNodeBounds(ref);
yield { kind: 'add-namespace-ref', start, length: 0, text: '$.' };
}
}
if (fnDef.kind !== 'constructor' && foundReferences) {
// The constructor is handled in transformConstructor
const { start: fnBodyStart } = getNodeBounds(fnDef.body);
yield {
kind: 'add-namespace-base-ref',
start: fnBodyStart + 1,
length: 0,
text: `\n ${namespace} storage $ = _get${namespace}();`,
};
}
}
}
}
}
};
}
function getRealEndIndex(node: Node, tools: TransformerTools): number {
// VariableDeclaration node bounds don't include the semicolon, so we look for it,
// and include a comment if there is one after the node.
// This regex always matches at least the empty string.
const { start, length } = tools.matchOriginalAfter(node, /(\s*;)?([ \t]*\/\/[^\n\r]*)?/)!;
return start + length - 1;
}