react-refresh-typescript
Version:
React Refresh transformer for TypeScript
680 lines (656 loc) • 30.7 kB
text/typescript
import type {
ArrowFunction,
Block,
CallExpression,
CaseClause,
ConciseBody,
DefaultClause,
Expression,
FunctionDeclaration,
FunctionExpression,
Identifier,
ModuleBlock,
Node,
NodeArray,
SourceFile,
Statement,
TransformerFactory,
VisitResult,
} from 'typescript'
/**
* Create a ReactRefresh transformer for TypeScript.
*
* This transformer should run in the before stage.
*
* This transformer requires TypeScript to be at least 4.0.
*/
export default function (opts: Options = {}): TransformerFactory<SourceFile> {
const ts = opts.ts!
if (!ts) throw new Error('Please provide typescript by options.ts')
{
const [major] = ts.version.split('.')
if (parseInt(major) < 4) throw new Error('TypeScript should be at least 4.0')
}
return (context) => {
const { factory } = context
const refreshReg = factory.createIdentifier(opts.refreshReg || '$RefreshReg$')
const refreshSig = factory.createIdentifier(opts.refreshSig || '$RefreshSig$')
return (file) => {
if (file.isDeclarationFile) return file
const containHooksLikeOrJSX = file.languageVariant === ts.LanguageVariant.JSX || file.text.includes('use')
if (!containHooksLikeOrJSX) return file
// TODO: change to scan comment?
const globalRequireForceRefresh = file.text.includes('@refresh reset')
const topLevelDeclaredName = new Set<string>()
// Collect top level local declarations
for (const node of file.statements) {
if (ts.isFunctionDeclaration(node) && node.name) topLevelDeclaredName.add(node.name.text)
if (ts.isVariableStatement(node)) {
for (const decl of node.declarationList.declarations) {
if (ts.isIdentifier(decl.name)) {
topLevelDeclaredName.add(decl.name.text)
}
// ? skip for deconstructing pattern
}
}
}
// track all JSX usage and transform non-top level hooks
const { nextFile, usedAsJSXElement, hooksSignatureMap } = visitDeep(
file,
topLevelDeclaredName,
globalRequireForceRefresh
)
file = nextFile
return updateStatements(file, (statements) =>
ts.visitLexicalEnvironment(
statements,
(node) => visitTopLevel(usedAsJSXElement, hooksSignatureMap, node),
context
)
)
}
// Only visit top level declaration to find possible components
function visitTopLevel(
usedAsJSXElement: ReadonlySet<string>,
hooksSignatureMap: Map<HandledFunction, CallExpression>,
node: Node
): VisitResult<Node> {
if (ts.isFunctionDeclaration(node)) {
if (!node.name || !node.body) return node
return [node, ...registerComponent(node.name)]
} else if (ts.isVariableStatement(node)) {
const deferredStatements: Statement[] = []
const nextDeclarationList = ts.visitEachChild(
node.declarationList,
(declaration) => {
if (!ts.isVariableDeclaration(declaration)) return declaration
const init = declaration.initializer
// Not handle complex declaration. e.g. [a, b] = [() => ..., () => ...]
// or declaration without initializer
if (!ts.isIdentifier(declaration.name) || !init) return declaration
const declarationUsedAsJSX = usedAsJSXElement.has(declaration.name.text)
if (declarationUsedAsJSX || isFunctionExpressionLikeOrFunctionDeclaration(init)) {
if (!unwantedComponentLikeDefinition(init)) {
deferredStatements.push(...registerComponent(declaration.name))
}
if (isFunctionExpressionLikeOrFunctionDeclaration(init) && hooksSignatureMap.has(init)) {
/**
* const Comp = () => <Comp />
* const Comp2 = function () { return <Comp /> }
*
* Reserve the function name
*
* See https://tc39.es/ecma262/multipage/ecmascript-language-expressions.html#sec-assignment-operators-runtime-semantics-evaluation
*/
// this is a workaround to https://github.com/Jack-Works/react-refresh-transformer/issues/8
// I don't have time to refactor it yet.
let oneShot: any = false
const sig = ts.visitEachChild(
hooksSignatureMap.get(init)!,
(node) =>
oneShot
? node
: ts.isFunctionLike(node)
? (oneShot = declaration.name as Identifier)
: node,
context
)
deferredStatements.push(factory.createExpressionStatement(sig))
}
return declaration
}
if (isHigherOrderComponentLike(init)) {
const { registers, call } = registerHigherOrderComponent(
hooksSignatureMap,
init,
declaration.name.text
)
deferredStatements.push(...registers, ...registerComponent(declaration.name))
return factory.updateVariableDeclaration(
declaration,
declaration.name,
undefined,
declaration.type,
call
)
}
return declaration
},
context
)
return [
factory.updateVariableStatement(node, node.modifiers, nextDeclarationList),
...deferredStatements,
]
} else if (ts.isExportAssignment(node)) {
if (isHigherOrderComponentLike(node.expression)) {
const { registers, call } = registerHigherOrderComponent(
hooksSignatureMap,
node.expression,
'%default%'
)
const temp = createTempVariable()
return [
factory.updateExportAssignment(node, node.modifiers, factory.createAssignment(temp, call)),
createComponentRegisterCall(temp, '%default%'),
...registers,
]
} else if (isFunctionExpressionLikeOrFunctionDeclaration(node.expression)) {
const expr = hooksSignatureMap.get(node.expression)
if (expr) {
return factory.updateExportAssignment(node, node.modifiers, expr)
}
}
}
return node
}
function registerComponent(name: Identifier) {
if (!startsWithLowerCase(name.text)) {
const temp = createTempVariable()
// uniq = name
const assignment = factory.createAssignment(temp, name)
// $reg$(uniq, "name")
return [factory.createExpressionStatement(assignment), createComponentRegisterCall(temp, name.text)]
}
return []
}
/**
* Please call isHOCLike before call this function
*/
function registerHigherOrderComponent(
hooksSignatureMap: ReadonlyMap<HandledFunction, CallExpression>,
callExpr: CallExpression,
nameHint: string
): { call: CallExpression; registers: Statement[] } {
// Recursive case, if it is x(y(...)), recursive with y(...) to get inner expr
const arg0 = callExpr.arguments[0]
if (ts.isCallExpression(arg0)) {
const tempVar = createTempVariable()
const nextNameHint = nameHint + '$' + printNode(callExpr.expression)
const { registers, call: innerResult } = registerHigherOrderComponent(
hooksSignatureMap,
arg0,
nextNameHint
)
return {
call: factory.updateCallExpression(callExpr, callExpr.expression, void 0, [
factory.createAssignment(tempVar, innerResult),
...callExpr.arguments.slice(1),
]),
registers: registers.concat(createComponentRegisterCall(tempVar, nextNameHint)),
}
}
// Base case, it is x(function () {...}) or x(() => ...) or x(Identifier)
if (!isFunctionExpressionLikeOrFunctionDeclaration(arg0) && !ts.isIdentifier(arg0)) {
throw new Error(
'This is an error of react-refresh/typescript. Please report this problem: Call isHOC before register it'
)
}
if (ts.isIdentifier(arg0)) return { call: callExpr, registers: [] }
const tempVar = createTempVariable()
return {
call: factory.updateCallExpression(callExpr, callExpr.expression, void 0, [
factory.createAssignment(tempVar, hooksSignatureMap.get(arg0) || arg0),
...callExpr.arguments.slice(1),
]),
registers: [createComponentRegisterCall(tempVar, nameHint + '$' + printNode(callExpr.expression))],
}
}
function createTempVariable() {
const tempVariable = factory.createUniqueName('_react_refresh_temp')
context.hoistVariableDeclaration(tempVariable)
return tempVariable
}
function visitDeep(
file: SourceFile,
topLevelDeclaredName: ReadonlySet<string>,
globalRequireForceRefresh: boolean
) {
const usedAsJSXElement = new Set<string>()
const containingHooksOldMap = new Map<HandledFunction, CallExpression[]>()
const hooksSignatureMap = new Map<HandledFunction, CallExpression>()
function trackHooks(comp: HandledFunction, call: CallExpression) {
const arr = containingHooksOldMap.get(comp) || []
arr.push(call)
containingHooksOldMap.set(comp, arr)
}
function visitor(node: Node) {
// Collect JSX create info
// <abc /> or <abc>
if (ts.isJsxOpeningLikeElement(node)) {
const tag = node.tagName
if (ts.isIdentifier(tag) && !isIntrinsicElement(tag)) {
const name = tag.text
if (topLevelDeclaredName.has(name)) usedAsJSXElement.add(name)
}
// Not tracking other kinds of tagNames like <A.B /> or <A:B />
} else if (isJSXConstructingCallExpr(node)) {
const arg0 = node.arguments[0]
if (arg0 && ts.isIdentifier(arg0)) {
const name = arg0.text
if (topLevelDeclaredName.has(name)) usedAsJSXElement.add(name)
}
}
if (isReactHooksCall(node)) {
const parent = findAncestor(node, isFunctionExpressionLikeOrFunctionDeclaration) as HandledFunction
if (parent) trackHooks(parent, node)
}
const oldNode = node as HandledFunction
// Collect hooks
node = ts.visitEachChild(node, visitor, context)
const hooksCalls = containingHooksOldMap.get(oldNode)
if (hooksCalls && isFunctionExpressionLikeOrFunctionDeclaration(node) && node.body) {
const hooksTracker = createTempVariable()
const createHooksTracker = factory.createExpressionStatement(
factory.createBinaryExpression(
hooksTracker,
factory.createToken(ts.SyntaxKind.EqualsToken),
factory.createCallExpression(refreshSig, undefined, [])
)
)
// @ts-ignore This is a private API.
context.addInitializationStatement(createHooksTracker)
const callTracker = factory.createCallExpression(hooksTracker, void 0, [])
const nextBody = ts.isBlock(node.body)
? updateStatements(node.body, (r) => [factory.createExpressionStatement(callTracker), ...r])
: factory.createComma(callTracker, node.body)
const newFunction = updateBody(node, nextBody)
const hooksSignature = hooksCallsToSignature(hooksCalls)
const { force: forceRefresh, hooks: hooksArray } = needForceRefresh(hooksCalls)
const requireForceRefresh = forceRefresh || globalRequireForceRefresh
if (ts.isFunctionDeclaration(newFunction)) {
if (newFunction.name) {
hooksSignatureMap.set(
newFunction,
createHooksRegisterCall(
hooksTracker,
newFunction.name,
hooksSignature,
requireForceRefresh,
hooksArray
)
)
}
node = newFunction
} else {
const wrapped = createHooksRegisterCall(
hooksTracker,
newFunction,
hooksSignature,
requireForceRefresh,
hooksArray
)
hooksSignatureMap.set(newFunction, wrapped)
node = newFunction
// if it is an inner decl, we can update it safely
if (findAncestor(oldNode.parent, ts.isFunctionLike)) node = wrapped
}
}
return updateStatements(node, addSignatureReport)
}
function addSignatureReport(statements: ReadonlyArray<Statement>) {
const next: Statement[] = []
for (const statement of statements) {
// Don't want to do a type guard here cause it is safe
const signatureReport = hooksSignatureMap.get(statement as any)
next.push(statement)
if (signatureReport) next.push(factory.createExpressionStatement(signatureReport))
}
return next
}
const nextFile = updateStatements(ts.visitEachChild(file, visitor, context), addSignatureReport)
return {
nextFile,
usedAsJSXElement,
hooksSignatureMap,
}
}
function printNode(node: Node) {
try {
return node.getText()
} catch {
return ''
}
}
function hooksCallsToSignature(calls: CallExpression[]) {
const signature = calls
.map((x) => {
let assignTarget = ''
if (x.parent && ts.isVariableDeclaration(x.parent)) {
assignTarget = printNode(x.parent.name)
}
let hooksName = printNode(x.expression)
let shouldCaptureArgs = 0 // bit-wise parameter position
if (ts.isPropertyAccessExpression(x.expression)) {
const left = x.expression.expression
if (ts.isIdentifier(left) && left.text === 'React') {
hooksName = printNode(x.expression.name)
}
}
if (hooksName === 'useState') shouldCaptureArgs = 1 << 0
else if (hooksName === 'useReducer') shouldCaptureArgs = 1 << 1
const args = x.arguments.reduce((last, val, index) => {
if ((1 << index) & shouldCaptureArgs) {
if (last) last += ','
last += printNode(val)
}
return last
}, '')
return `${hooksName}{${assignTarget}${args ? `(${args})` : ''}}`
})
.join('\n')
if (opts.emitFullSignatures !== true && opts.hashSignature) {
try {
return opts.hashSignature(signature)
} catch (e) {}
}
return signature
}
function needForceRefresh(calls: CallExpression[]) {
const externalHooks: Expression[] = []
return {
hooks: externalHooks,
force: calls.some((x) => {
const ownerFunction = findAncestor(x, isFunctionExpressionLikeOrFunctionDeclaration)
const callee = x.expression
if (!ownerFunction) return true
if (ts.isPropertyAccessExpression(callee)) {
const left = callee.expression
if (ts.isIdentifier(left)) {
if (left.text === 'React') return false
const hasDecl = hasDeclarationInScope(ownerFunction, left.text)
if (hasDecl) externalHooks.push(callee)
return !hasDecl
}
return true
} else if (ts.isIdentifier(callee)) {
if (isBuiltinHook(callee.text)) return false
const hasDecl = hasDeclarationInScope(ownerFunction, callee.text)
if (hasDecl) externalHooks.push(callee)
return !hasDecl
}
return true
}),
}
}
/**
* @param instance The identifier of the sig instance
* @param component The binding component
* @param signature The signature of the function
* @param forceRefresh Does forceRefresh enabled?
* @param trackers A list of custom hooks references
*/
function createHooksRegisterCall(
instance: Identifier,
component: Expression,
signature: string,
forceRefresh: boolean,
trackers: Expression[]
) {
const args = [component]
if (signature.includes('\n')) args.push(factory.createNoSubstitutionTemplateLiteral(signature, signature))
else args.push(factory.createStringLiteral(signature))
if (forceRefresh || trackers.length) args.push(forceRefresh ? factory.createTrue() : factory.createFalse())
if (trackers.length)
args.push(
factory.createArrowFunction(
void 0,
void 0,
[],
void 0,
factory.createToken(ts.SyntaxKind.EqualsGreaterThanToken),
factory.createArrayLiteralExpression(trackers)
)
)
return factory.createCallExpression(instance, void 0, args)
}
function createComponentRegisterCall(id: Identifier, name: string) {
return factory.createExpressionStatement(
factory.createCallExpression(refreshReg, void 0, [id, factory.createStringLiteral(name)])
)
}
function updateStatements<T extends Node>(node: T, f: (s: NodeArray<Statement>) => readonly Statement[]): T {
if (ts.isSourceFile(node)) {
const sf = factory.updateSourceFile(
node,
f(node.statements),
node.isDeclarationFile,
node.referencedFiles,
node.typeReferenceDirectives,
node.hasNoDefaultLib,
node.libReferenceDirectives
)
return sf as T & SourceFile
}
if (ts.isCaseClause(node)) {
const caseClause = factory.updateCaseClause(node, node.expression, f(node.statements))
return caseClause as T & CaseClause
}
if (ts.isDefaultClause(node)) {
const defaultClause = factory.updateDefaultClause(node, f(node.statements))
return defaultClause as T & DefaultClause
}
if (ts.isModuleBlock(node)) {
const modBlock = factory.updateModuleBlock(node, f(node.statements))
return modBlock as T & ModuleBlock
}
if (ts.isBlock(node)) {
const block = factory.updateBlock(node, f(node.statements))
return block as T & Block
}
return node
}
function updateBody(node: HandledFunction, nextBody: Block | ConciseBody): HandledFunction {
if (ts.isFunctionDeclaration(node)) {
if (!ts.isBlock(nextBody)) throw new TypeError()
return factory.updateFunctionDeclaration(
node,
node.modifiers,
node.asteriskToken,
node.name,
node.typeParameters,
node.parameters,
node.type,
nextBody
)
} else if (ts.isFunctionExpression(node)) {
if (!ts.isBlock(nextBody)) throw new TypeError()
return factory.updateFunctionExpression(
node,
node.modifiers,
node.asteriskToken,
node.name,
node.typeParameters,
node.parameters,
node.type,
nextBody
)
} else if (ts.isArrowFunction(node)) {
return factory.updateArrowFunction(
node,
node.modifiers,
node.typeParameters,
node.parameters,
node.type,
node.equalsGreaterThanToken,
nextBody
)
}
return node
}
}
function isBuiltinHook(hookName: string) {
switch (hookName) {
case 'useState':
case 'useReducer':
case 'useEffect':
case 'useLayoutEffect':
case 'useMemo':
case 'useCallback':
case 'useRef':
case 'useContext':
case 'useImperativeHandle':
case 'useDebugValue':
case 'useId':
case 'useDeferredValue':
case 'useTransition':
case 'useInsertionEffect':
case 'useSyncExternalStore':
case 'useFormState':
case 'useActionState':
case 'useOptimistic':
return true
default:
return false
}
}
function hasDeclarationInScope(node: Node, name: string) {
while (node) {
if (ts.isSourceFile(node) && hasDeclaration(node.statements, name)) return true
if (ts.isBlock(node) && hasDeclaration(node.statements, name)) return true
node = node.parent
}
return false
}
// This function does not consider uncommon and unrecommended practice like declare use var in a inner scope
function hasDeclaration(nodes: readonly Statement[], name: string) {
for (const node of nodes) {
if (ts.isVariableStatement(node)) {
for (const decl of node.declarationList.declarations) {
// binding pattern not checked
if (ts.isIdentifier(decl.name) && decl.name.text === name) return true
}
} else if (ts.isImportDeclaration(node)) {
const clause = node.importClause
const defaultImport = clause && clause.name
const namedImport = clause && clause.namedBindings
if (defaultImport && defaultImport.text === name) return true
if (namedImport && ts.isNamespaceImport(namedImport)) {
if (namedImport.name.text === name) return true
} else if (namedImport && ts.isNamedImports(namedImport)) {
const hasBinding = namedImport.elements.some((x) => x.name.text === name)
if (hasBinding) return true
}
} else if (ts.isFunctionDeclaration(node)) {
if (!node.body) continue
if (node.name && node.name.text === name) return true
}
}
return false
}
function isIntrinsicElement(id: Identifier) {
return id.text.includes('-') || startsWithLowerCase(id.text) || id.text.includes(':')
}
function isImportOrRequireLike(expr: Expression) {
if (!ts.isCallExpression(expr)) return false
const callee = expr.expression
if (callee.kind === ts.SyntaxKind.ImportKeyword) return true
if (ts.isIdentifier(callee) && callee.text.includes('require')) return true
return false
}
function isReactHooksCall(expr: Node): expr is CallExpression {
if (!ts.isCallExpression(expr)) return false
const callee = expr.expression
if (ts.isIdentifier(callee) && callee.text.startsWith('use')) return true
if (ts.isPropertyAccessExpression(callee) && callee.name.text.startsWith('use')) return true
return false
}
function findAncestor(node: Node, callback: (element: Node) => boolean | 'quit') {
while (node) {
const result = callback(node)
if (result === 'quit') {
return undefined
} else if (result) {
return node
}
node = node.parent
}
return undefined
}
/**
* If it return true, don't track it even it is used as JSX component
*/
function unwantedComponentLikeDefinition(expr: Expression): boolean {
if (isImportOrRequireLike(expr)) return true
// `const A = B.X` or `const A = X`
if (ts.isIdentifier(expr) || ts.isPropertyAccessExpression(expr)) return true
if (ts.isConditionalExpression(expr))
return (
unwantedComponentLikeDefinition(expr.condition) ||
unwantedComponentLikeDefinition(expr.whenFalse) ||
unwantedComponentLikeDefinition(expr.whenTrue)
)
return false
}
function isHigherOrderComponentLike(outExpr: Expression): outExpr is CallExpression {
let expr = outExpr
if (!ts.isCallExpression(outExpr)) return false
while (ts.isCallExpression(expr) && !isImportOrRequireLike(expr)) {
const callee = expr.expression
// x.y() or x()
const isValidCallee = ts.isPropertyAccessExpression(callee) || ts.isIdentifier(callee)
if (isValidCallee) {
expr = expr.arguments[0] // check if arg is also a HOC
if (!expr) return false
} else return false
}
const isValidHOCArg =
isFunctionExpressionLikeOrFunctionDeclaration(expr) ||
(ts.isIdentifier(expr) && !startsWithLowerCase(expr.text))
return isValidHOCArg
}
function isFunctionExpressionLikeOrFunctionDeclaration(node: Node): node is HandledFunction {
if (ts.isFunctionDeclaration(node)) return true
if (ts.isArrowFunction(node)) return true
if (ts.isFunctionExpression(node)) return true
return false
}
/**
* If the call expression seems like "jsx(...)" or "xyz.jsx(...)"
*/
function isJSXConstructingCallExpr(call: Node): call is CallExpression {
if (!ts.isCallExpression(call)) return false
const callee = call.expression
let f = ''
if (ts.isIdentifier(callee)) f = callee.text
if (ts.isPropertyAccessExpression(callee)) f = callee.name.text
if (['createElement', 'jsx', 'jsxs', 'jsxDEV'].includes(f)) return true
return false
}
}
function startsWithLowerCase(str: string) {
return str[0].toLowerCase() === str[0]
}
export type Options = {
/** @default "$RefreshReg$" */
readonly refreshReg?: string
/** @default "$RefreshSig$" */
readonly refreshSig?: string
/** @default false */
readonly emitFullSignatures?: boolean
/** Provide your own TypeScript instance. */
readonly ts?: typeof import('typescript')
/** Provide your own hash function when `emitFullSignatures` is `false` */
readonly hashSignature?: (signature: string) => string
}
type HandledFunction = FunctionDeclaration | FunctionExpression | ArrowFunction