nitro-codegen
Version:
The code-generator for react-native-nitro-modules.
144 lines (143 loc) • 5.61 kB
JavaScript
import { escapeCppName, toReferenceType } from '../helpers.js';
import { Parameter } from '../Parameter.js';
import {} from '../SourceFile.js';
import { PromiseType } from './PromiseType.js';
export class FunctionType {
returnType;
parameters;
constructor(returnType, parameters, isSync = false) {
if (returnType.kind === 'void' || isSync) {
// void callbacks are async, but we don't care about the result.
this.returnType = returnType;
}
else {
// non-void callbacks are async and need to be awaited to get the result from JS.
this.returnType = new PromiseType(returnType);
}
this.parameters = parameters;
if (isSync && returnType.kind === 'void') {
throw new Error(`Function \`${this.jsName}\` cannot be sync (\`Sync<...>\`) AND return \`void\`, as this is ambiguous. ` +
`Either return a value (even if it's just a \`boolean\`) to keep it sync, or make it async.`);
}
}
get specializationName() {
return ('Func_' +
[this.returnType, ...this.parameters]
.map((p) => escapeCppName(p.getCode('c++')))
.join('_'));
}
get jsName() {
const paramsJs = this.parameters
.map((p) => `${p.name}: ${p.kind}`)
.join(', ');
const returnType = this.returnType.getCode('c++');
return `(${paramsJs}) => ${returnType}`;
}
get canBePassedByReference() {
// It's a function<..>, heavy to copy.
return true;
}
get kind() {
return 'function';
}
/**
* For a function, get the forward recreation of it:
* If variable is called `func`, this would return:
* ```cpp
* [func = std::move(func)](Params... params) -> ReturnType {
* return func(params...);
* }
* ```
*/
getForwardRecreationCode(variableName, language) {
const returnType = this.returnType.getCode(language);
const parameters = this.parameters
.map((p) => new Parameter(p.name, p))
.map((p) => p.getCode('c++'));
const forwardedParameters = this.parameters.map((p) => `std::forward<decltype(${p.name})>(${p.name})`);
switch (language) {
case 'c++':
const closure = `[${variableName} = std::move(${variableName})]`;
const signature = `(${parameters.join(', ')}) -> ${returnType}`;
const body = `{ return ${variableName}(${forwardedParameters.join(', ')}); }`;
return `${closure} ${signature} ${body}`;
default:
throw new Error(`Language ${language} is not yet supported for function forward recreations!`);
}
}
getCppFunctionPointerType(name, includeNameInfo = true) {
const params = this.parameters
.map((p) => {
const type = p.getCode('c++');
const code = p.canBePassedByReference ? toReferenceType(type) : type;
if (includeNameInfo)
return `${code} /* ${p.name} */`;
else
return code;
})
.join(', ');
const returnType = this.returnType.getCode('c++');
return `${returnType}(*${name})(${params})`;
}
getCode(language, includeNameInfo = true) {
switch (language) {
case 'c++': {
const params = this.parameters
.map((p) => {
const type = p.getCode('c++');
const code = p.canBePassedByReference ? toReferenceType(type) : type;
if (includeNameInfo)
return `${code} /* ${p.name} */`;
else
return code;
})
.join(', ');
const returnType = this.returnType.getCode(language);
return `std::function<${returnType}(${params})>`;
}
case 'swift': {
const params = this.parameters
.map((p) => {
if (includeNameInfo)
return `_ ${p.escapedName}: ${p.getCode(language)}`;
else
return p.getCode(language);
})
.join(', ');
const returnType = this.returnType.getCode(language);
return `(${params}) -> ${returnType}`;
}
case 'kotlin': {
const params = this.parameters
.map((p) => {
if (includeNameInfo)
return `${p.escapedName}: ${p.getCode(language)}`;
else
return p.getCode(language);
})
.join(', ');
const returnType = this.returnType.getCode(language);
return `(${params}) -> ${returnType}`;
}
default:
throw new Error(`Language ${language} is not yet supported for FunctionType!`);
}
}
getExtraFiles() {
return [
...this.returnType.getExtraFiles(),
...this.parameters.flatMap((p) => p.getExtraFiles()),
];
}
getRequiredImports() {
return [
{
language: 'c++',
name: 'functional',
space: 'system',
},
...this.returnType.getRequiredImports(),
...this.parameters.flatMap((p) => p.getRequiredImports()),
];
}
}