@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
119 lines (118 loc) • 4.54 kB
TypeScript
import { DrawCommand } from 'regl';
import { DTypeGpu, GPUTensorConstructor, GPUTensorI } from '../../tensor/gpu/interface';
import { GPUMemoryAllocator } from '../../tensor/gpu/memory';
export declare const defaultMaxRank = 10;
export declare const defaultMaxIterations = 10000000;
declare type DictBase = {
[name: string]: any;
};
declare type InputType = 'int' | 'float';
export interface Input {
name: string;
length?: number;
type?: InputType;
}
/**
* A GPU operation takes some input of the InputType and
* calculates a single GPUTensor
*
* This is done with WebGL, which takes some input information of type Info,
* which is passed to the shader in the form of uniforms.
*/
export declare abstract class Operation<GPUTensor extends GPUTensorI, Info extends DictBase, InputType> {
protected dtype: DTypeGpu;
protected allocator: GPUMemoryAllocator;
protected statics: Set<string>;
protected gpuTensorConstructor: GPUTensorConstructor<GPUTensor>;
/**
* Since WebGL only supports arrays of constant size, we have to represent all arrays
* with a fixed length. This is this attribute, which is the maximum rank a tensor can have.
*/
protected maxRank: number;
protected drawCommand?: DrawCommand;
private copyCounter;
protected fullyStatic: boolean;
protected outputShape?: readonly number[];
constructor(tensorConstructor: GPUTensorConstructor<GPUTensor>, dtype: DTypeGpu, allocator?: GPUMemoryAllocator, maxRank?: number);
registerStatics(info: Info): void;
/**
* Gets the variable modifier for the WebGL variable with the given name
*/
getVarModifier(name: string): "" | "uniform";
/**
* Pads an array to the specified length, or the maxRank by default
*/
pad(arr: number[], len?: number): number[];
copyPad(arr: readonly number[], len?: number): number[];
/**
* Gets the variable declarations for the WebGL shader. Overwrite this if you
* need extra uniform inputs
*/
getVariables(): string;
getVariableDeclarations(): string;
getVariableInitializations(info: Info): string;
getVarType(name: string): "float" | "int";
getArrayInit(name: string, values: any[], len?: number, pad?: string): string;
getUtilFunctions(): string;
getTextureFunctions(): string;
getCompleteFragmentShader(info: Info): string;
getUniforms(info: Info): any;
posToIndex(strides: string, result: string, pos: string): string;
initIndex(index: string, rank?: string): string;
incrementIndex(index: string, shape: string): string;
incrementConditional(index: string, shape: string, cond: string): string;
/**
* The default main function of the fragment shader.
* Unless in special cases, you will use this and your fragment shader will look something like this:
*
* ```
* float process(int index[maxRank]) {
* // Calculate the value of the output at the given index
* }
*
* ${this.getDefaultMain()}
* ```
*/
getDefaultMain(): string;
precisionString(): "mediump" | "highp";
getDrawCommand(info: Info): DrawCommand;
/**
* Compiles the fragment shader with the given compilation info and precision
*
* If you need to add extra compilation info, overwrite this method
*/
compile(info: Info): void;
compute(resultShape: readonly number[], inputTensors: {
[name: string]: GPUTensorI;
}, inputs?: any): GPUTensor;
/**
* Returns the fragment shader of this operation,
* should have a method `void main()`. See getDefaultMain
*/
abstract getFragmentShader(info: Info): string;
/**
* Get the names of the tensors in the order they are passed to compute
*/
abstract getTextureNames(): string[];
getUniformAttrs(): Input[];
/**
* Performs the computation of the operation.
*
* Typically you will compute the output shape and then call `compute`
*/
abstract calc(input: InputType): GPUTensor;
/**
* Computes the shape of the output tensor given the inputs
*/
abstract getOutputShape(input: InputType): readonly number[];
/**
* Get all compilation info, that can be inferred from the given input
*/
abstract getCompilationInfo(input: InputType): Info;
/**
* Returns a string that is unique for each compilation
* configuration implied by the given input
*/
abstract getInputInfoString(input: InputType): string;
}
export {};