onnxruntime-web
Version:
A Javascript library for running ONNX models on browsers
649 lines (593 loc) • 21 kB
text/typescript
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import { env } from 'onnxruntime-common';
import * as DataEncoders from './texture-data-encoder';
import { DataEncoder, Encoder, EncoderUsage } from './texture-data-encoder';
import { repeatedTry } from './utils';
export interface FenceContext {
query: WebGLSync | null;
isFencePassed(): boolean;
}
type PollItem = {
isDoneFn: () => boolean;
resolveFn: () => void;
};
export function linearSearchLastTrue(arr: Array<() => boolean>): number {
let i = 0;
for (; i < arr.length; ++i) {
const isDone = arr[i]();
if (!isDone) {
break;
}
}
return i - 1;
}
/**
* Abstraction and wrapper around WebGLRenderingContext and its operations
*/
export class WebGLContext {
gl: WebGLRenderingContext;
version: 1 | 2;
private vertexbuffer: WebGLBuffer;
private framebuffer: WebGLFramebuffer;
// WebGL flags and vital parameters
private isFloatTextureAttachableToFrameBuffer: boolean;
isFloat32DownloadSupported: boolean;
isRenderFloat32Supported: boolean;
isBlendSupported: boolean;
maxTextureSize: number;
// private maxCombinedTextureImageUnits: number;
private maxTextureImageUnits: number;
// private maxCubeMapTextureSize: number;
// private shadingLanguageVersion: string;
// private webglVendor: string;
// private webglVersion: string;
// WebGL2 flags and vital parameters
// private max3DTextureSize: number;
// private maxArrayTextureLayers: number;
// private maxColorAttachments: number;
// private maxDrawBuffers: number;
// WebGL extensions
// eslint-disable-next-line camelcase
textureFloatExtension: OES_texture_float | null;
// eslint-disable-next-line camelcase
textureHalfFloatExtension: OES_texture_half_float | null;
// WebGL2 extensions
colorBufferFloatExtension: unknown | null;
// eslint-disable-next-line @typescript-eslint/naming-convention
disjointTimerQueryWebgl2Extension: { TIME_ELAPSED_EXT: GLenum; GPU_DISJOINT_EXT: GLenum } | null;
private disposed: boolean;
private frameBufferBound = false;
constructor(gl: WebGLRenderingContext, version: 1 | 2) {
this.gl = gl;
this.version = version;
this.getExtensions();
this.vertexbuffer = this.createVertexbuffer();
this.framebuffer = this.createFramebuffer();
this.queryVitalParameters();
}
allocateTexture(width: number, height: number, encoder: DataEncoder, data?: Encoder.DataArrayType): WebGLTexture {
const gl = this.gl;
// create the texture
const texture = gl.createTexture();
// bind the texture so the following methods effect this texture.
gl.bindTexture(gl.TEXTURE_2D, texture);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
const buffer = data ? encoder.encode(data, width * height) : null;
gl.texImage2D(
gl.TEXTURE_2D,
0, // Level of detail.
encoder.internalFormat,
width,
height,
0, // Always 0 in OpenGL ES.
encoder.format,
encoder.textureType,
buffer,
);
this.checkError();
return texture as WebGLTexture;
}
updateTexture(
texture: WebGLTexture,
width: number,
height: number,
encoder: DataEncoder,
data: Encoder.DataArrayType,
): void {
const gl = this.gl;
gl.bindTexture(gl.TEXTURE_2D, texture);
const buffer = encoder.encode(data, width * height);
gl.texSubImage2D(
gl.TEXTURE_2D,
0, // level
0, // xoffset
0, // yoffset
width,
height,
encoder.format,
encoder.textureType,
buffer,
);
this.checkError();
}
attachFramebuffer(texture: WebGLTexture, width: number, height: number): void {
const gl = this.gl;
// Make it the target for framebuffer operations - including rendering.
gl.bindTexture(gl.TEXTURE_2D, texture);
gl.bindFramebuffer(gl.FRAMEBUFFER, this.framebuffer);
gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0); // 0, we aren't using MIPMAPs
this.checkError();
gl.viewport(0, 0, width, height);
gl.scissor(0, 0, width, height);
}
readTexture(
texture: WebGLTexture,
width: number,
height: number,
dataSize: number,
dataType: Encoder.DataType,
channels: number,
): Encoder.DataArrayType {
const gl = this.gl;
if (!channels) {
channels = 1;
}
if (!this.frameBufferBound) {
this.attachFramebuffer(texture, width, height);
}
const encoder = this.getEncoder(dataType, channels);
const buffer = encoder.allocate(width * height);
// bind texture to framebuffer
gl.bindTexture(gl.TEXTURE_2D, texture);
gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0); // 0, we aren't using MIPMAPs
// TODO: Check if framebuffer is ready
gl.readPixels(0, 0, width, height, gl.RGBA, encoder.textureType, buffer);
this.checkError();
// unbind FB
return encoder.decode(buffer, dataSize);
}
isFramebufferReady(): boolean {
// TODO: Implement logic to check if the framebuffer is ready
return true;
}
getActiveTexture(): string {
const gl = this.gl;
const n = gl.getParameter(this.gl.ACTIVE_TEXTURE);
return `TEXTURE${n - gl.TEXTURE0}`;
}
getTextureBinding(): WebGLTexture {
return this.gl.getParameter(this.gl.TEXTURE_BINDING_2D);
}
getFramebufferBinding(): WebGLFramebuffer {
return this.gl.getParameter(this.gl.FRAMEBUFFER_BINDING);
}
setVertexAttributes(positionHandle: number, textureCoordHandle: number): void {
const gl = this.gl;
gl.vertexAttribPointer(positionHandle, 3, gl.FLOAT, false, 20, 0);
gl.enableVertexAttribArray(positionHandle);
if (textureCoordHandle !== -1) {
gl.vertexAttribPointer(textureCoordHandle, 2, gl.FLOAT, false, 20, 12);
gl.enableVertexAttribArray(textureCoordHandle);
}
this.checkError();
}
createProgram(vertexShader: WebGLShader, fragShader: WebGLShader): WebGLProgram {
const gl = this.gl;
const program = gl.createProgram()!;
// the program consists of our shaders
gl.attachShader(program, vertexShader);
gl.attachShader(program, fragShader);
gl.linkProgram(program);
return program;
}
compileShader(shaderSource: string, shaderType: number): WebGLShader {
const gl = this.gl;
const shader = gl.createShader(shaderType);
if (!shader) {
throw new Error(`createShader() returned null with type ${shaderType}`);
}
gl.shaderSource(shader, shaderSource);
gl.compileShader(shader);
if (gl.getShaderParameter(shader, gl.COMPILE_STATUS) === false) {
throw new Error(`Failed to compile shader: ${gl.getShaderInfoLog(shader)}
Shader source:
${shaderSource}`);
}
return shader;
}
deleteShader(shader: WebGLShader): void {
this.gl.deleteShader(shader);
}
bindTextureToUniform(texture: WebGLTexture, position: number, uniformHandle: WebGLUniformLocation): void {
const gl = this.gl;
gl.activeTexture(gl.TEXTURE0 + position);
this.checkError();
gl.bindTexture(gl.TEXTURE_2D, texture);
this.checkError();
gl.uniform1i(uniformHandle, position);
this.checkError();
}
draw(): void {
this.gl.drawArrays(this.gl.TRIANGLE_STRIP, 0, 4);
this.checkError();
}
checkError(): void {
if (env.debug) {
const gl = this.gl;
const error = gl.getError();
let label = '';
switch (error) {
case gl.NO_ERROR:
return;
case gl.INVALID_ENUM:
label = 'INVALID_ENUM';
break;
case gl.INVALID_VALUE:
label = 'INVALID_VALUE';
break;
case gl.INVALID_OPERATION:
label = 'INVALID_OPERATION';
break;
case gl.INVALID_FRAMEBUFFER_OPERATION:
label = 'INVALID_FRAMEBUFFER_OPERATION';
break;
case gl.OUT_OF_MEMORY:
label = 'OUT_OF_MEMORY';
break;
case gl.CONTEXT_LOST_WEBGL:
label = 'CONTEXT_LOST_WEBGL';
break;
default:
label = `Unknown WebGL Error: ${error.toString(16)}`;
}
throw new Error(label);
}
}
deleteTexture(texture: WebGLTexture): void {
this.gl.deleteTexture(texture);
}
deleteProgram(program: WebGLProgram): void {
this.gl.deleteProgram(program);
}
getEncoder(dataType: Encoder.DataType, channels: number, usage: EncoderUsage = EncoderUsage.Default): DataEncoder {
if (this.version === 2) {
return new DataEncoders.RedFloat32DataEncoder(this.gl as WebGL2RenderingContext, channels);
}
switch (dataType) {
case 'float':
if (usage === EncoderUsage.UploadOnly || this.isRenderFloat32Supported) {
return new DataEncoders.RGBAFloatDataEncoder(this.gl, channels);
} else {
return new DataEncoders.RGBAFloatDataEncoder(
this.gl,
channels,
this.textureHalfFloatExtension!.HALF_FLOAT_OES,
);
}
case 'int':
throw new Error('not implemented');
case 'byte':
return new DataEncoders.Uint8DataEncoder(this.gl, channels);
default:
throw new Error(`Invalid dataType: ${dataType}`);
}
}
clearActiveTextures(): void {
const gl = this.gl;
for (let unit = 0; unit < this.maxTextureImageUnits; ++unit) {
gl.activeTexture(gl.TEXTURE0 + unit);
gl.bindTexture(gl.TEXTURE_2D, null);
}
}
dispose(): void {
if (this.disposed) {
return;
}
const gl = this.gl;
gl.bindFramebuffer(gl.FRAMEBUFFER, null);
gl.deleteFramebuffer(this.framebuffer);
gl.bindBuffer(gl.ARRAY_BUFFER, null);
gl.deleteBuffer(this.vertexbuffer);
gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, null);
gl.finish();
this.disposed = true;
}
private createDefaultGeometry(): Float32Array {
// Sets of x,y,z(=0),s,t coordinates.
return new Float32Array([
-1.0,
1.0,
0.0,
0.0,
1.0, // upper left
-1.0,
-1.0,
0.0,
0.0,
0.0, // lower left
1.0,
1.0,
0.0,
1.0,
1.0, // upper right
1.0,
-1.0,
0.0,
1.0,
0.0, // lower right
]);
}
private createVertexbuffer(): WebGLBuffer {
const gl = this.gl;
const buffer = gl.createBuffer();
if (!buffer) {
throw new Error('createBuffer() returned null');
}
const geometry = this.createDefaultGeometry();
gl.bindBuffer(gl.ARRAY_BUFFER, buffer);
gl.bufferData(gl.ARRAY_BUFFER, geometry, gl.STATIC_DRAW);
this.checkError();
return buffer;
}
private createFramebuffer(): WebGLFramebuffer {
const fb = this.gl.createFramebuffer();
if (!fb) {
throw new Error('createFramebuffer returned null');
}
return fb;
}
private queryVitalParameters(): void {
const gl = this.gl;
this.isFloatTextureAttachableToFrameBuffer = this.checkFloatTextureAttachableToFrameBuffer();
this.isRenderFloat32Supported = this.checkRenderFloat32();
this.isFloat32DownloadSupported = this.checkFloat32Download();
if (this.version === 1 && !this.textureHalfFloatExtension && !this.isRenderFloat32Supported) {
throw new Error('both float32 and float16 TextureType are not supported');
}
this.isBlendSupported = !this.isRenderFloat32Supported || this.checkFloat32Blend();
// this.maxCombinedTextureImageUnits = gl.getParameter(gl.MAX_COMBINED_TEXTURE_IMAGE_UNITS);
this.maxTextureSize = gl.getParameter(gl.MAX_TEXTURE_SIZE);
this.maxTextureImageUnits = gl.getParameter(gl.MAX_TEXTURE_IMAGE_UNITS);
// this.maxCubeMapTextureSize = gl.getParameter(gl.MAX_CUBE_MAP_TEXTURE_SIZE);
// this.shadingLanguageVersion = gl.getParameter(gl.SHADING_LANGUAGE_VERSION);
// this.webglVendor = gl.getParameter(gl.VENDOR);
// this.webglVersion = gl.getParameter(gl.VERSION);
if (this.version === 2) {
// this.max3DTextureSize = gl.getParameter(WebGL2RenderingContext.MAX_3D_TEXTURE_SIZE);
// this.maxArrayTextureLayers = gl.getParameter(WebGL2RenderingContext.MAX_ARRAY_TEXTURE_LAYERS);
// this.maxColorAttachments = gl.getParameter(WebGL2RenderingContext.MAX_COLOR_ATTACHMENTS);
// this.maxDrawBuffers = gl.getParameter(WebGL2RenderingContext.MAX_DRAW_BUFFERS);
}
}
private getExtensions(): void {
if (this.version === 2) {
this.colorBufferFloatExtension = this.gl.getExtension('EXT_color_buffer_float');
this.disjointTimerQueryWebgl2Extension = this.gl.getExtension('EXT_disjoint_timer_query_webgl2');
} else {
this.textureFloatExtension = this.gl.getExtension('OES_texture_float');
this.textureHalfFloatExtension = this.gl.getExtension('OES_texture_half_float');
}
}
private checkFloatTextureAttachableToFrameBuffer(): boolean {
// test whether Float32 texture is supported:
// STEP.1 create a float texture
const gl = this.gl;
const texture = gl.createTexture();
gl.bindTexture(gl.TEXTURE_2D, texture);
// eslint-disable-next-line @typescript-eslint/naming-convention
const internalFormat = this.version === 2 ? (gl as unknown as { RGBA32F: number }).RGBA32F : gl.RGBA;
gl.texImage2D(gl.TEXTURE_2D, 0, internalFormat, 1, 1, 0, gl.RGBA, gl.FLOAT, null);
// STEP.2 bind a frame buffer
const frameBuffer = gl.createFramebuffer();
gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer);
// STEP.3 attach texture to framebuffer
gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);
// STEP.4 test whether framebuffer is complete
const isComplete = gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE;
gl.bindTexture(gl.TEXTURE_2D, null);
gl.bindFramebuffer(gl.FRAMEBUFFER, null);
gl.deleteTexture(texture);
gl.deleteFramebuffer(frameBuffer);
return isComplete;
}
private checkRenderFloat32(): boolean {
if (this.version === 2) {
if (!this.colorBufferFloatExtension) {
return false;
}
} else {
if (!this.textureFloatExtension) {
return false;
}
}
return this.isFloatTextureAttachableToFrameBuffer;
}
private checkFloat32Download(): boolean {
if (this.version === 2) {
if (!this.colorBufferFloatExtension) {
return false;
}
} else {
if (!this.textureFloatExtension) {
return false;
}
if (!this.gl.getExtension('WEBGL_color_buffer_float')) {
return false;
}
}
return this.isFloatTextureAttachableToFrameBuffer;
}
/**
* Check whether GL_BLEND is supported
*/
private checkFloat32Blend(): boolean {
// it looks like currently (2019-05-08) there is no easy way to detect whether BLEND is supported
// https://github.com/microsoft/onnxjs/issues/145
const gl = this.gl;
let texture: WebGLTexture | null | undefined;
let frameBuffer: WebGLFramebuffer | null | undefined;
let vertexShader: WebGLShader | null | undefined;
let fragmentShader: WebGLShader | null | undefined;
let program: WebGLProgram | null | undefined;
try {
texture = gl.createTexture();
frameBuffer = gl.createFramebuffer();
gl.bindTexture(gl.TEXTURE_2D, texture);
// eslint-disable-next-line @typescript-eslint/naming-convention
const internalFormat = this.version === 2 ? (gl as unknown as { RGBA32F: number }).RGBA32F : gl.RGBA;
gl.texImage2D(gl.TEXTURE_2D, 0, internalFormat, 1, 1, 0, gl.RGBA, gl.FLOAT, null);
gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer);
gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);
gl.enable(gl.BLEND);
vertexShader = gl.createShader(gl.VERTEX_SHADER);
if (!vertexShader) {
return false;
}
gl.shaderSource(vertexShader, 'void main(){}');
gl.compileShader(vertexShader);
fragmentShader = gl.createShader(gl.FRAGMENT_SHADER);
if (!fragmentShader) {
return false;
}
gl.shaderSource(fragmentShader, 'precision highp float;void main(){gl_FragColor=vec4(0.5);}');
gl.compileShader(fragmentShader);
program = gl.createProgram();
if (!program) {
return false;
}
gl.attachShader(program, vertexShader);
gl.attachShader(program, fragmentShader);
gl.linkProgram(program);
gl.useProgram(program);
gl.drawArrays(gl.POINTS, 0, 1);
return gl.getError() === gl.NO_ERROR;
} finally {
gl.disable(gl.BLEND);
if (program) {
gl.deleteProgram(program);
}
if (vertexShader) {
gl.deleteShader(vertexShader);
}
if (fragmentShader) {
gl.deleteShader(fragmentShader);
}
if (frameBuffer) {
gl.bindFramebuffer(gl.FRAMEBUFFER, null);
gl.deleteFramebuffer(frameBuffer);
}
if (texture) {
gl.bindTexture(gl.TEXTURE_2D, null);
gl.deleteTexture(texture);
}
}
}
beginTimer(): WebGLQuery {
if (this.version === 2 && this.disjointTimerQueryWebgl2Extension) {
const gl2 = this.gl as WebGL2RenderingContext;
const ext = this.disjointTimerQueryWebgl2Extension;
const query = gl2.createQuery() as WebGLQuery;
gl2.beginQuery(ext.TIME_ELAPSED_EXT, query);
return query;
} else {
// TODO: add webgl 1 handling.
throw new Error('WebGL1 profiling currently not supported.');
}
}
endTimer() {
if (this.version === 2 && this.disjointTimerQueryWebgl2Extension) {
const gl2 = this.gl as WebGL2RenderingContext;
const ext = this.disjointTimerQueryWebgl2Extension;
gl2.endQuery(ext.TIME_ELAPSED_EXT);
return;
} else {
// TODO: add webgl 1 handling.
throw new Error('WebGL1 profiling currently not supported');
}
}
isTimerResultAvailable(query: WebGLQuery): boolean {
let available = false,
disjoint = false;
if (this.version === 2 && this.disjointTimerQueryWebgl2Extension) {
const gl2 = this.gl as WebGL2RenderingContext;
const ext = this.disjointTimerQueryWebgl2Extension;
available = gl2.getQueryParameter(query, gl2.QUERY_RESULT_AVAILABLE);
disjoint = gl2.getParameter(ext.GPU_DISJOINT_EXT);
} else {
// TODO: add webgl 1 handling.
throw new Error('WebGL1 profiling currently not supported');
}
return available && !disjoint;
}
getTimerResult(query: WebGLQuery): number {
let timeElapsed = 0;
if (this.version === 2) {
const gl2 = this.gl as WebGL2RenderingContext;
timeElapsed = gl2.getQueryParameter(query, gl2.QUERY_RESULT);
gl2.deleteQuery(query);
} else {
// TODO: add webgl 1 handling.
throw new Error('WebGL1 profiling currently not supported');
}
// return miliseconds
return timeElapsed / 1000000;
}
async waitForQueryAndGetTime(query: WebGLQuery): Promise<number> {
await repeatedTry(() => this.isTimerResultAvailable(query));
return this.getTimerResult(query);
}
public async createAndWaitForFence(): Promise<void> {
const fenceContext = this.createFence(this.gl);
return this.pollFence(fenceContext);
}
private createFence(gl: WebGLRenderingContext): FenceContext {
let isFencePassed: () => boolean;
const gl2 = gl as WebGL2RenderingContext;
const query = gl2.fenceSync(gl2.SYNC_GPU_COMMANDS_COMPLETE, 0);
gl.flush();
if (query === null) {
isFencePassed = () => true;
} else {
isFencePassed = () => {
const status = gl2.clientWaitSync(query, 0, 0);
return status === gl2.ALREADY_SIGNALED || status === gl2.CONDITION_SATISFIED;
};
}
return { query, isFencePassed };
}
async pollFence(fenceContext: FenceContext) {
return new Promise<void>((resolve) => {
void this.addItemToPoll(
() => fenceContext.isFencePassed(),
() => resolve(),
);
});
}
private itemsToPoll: PollItem[] = [];
pollItems(): void {
// Find the last query that has finished.
const index = linearSearchLastTrue(this.itemsToPoll.map((x) => x.isDoneFn));
for (let i = 0; i <= index; ++i) {
const { resolveFn } = this.itemsToPoll[i];
resolveFn();
}
this.itemsToPoll = this.itemsToPoll.slice(index + 1);
}
private async addItemToPoll(isDoneFn: () => boolean, resolveFn: () => void) {
this.itemsToPoll.push({ isDoneFn, resolveFn });
if (this.itemsToPoll.length > 1) {
// We already have a running loop that polls.
return;
}
// Start a new loop that polls.
await repeatedTry(() => {
this.pollItems();
// End the loop if no more items to poll.
return this.itemsToPoll.length === 0;
});
}
}