gpu.js
Version:
GPU Accelerated JavaScript
627 lines (570 loc) • 20.5 kB
JavaScript
/**
* @desc This handles all the raw state, converted state, etc. of a single function.
* [INTERNAL] A collection of functionNodes.
* @class
*/
class FunctionBuilder {
/**
*
* @param {Kernel} kernel
* @param {FunctionNode} FunctionNode
* @param {object} [extraNodeOptions]
* @returns {FunctionBuilder}
* @static
*/
static fromKernel(kernel, FunctionNode, extraNodeOptions) {
const {
kernelArguments,
kernelConstants,
argumentNames,
argumentSizes,
argumentBitRatios,
constants,
constantBitRatios,
debug,
loopMaxIterations,
nativeFunctions,
output,
optimizeFloatMemory,
precision,
plugins,
source,
subKernels,
functions,
leadingReturnStatement,
followingReturnStatement,
dynamicArguments,
dynamicOutput,
} = kernel;
const argumentTypes = new Array(kernelArguments.length);
const constantTypes = {};
for (let i = 0; i < kernelArguments.length; i++) {
argumentTypes[i] = kernelArguments[i].type;
}
for (let i = 0; i < kernelConstants.length; i++) {
const kernelConstant = kernelConstants[i];
constantTypes[kernelConstant.name] = kernelConstant.type;
}
const needsArgumentType = (functionName, index) => {
return functionBuilder.needsArgumentType(functionName, index);
};
const assignArgumentType = (functionName, index, type) => {
functionBuilder.assignArgumentType(functionName, index, type);
};
const lookupReturnType = (functionName, ast, requestingNode) => {
return functionBuilder.lookupReturnType(functionName, ast, requestingNode);
};
const lookupFunctionArgumentTypes = (functionName) => {
return functionBuilder.lookupFunctionArgumentTypes(functionName);
};
const lookupFunctionArgumentName = (functionName, argumentIndex) => {
return functionBuilder.lookupFunctionArgumentName(functionName, argumentIndex);
};
const lookupFunctionArgumentBitRatio = (functionName, argumentName) => {
return functionBuilder.lookupFunctionArgumentBitRatio(functionName, argumentName);
};
const triggerImplyArgumentType = (functionName, i, argumentType, requestingNode) => {
functionBuilder.assignArgumentType(functionName, i, argumentType, requestingNode);
};
const triggerImplyArgumentBitRatio = (functionName, argumentName, calleeFunctionName, argumentIndex) => {
functionBuilder.assignArgumentBitRatio(functionName, argumentName, calleeFunctionName, argumentIndex);
};
const onFunctionCall = (functionName, calleeFunctionName, args) => {
functionBuilder.trackFunctionCall(functionName, calleeFunctionName, args);
};
const onNestedFunction = (ast, source) => {
const argumentNames = [];
for (let i = 0; i < ast.params.length; i++) {
argumentNames.push(ast.params[i].name);
}
const nestedFunction = new FunctionNode(source, Object.assign({}, nodeOptions, {
returnType: null,
ast,
name: ast.id.name,
argumentNames,
lookupReturnType,
lookupFunctionArgumentTypes,
lookupFunctionArgumentName,
lookupFunctionArgumentBitRatio,
needsArgumentType,
assignArgumentType,
triggerImplyArgumentType,
triggerImplyArgumentBitRatio,
onFunctionCall,
}));
nestedFunction.traceFunctionAST(ast);
functionBuilder.addFunctionNode(nestedFunction);
};
const nodeOptions = Object.assign({
isRootKernel: false,
onNestedFunction,
lookupReturnType,
lookupFunctionArgumentTypes,
lookupFunctionArgumentName,
lookupFunctionArgumentBitRatio,
needsArgumentType,
assignArgumentType,
triggerImplyArgumentType,
triggerImplyArgumentBitRatio,
onFunctionCall,
optimizeFloatMemory,
precision,
constants,
constantTypes,
constantBitRatios,
debug,
loopMaxIterations,
output,
plugins,
dynamicArguments,
dynamicOutput,
}, extraNodeOptions || {});
const rootNodeOptions = Object.assign({}, nodeOptions, {
isRootKernel: true,
name: 'kernel',
argumentNames,
argumentTypes,
argumentSizes,
argumentBitRatios,
leadingReturnStatement,
followingReturnStatement,
});
if (typeof source === 'object' && source.functionNodes) {
return new FunctionBuilder().fromJSON(source.functionNodes, FunctionNode);
}
const rootNode = new FunctionNode(source, rootNodeOptions);
let functionNodes = null;
if (functions) {
functionNodes = functions.map((fn) => new FunctionNode(fn.source, {
returnType: fn.returnType,
argumentTypes: fn.argumentTypes,
output,
plugins,
constants,
constantTypes,
constantBitRatios,
optimizeFloatMemory,
precision,
lookupReturnType,
lookupFunctionArgumentTypes,
lookupFunctionArgumentName,
lookupFunctionArgumentBitRatio,
needsArgumentType,
assignArgumentType,
triggerImplyArgumentType,
triggerImplyArgumentBitRatio,
onFunctionCall,
onNestedFunction,
}));
}
let subKernelNodes = null;
if (subKernels) {
subKernelNodes = subKernels.map((subKernel) => {
const { name, source } = subKernel;
return new FunctionNode(source, Object.assign({}, nodeOptions, {
name,
isSubKernel: true,
isRootKernel: false,
}));
});
}
const functionBuilder = new FunctionBuilder({
kernel,
rootNode,
functionNodes,
nativeFunctions,
subKernelNodes
});
return functionBuilder;
}
/**
*
* @param {IFunctionBuilderSettings} [settings]
*/
constructor(settings) {
settings = settings || {};
this.kernel = settings.kernel;
this.rootNode = settings.rootNode;
this.functionNodes = settings.functionNodes || [];
this.subKernelNodes = settings.subKernelNodes || [];
this.nativeFunctions = settings.nativeFunctions || [];
this.functionMap = {};
this.nativeFunctionNames = [];
this.lookupChain = [];
this.functionNodeDependencies = {};
this.functionCalls = {};
if (this.rootNode) {
this.functionMap['kernel'] = this.rootNode;
}
if (this.functionNodes) {
for (let i = 0; i < this.functionNodes.length; i++) {
this.functionMap[this.functionNodes[i].name] = this.functionNodes[i];
}
}
if (this.subKernelNodes) {
for (let i = 0; i < this.subKernelNodes.length; i++) {
this.functionMap[this.subKernelNodes[i].name] = this.subKernelNodes[i];
}
}
if (this.nativeFunctions) {
for (let i = 0; i < this.nativeFunctions.length; i++) {
const nativeFunction = this.nativeFunctions[i];
this.nativeFunctionNames.push(nativeFunction.name);
}
}
}
/**
* @desc Add the function node directly
*
* @param {FunctionNode} functionNode - functionNode to add
*
*/
addFunctionNode(functionNode) {
if (!functionNode.name) throw new Error('functionNode.name needs set');
this.functionMap[functionNode.name] = functionNode;
if (functionNode.isRootKernel) {
this.rootNode = functionNode;
}
}
/**
* @desc Trace all the depending functions being called, from a single function
*
* This allow for 'unneeded' functions to be automatically optimized out.
* Note that the 0-index, is the starting function trace.
*
* @param {String} functionName - Function name to trace from, default to 'kernel'
* @param {String[]} [retList] - Returning list of function names that is traced. Including itself.
*
* @returns {String[]} Returning list of function names that is traced. Including itself.
*/
traceFunctionCalls(functionName, retList) {
functionName = functionName || 'kernel';
retList = retList || [];
if (this.nativeFunctionNames.indexOf(functionName) > -1) {
const nativeFunctionIndex = retList.indexOf(functionName);
if (nativeFunctionIndex === -1) {
retList.push(functionName);
} else {
/**
* https://github.com/gpujs/gpu.js/issues/207
* if dependent function is already in the list, because a function depends on it, and because it has
* already been traced, we know that we must move the dependent function to the end of the the retList.
* */
const dependantNativeFunctionName = retList.splice(nativeFunctionIndex, 1)[0];
retList.push(dependantNativeFunctionName);
}
return retList;
}
const functionNode = this.functionMap[functionName];
if (functionNode) {
// Check if function already exists
const functionIndex = retList.indexOf(functionName);
if (functionIndex === -1) {
retList.push(functionName);
functionNode.toString(); //ensure JS trace is done
for (let i = 0; i < functionNode.calledFunctions.length; ++i) {
this.traceFunctionCalls(functionNode.calledFunctions[i], retList);
}
} else {
/**
* https://github.com/gpujs/gpu.js/issues/207
* if dependent function is already in the list, because a function depends on it, and because it has
* already been traced, we know that we must move the dependent function to the end of the the retList.
* */
const dependantFunctionName = retList.splice(functionIndex, 1)[0];
retList.push(dependantFunctionName);
}
}
return retList;
}
/**
* @desc Return the string for a function
* @param {String} functionName - Function name to trace from. If null, it returns the WHOLE builder stack
* @returns {String} The full string, of all the various functions. Trace optimized if functionName given
*/
getPrototypeString(functionName) {
return this.getPrototypes(functionName).join('\n');
}
/**
* @desc Return the string for a function
* @param {String} [functionName] - Function name to trace from. If null, it returns the WHOLE builder stack
* @returns {Array} The full string, of all the various functions. Trace optimized if functionName given
*/
getPrototypes(functionName) {
if (this.rootNode) {
this.rootNode.toString();
}
if (functionName) {
return this.getPrototypesFromFunctionNames(this.traceFunctionCalls(functionName, []).reverse());
}
return this.getPrototypesFromFunctionNames(Object.keys(this.functionMap));
}
/**
* @desc Get string from function names
* @param {String[]} functionList - List of function to build string
* @returns {String} The string, of all the various functions. Trace optimized if functionName given
*/
getStringFromFunctionNames(functionList) {
const ret = [];
for (let i = 0; i < functionList.length; ++i) {
const node = this.functionMap[functionList[i]];
if (node) {
ret.push(this.functionMap[functionList[i]].toString());
}
}
return ret.join('\n');
}
/**
* @desc Return string of all functions converted
* @param {String[]} functionList - List of function names to build the string.
* @returns {Array} Prototypes of all functions converted
*/
getPrototypesFromFunctionNames(functionList) {
const ret = [];
for (let i = 0; i < functionList.length; ++i) {
const functionName = functionList[i];
const functionIndex = this.nativeFunctionNames.indexOf(functionName);
if (functionIndex > -1) {
ret.push(this.nativeFunctions[functionIndex].source);
continue;
}
const node = this.functionMap[functionName];
if (node) {
ret.push(node.toString());
}
}
return ret;
}
toJSON() {
return this.traceFunctionCalls(this.rootNode.name).reverse().map(name => {
const nativeIndex = this.nativeFunctions.indexOf(name);
if (nativeIndex > -1) {
return {
name,
source: this.nativeFunctions[nativeIndex].source
};
} else if (this.functionMap[name]) {
return this.functionMap[name].toJSON();
} else {
throw new Error(`function ${ name } not found`);
}
});
}
fromJSON(jsonFunctionNodes, FunctionNode) {
this.functionMap = {};
for (let i = 0; i < jsonFunctionNodes.length; i++) {
const jsonFunctionNode = jsonFunctionNodes[i];
this.functionMap[jsonFunctionNode.settings.name] = new FunctionNode(jsonFunctionNode.ast, jsonFunctionNode.settings);
}
return this;
}
/**
* @desc Get string for a particular function name
* @param {String} functionName - Function name to trace from. If null, it returns the WHOLE builder stack
* @returns {String} settings - The string, of all the various functions. Trace optimized if functionName given
*/
getString(functionName) {
if (functionName) {
return this.getStringFromFunctionNames(this.traceFunctionCalls(functionName).reverse());
}
return this.getStringFromFunctionNames(Object.keys(this.functionMap));
}
lookupReturnType(functionName, ast, requestingNode) {
if (ast.type !== 'CallExpression') {
throw new Error(`expected ast type of "CallExpression", but is ${ ast.type }`);
}
if (this._isNativeFunction(functionName)) {
return this._lookupNativeFunctionReturnType(functionName);
} else if (this._isFunction(functionName)) {
const node = this._getFunction(functionName);
if (node.returnType) {
return node.returnType;
} else {
for (let i = 0; i < this.lookupChain.length; i++) {
// detect circlical logic
if (this.lookupChain[i].ast === ast) {
// detect if arguments have not resolved, preventing a return type
// if so, go ahead and resolve them, so we can resolve the return type
if (node.argumentTypes.length === 0 && ast.arguments.length > 0) {
const args = ast.arguments;
for (let j = 0; j < args.length; j++) {
this.lookupChain.push({
name: requestingNode.name,
ast: args[i],
requestingNode
});
node.argumentTypes[j] = requestingNode.getType(args[j]);
this.lookupChain.pop();
}
return node.returnType = node.getType(node.getJsAST());
}
throw new Error('circlical logic detected!');
}
}
// get ready for a ride!
this.lookupChain.push({
name: requestingNode.name,
ast,
requestingNode
});
const type = node.getType(node.getJsAST());
this.lookupChain.pop();
return node.returnType = type;
}
}
return null;
}
/**
*
* @param {String} functionName
* @return {FunctionNode}
* @private
*/
_getFunction(functionName) {
if (!this._isFunction(functionName)) {
new Error(`Function ${functionName} not found`);
}
return this.functionMap[functionName];
}
_isFunction(functionName) {
return Boolean(this.functionMap[functionName]);
}
_getNativeFunction(functionName) {
for (let i = 0; i < this.nativeFunctions.length; i++) {
if (this.nativeFunctions[i].name === functionName) return this.nativeFunctions[i];
}
return null;
}
_isNativeFunction(functionName) {
return Boolean(this._getNativeFunction(functionName));
}
_lookupNativeFunctionReturnType(functionName) {
let nativeFunction = this._getNativeFunction(functionName);
if (nativeFunction) {
return nativeFunction.returnType;
}
throw new Error(`Native function ${ functionName } not found`);
}
lookupFunctionArgumentTypes(functionName) {
if (this._isNativeFunction(functionName)) {
return this._getNativeFunction(functionName).argumentTypes;
} else if (this._isFunction(functionName)) {
return this._getFunction(functionName).argumentTypes;
}
return null;
}
lookupFunctionArgumentName(functionName, argumentIndex) {
return this._getFunction(functionName).argumentNames[argumentIndex];
}
/**
*
* @param {string} functionName
* @param {string} argumentName
* @return {number}
*/
lookupFunctionArgumentBitRatio(functionName, argumentName) {
if (!this._isFunction(functionName)) {
throw new Error('function not found');
}
if (this.rootNode.name === functionName) {
const i = this.rootNode.argumentNames.indexOf(argumentName);
if (i !== -1) {
return this.rootNode.argumentBitRatios[i];
}
}
const node = this._getFunction(functionName);
const i = node.argumentNames.indexOf(argumentName);
if (i === -1) {
throw new Error('argument not found');
}
const bitRatio = node.argumentBitRatios[i];
if (typeof bitRatio !== 'number') {
throw new Error('argument bit ratio not found');
}
return bitRatio;
}
needsArgumentType(functionName, i) {
if (!this._isFunction(functionName)) return false;
const fnNode = this._getFunction(functionName);
return !fnNode.argumentTypes[i];
}
assignArgumentType(functionName, i, argumentType, requestingNode) {
if (!this._isFunction(functionName)) return;
const fnNode = this._getFunction(functionName);
if (!fnNode.argumentTypes[i]) {
fnNode.argumentTypes[i] = argumentType;
}
}
/**
* @param {string} functionName
* @param {string} argumentName
* @param {string} calleeFunctionName
* @param {number} argumentIndex
* @return {number|null}
*/
assignArgumentBitRatio(functionName, argumentName, calleeFunctionName, argumentIndex) {
const node = this._getFunction(functionName);
if (this._isNativeFunction(calleeFunctionName)) return null;
const calleeNode = this._getFunction(calleeFunctionName);
const i = node.argumentNames.indexOf(argumentName);
if (i === -1) {
throw new Error(`Argument ${argumentName} not found in arguments from function ${functionName}`);
}
const bitRatio = node.argumentBitRatios[i];
if (typeof bitRatio !== 'number') {
throw new Error(`Bit ratio for argument ${argumentName} not found in function ${functionName}`);
}
if (!calleeNode.argumentBitRatios) {
calleeNode.argumentBitRatios = new Array(calleeNode.argumentNames.length);
}
const calleeBitRatio = calleeNode.argumentBitRatios[i];
if (typeof calleeBitRatio === 'number') {
if (calleeBitRatio !== bitRatio) {
throw new Error(`Incompatible bit ratio found at function ${functionName} at argument ${argumentName}`);
}
return calleeBitRatio;
}
calleeNode.argumentBitRatios[i] = bitRatio;
return bitRatio;
}
trackFunctionCall(functionName, calleeFunctionName, args) {
if (!this.functionNodeDependencies[functionName]) {
this.functionNodeDependencies[functionName] = new Set();
this.functionCalls[functionName] = [];
}
this.functionNodeDependencies[functionName].add(calleeFunctionName);
this.functionCalls[functionName].push(args);
}
getKernelResultType() {
return this.rootNode.returnType || this.rootNode.getType(this.rootNode.ast);
}
getSubKernelResultType(index) {
const subKernelNode = this.subKernelNodes[index];
let called = false;
for (let functionCallIndex = 0; functionCallIndex < this.rootNode.functionCalls.length; functionCallIndex++) {
const functionCall = this.rootNode.functionCalls[functionCallIndex];
if (functionCall.ast.callee.name === subKernelNode.name) {
called = true;
}
}
if (!called) {
throw new Error(`SubKernel ${ subKernelNode.name } never called by kernel`);
}
return subKernelNode.returnType || subKernelNode.getType(subKernelNode.getJsAST());
}
getReturnTypes() {
const result = {
[this.rootNode.name]: this.rootNode.getType(this.rootNode.ast),
};
const list = this.traceFunctionCalls(this.rootNode.name);
for (let i = 0; i < list.length; i++) {
const functionName = list[i];
const functionNode = this.functionMap[functionName];
result[functionName] = functionNode.getType(functionNode.ast);
}
return result;
}
}
module.exports = {
FunctionBuilder
};