UNPKG

nitro-codegen

Version:

The code-generator for react-native-nitro-modules.

160 lines (148 loc) 5.18 kB
import type { Language } from '../../getPlatformSpecs.js' import { escapeCppName, toReferenceType } from '../helpers.js' import { Parameter } from '../Parameter.js' import { type SourceFile, type SourceImport } from '../SourceFile.js' import { PromiseType } from './PromiseType.js' import type { NamedType, Type, TypeKind } from './Type.js' export class FunctionType implements Type { readonly returnType: Type readonly parameters: NamedType[] constructor(returnType: Type, parameters: NamedType[], isSync = false) { if (returnType.kind === 'void' || isSync) { // void callbacks are async, but we don't care about the result. this.returnType = returnType } else { // non-void callbacks are async and need to be awaited to get the result from JS. this.returnType = new PromiseType(returnType) } this.parameters = parameters if (isSync && returnType.kind === 'void') { throw new Error( `Function \`${this.jsName}\` cannot be sync (\`Sync<...>\`) AND return \`void\`, as this is ambiguous. ` + `Either return a value (even if it's just a \`boolean\`) to keep it sync, or make it async.` ) } } get specializationName(): string { return ( 'Func_' + [this.returnType, ...this.parameters] .map((p) => escapeCppName(p.getCode('c++'))) .join('_') ) } get jsName(): string { const paramsJs = this.parameters .map((p) => `${p.name}: ${p.kind}`) .join(', ') const returnType = this.returnType.getCode('c++') return `(${paramsJs}) => ${returnType}` } get canBePassedByReference(): boolean { // It's a function<..>, heavy to copy. return true } get kind(): TypeKind { return 'function' } /** * For a function, get the forward recreation of it: * If variable is called `func`, this would return: * ```cpp * [func = std::move(func)](Params... params) -> ReturnType { * return func(params...); * } * ``` */ getForwardRecreationCode(variableName: string, language: Language): string { const returnType = this.returnType.getCode(language) const parameters = this.parameters .map((p) => new Parameter(p.name, p)) .map((p) => p.getCode('c++')) const forwardedParameters = this.parameters.map( (p) => `std::forward<decltype(${p.name})>(${p.name})` ) switch (language) { case 'c++': const closure = `[${variableName} = std::move(${variableName})]` const signature = `(${parameters.join(', ')}) -> ${returnType}` const body = `{ return ${variableName}(${forwardedParameters.join(', ')}); }` return `${closure} ${signature} ${body}` default: throw new Error( `Language ${language} is not yet supported for function forward recreations!` ) } } getCppFunctionPointerType(name: string, includeNameInfo = true): string { const params = this.parameters .map((p) => { const type = p.getCode('c++') const code = p.canBePassedByReference ? toReferenceType(type) : type if (includeNameInfo) return `${code} /* ${p.name} */` else return code }) .join(', ') const returnType = this.returnType.getCode('c++') return `${returnType}(*${name})(${params})` } getCode(language: Language, includeNameInfo = true): string { switch (language) { case 'c++': { const params = this.parameters .map((p) => { const type = p.getCode('c++') const code = p.canBePassedByReference ? toReferenceType(type) : type if (includeNameInfo) return `${code} /* ${p.name} */` else return code }) .join(', ') const returnType = this.returnType.getCode(language) return `std::function<${returnType}(${params})>` } case 'swift': { const params = this.parameters .map((p) => { if (includeNameInfo) return `_ ${p.escapedName}: ${p.getCode(language)}` else return p.getCode(language) }) .join(', ') const returnType = this.returnType.getCode(language) return `(${params}) -> ${returnType}` } case 'kotlin': { const params = this.parameters .map((p) => { if (includeNameInfo) return `${p.escapedName}: ${p.getCode(language)}` else return p.getCode(language) }) .join(', ') const returnType = this.returnType.getCode(language) return `(${params}) -> ${returnType}` } default: throw new Error( `Language ${language} is not yet supported for FunctionType!` ) } } getExtraFiles(): SourceFile[] { return [ ...this.returnType.getExtraFiles(), ...this.parameters.flatMap((p) => p.getExtraFiles()), ] } getRequiredImports(): SourceImport[] { return [ { language: 'c++', name: 'functional', space: 'system', }, ...this.returnType.getRequiredImports(), ...this.parameters.flatMap((p) => p.getRequiredImports()), ] } }