transformers-fork
Version:
State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in your browser, with no need for a server!
104 lines (94 loc) • 5.78 kB
JavaScript
import { createInferenceSession } from "../backends/onnx.js";
import { Tensor } from "../utils/tensor.js";
/**
* Asynchronously creates a wrapper function for running an ONNX inference session.
*
* @param {number[]} session_bytes The session data in bytes.
* @param {import('onnxruntime-common').InferenceSession.SessionOptions} session_options The options for the ONNX session.
* @template {string | [string] | string[]} T
* @param {T} names The name(s) of the output tensor(s).
*
* @returns {Promise<function(Record<string, Tensor>): Promise<T extends string ? Tensor : T extends string[] ? { [K in keyof T]: Tensor } : never>>}
* The wrapper function for running the ONNX inference session.
*/
const wrap = async (session_bytes, session_options, names) => {
const session = await createInferenceSession(
new Uint8Array(session_bytes), session_options,
);
return /** @type {any} */(async (/** @type {Record<string, Tensor>} */ inputs) => {
const ortFeed = Object.fromEntries(Object.entries(inputs).map(([k, v]) => [k, v.ort_tensor]));
const outputs = await session.run(ortFeed);
if (Array.isArray(names)) {
return names.map((n) => new Tensor(outputs[n]));
} else {
return new Tensor(outputs[/** @type {string} */(names)]);
}
})
}
// In-memory registry of initialized ONNX operators
export class TensorOpRegistry {
static session_options = {
// TODO: Allow for multiple execution providers
// executionProviders: ['webgpu'],
};
static get bilinear_interpolate_4d() {
if (!this._bilinear_interpolate_4d) {
this._bilinear_interpolate_4d = wrap(
[8, 9, 18, 0, 58, 128, 1, 10, 40, 10, 1, 120, 10, 0, 10, 0, 10, 1, 115, 18, 1, 121, 34, 6, 82, 101, 115, 105, 122, 101, 42, 17, 10, 4, 109, 111, 100, 101, 34, 6, 108, 105, 110, 101, 97, 114, 160, 1, 3, 18, 1, 114, 90, 31, 10, 1, 120, 18, 26, 10, 24, 8, 1, 18, 20, 10, 3, 18, 1, 98, 10, 3, 18, 1, 99, 10, 3, 18, 1, 104, 10, 3, 18, 1, 119, 90, 15, 10, 1, 115, 18, 10, 10, 8, 8, 7, 18, 4, 10, 2, 8, 4, 98, 31, 10, 1, 121, 18, 26, 10, 24, 8, 1, 18, 20, 10, 3, 18, 1, 98, 10, 3, 18, 1, 99, 10, 3, 18, 1, 104, 10, 3, 18, 1, 119, 66, 2, 16, 20],
this.session_options,
'y',
);
}
return this._bilinear_interpolate_4d;
}
static get bicubic_interpolate_4d() {
if (!this._bicubic_interpolate_4d) {
this._bicubic_interpolate_4d = wrap(
[8, 9, 18, 0, 58, 127, 10, 39, 10, 1, 120, 10, 0, 10, 0, 10, 1, 115, 18, 1, 121, 34, 6, 82, 101, 115, 105, 122, 101, 42, 16, 10, 4, 109, 111, 100, 101, 34, 5, 99, 117, 98, 105, 99, 160, 1, 3, 18, 1, 114, 90, 31, 10, 1, 120, 18, 26, 10, 24, 8, 1, 18, 20, 10, 3, 18, 1, 98, 10, 3, 18, 1, 99, 10, 3, 18, 1, 104, 10, 3, 18, 1, 119, 90, 15, 10, 1, 115, 18, 10, 10, 8, 8, 7, 18, 4, 10, 2, 8, 4, 98, 31, 10, 1, 121, 18, 26, 10, 24, 8, 1, 18, 20, 10, 3, 18, 1, 98, 10, 3, 18, 1, 99, 10, 3, 18, 1, 104, 10, 3, 18, 1, 119, 66, 2, 16, 20],
this.session_options,
'y',
);
}
return this._bicubic_interpolate_4d;
}
static get matmul() {
if (!this._matmul) {
this._matmul = wrap(
[8, 9, 18, 0, 58, 55, 10, 17, 10, 1, 97, 10, 1, 98, 18, 1, 99, 34, 6, 77, 97, 116, 77, 117, 108, 18, 1, 114, 90, 9, 10, 1, 97, 18, 4, 10, 2, 8, 1, 90, 9, 10, 1, 98, 18, 4, 10, 2, 8, 1, 98, 9, 10, 1, 99, 18, 4, 10, 2, 8, 1, 66, 2, 16, 20],
this.session_options,
'c',
);
}
return this._matmul;
}
static get stft() {
if (!this._stft) {
this._stft = wrap(
[8, 7, 18, 0, 58, 148, 1, 10, 38, 10, 1, 115, 10, 1, 106, 10, 1, 119, 10, 1, 108, 18, 1, 111, 34, 4, 83, 84, 70, 84, 42, 15, 10, 8, 111, 110, 101, 115, 105, 100, 101, 100, 24, 1, 160, 1, 2, 18, 1, 115, 90, 26, 10, 1, 115, 18, 21, 10, 19, 8, 1, 18, 15, 10, 3, 18, 1, 98, 10, 3, 18, 1, 115, 10, 3, 18, 1, 99, 90, 11, 10, 1, 106, 18, 6, 10, 4, 8, 7, 18, 0, 90, 16, 10, 1, 119, 18, 11, 10, 9, 8, 1, 18, 5, 10, 3, 18, 1, 119, 90, 11, 10, 1, 108, 18, 6, 10, 4, 8, 7, 18, 0, 98, 31, 10, 1, 111, 18, 26, 10, 24, 8, 1, 18, 20, 10, 3, 18, 1, 98, 10, 3, 18, 1, 102, 10, 3, 18, 1, 100, 10, 3, 18, 1, 99, 66, 2, 16, 17],
this.session_options,
'o',
)
}
return this._stft;
}
static get rfft() {
if (!this._rfft) {
this._rfft = wrap(
[8, 9, 18, 0, 58, 97, 10, 33, 10, 1, 120, 10, 0, 10, 1, 97, 18, 1, 121, 34, 3, 68, 70, 84, 42, 15, 10, 8, 111, 110, 101, 115, 105, 100, 101, 100, 24, 1, 160, 1, 2, 18, 1, 100, 90, 21, 10, 1, 120, 18, 16, 10, 14, 8, 1, 18, 10, 10, 3, 18, 1, 115, 10, 3, 18, 1, 99, 90, 11, 10, 1, 97, 18, 6, 10, 4, 8, 7, 18, 0, 98, 21, 10, 1, 121, 18, 16, 10, 14, 8, 1, 18, 10, 10, 3, 18, 1, 115, 10, 3, 18, 1, 99, 66, 2, 16, 20],
this.session_options,
'y',
)
}
return this._rfft;
}
static get top_k() {
if (!this._top_k) {
this._top_k = wrap(
[8, 10, 18, 0, 58, 73, 10, 18, 10, 1, 120, 10, 1, 107, 18, 1, 118, 18, 1, 105, 34, 4, 84, 111, 112, 75, 18, 1, 116, 90, 9, 10, 1, 120, 18, 4, 10, 2, 8, 1, 90, 15, 10, 1, 107, 18, 10, 10, 8, 8, 7, 18, 4, 10, 2, 8, 1, 98, 9, 10, 1, 118, 18, 4, 10, 2, 8, 1, 98, 9, 10, 1, 105, 18, 4, 10, 2, 8, 7, 66, 2, 16, 21],
this.session_options,
[ /* Values */ 'v', /* Indices */ 'i']
)
}
return this._top_k;
}
}