UNPKG

@hoff97/tensor-js

Version:

PyTorch like deep learning inferrence library

119 lines (118 loc) 4.54 kB
import { DrawCommand } from 'regl'; import { DTypeGpu, GPUTensorConstructor, GPUTensorI } from '../../tensor/gpu/interface'; import { GPUMemoryAllocator } from '../../tensor/gpu/memory'; export declare const defaultMaxRank = 10; export declare const defaultMaxIterations = 10000000; declare type DictBase = { [name: string]: any; }; declare type InputType = 'int' | 'float'; export interface Input { name: string; length?: number; type?: InputType; } /** * A GPU operation takes some input of the InputType and * calculates a single GPUTensor * * This is done with WebGL, which takes some input information of type Info, * which is passed to the shader in the form of uniforms. */ export declare abstract class Operation<GPUTensor extends GPUTensorI, Info extends DictBase, InputType> { protected dtype: DTypeGpu; protected allocator: GPUMemoryAllocator; protected statics: Set<string>; protected gpuTensorConstructor: GPUTensorConstructor<GPUTensor>; /** * Since WebGL only supports arrays of constant size, we have to represent all arrays * with a fixed length. This is this attribute, which is the maximum rank a tensor can have. */ protected maxRank: number; protected drawCommand?: DrawCommand; private copyCounter; protected fullyStatic: boolean; protected outputShape?: readonly number[]; constructor(tensorConstructor: GPUTensorConstructor<GPUTensor>, dtype: DTypeGpu, allocator?: GPUMemoryAllocator, maxRank?: number); registerStatics(info: Info): void; /** * Gets the variable modifier for the WebGL variable with the given name */ getVarModifier(name: string): "" | "uniform"; /** * Pads an array to the specified length, or the maxRank by default */ pad(arr: number[], len?: number): number[]; copyPad(arr: readonly number[], len?: number): number[]; /** * Gets the variable declarations for the WebGL shader. Overwrite this if you * need extra uniform inputs */ getVariables(): string; getVariableDeclarations(): string; getVariableInitializations(info: Info): string; getVarType(name: string): "float" | "int"; getArrayInit(name: string, values: any[], len?: number, pad?: string): string; getUtilFunctions(): string; getTextureFunctions(): string; getCompleteFragmentShader(info: Info): string; getUniforms(info: Info): any; posToIndex(strides: string, result: string, pos: string): string; initIndex(index: string, rank?: string): string; incrementIndex(index: string, shape: string): string; incrementConditional(index: string, shape: string, cond: string): string; /** * The default main function of the fragment shader. * Unless in special cases, you will use this and your fragment shader will look something like this: * * ``` * float process(int index[maxRank]) { * // Calculate the value of the output at the given index * } * * ${this.getDefaultMain()} * ``` */ getDefaultMain(): string; precisionString(): "mediump" | "highp"; getDrawCommand(info: Info): DrawCommand; /** * Compiles the fragment shader with the given compilation info and precision * * If you need to add extra compilation info, overwrite this method */ compile(info: Info): void; compute(resultShape: readonly number[], inputTensors: { [name: string]: GPUTensorI; }, inputs?: any): GPUTensor; /** * Returns the fragment shader of this operation, * should have a method `void main()`. See getDefaultMain */ abstract getFragmentShader(info: Info): string; /** * Get the names of the tensors in the order they are passed to compute */ abstract getTextureNames(): string[]; getUniformAttrs(): Input[]; /** * Performs the computation of the operation. * * Typically you will compute the output shape and then call `compute` */ abstract calc(input: InputType): GPUTensor; /** * Computes the shape of the output tensor given the inputs */ abstract getOutputShape(input: InputType): readonly number[]; /** * Get all compilation info, that can be inferred from the given input */ abstract getCompilationInfo(input: InputType): Info; /** * Returns a string that is unique for each compilation * configuration implied by the given input */ abstract getInputInfoString(input: InputType): string; } export {};