UNPKG

@internetarchive/bergamot-translator

Version:

Cross platform C++ library focusing on optimized machine translation on the consumer-grade device.

884 lines (750 loc) 30.8 kB
/** * @typedef {Object} TranslationRequest * @property {String} from * @property {String} to * @property {String} text * @property {Boolean} html * @property {Integer?} priority */ /** * @typedef {Object} TranslationResponse * @property {TranslationRequest} request * @property {{text: string}} target */ /** * NodeJS compatibility, a thin WebWorker layer around node:worker_threads. */ if (!(typeof window !== 'undefined' && window.Worker)) { globalThis.Worker = class { #worker; constructor(url) { this.#worker = new Promise(async (accept) => { const {Worker} = await import(/* webpackIgnore: true */ 'node:worker_threads'); accept(new Worker(url)); }); } addEventListener(eventName, callback) { this.#worker.then(worker => worker.on(eventName, (data) => callback({data}))); } postMessage(message) { this.#worker.then(worker => worker.postMessage(message)); } terminate() { this.#worker.then(worker => worker.terminate()); } } } /** * Thrown when a pending translation is replaced by another newer pending * translation. */ export class SupersededError extends Error {} /** * Thrown when a translation was removed from the queue. */ export class CancelledError extends Error {} /** * Wrapper around bergamot-translator loading and model management. */ export class TranslatorBacking { /** * @param {{ * cacheSize?: number, * useNativeIntGemm?: boolean, * downloadTimeout?: number, * registryUrl?: string * pivotLanguage?: string? * onerror?: (err: Error) * }} options */ constructor(options) { this.options = options || {}; this.registryUrl = this.options.registryUrl || 'https://bergamot.s3.amazonaws.com/models/index.json'; this.downloadTimeout = 'downloadTimeout' in this.options ? parseInt(this.options.downloadTimeout) : 60000; /** * registry of all available models and their urls * @type {Promise<Model[]>} */ this.registry = this.loadModelRegistery(); /** * Map of downloaded model data files as buffers per model. * @type {Map<{from:string,to:string}, Promise<Map<string,ArrayBuffer>>>} */ this.buffers = new Map(); /** * @type {string?} */ this.pivotLanguage = 'pivotLanguage' in this.options ? options.pivotLanguage : 'en'; /** * A map of language-pairs to a list of models you need for it. * @type {Map<{from:string,to:string}, Promise<{from:string,to:string}[]>>} */ this.models = new Map(); /** * Error handler for all errors that are async, not tied to a specific * call and that are unrecoverable. * @type {(error: Error)} */ this.onerror = this.options.onerror || (err => console.error('WASM Translation Worker error:', err)); } /** * Loads a worker thread, and wraps it in a message passing proxy. I.e. it * exposes the entire interface of TranslationWorker here, and all calls * to it are async. Do note that you can only pass arguments that survive * being copied into a message. * @return {Promise<{worker:Worker, exports:Proxy<TranslationWorker>}>} */ async loadWorker() { const worker = ( this.options.workerUrl ? new Worker(this.options.workerUrl) : new Worker(new URL('./worker/translator-worker.js', import.meta.url)) ); /** * Incremental counter to derive request/response ids from. */ let serial = 0; /** * Map of pending requests * @type {Map<number,{accept:(any), reject:(Error)}>} */ const pending = new Map(); // Function to send requests const call = (name, ...args) => new Promise((accept, reject) => { const id = ++serial; pending.set(id, { accept, reject, callsite: { // for debugging which call caused the error message: `${name}(${args.map(arg => String(arg)).join(', ')})`, stack: new Error().stack } }); worker.postMessage({id, name, args}); }); // … receive responses worker.addEventListener('message', function({data: {id, result, error}}) { if (!pending.has(id)) { console.debug('Received message with unknown id:', arguments[0]); throw new Error(`BergamotTranslator received response from worker to unknown call '${id}'`); } const {accept, reject, callsite} = pending.get(id); pending.delete(id); if (error !== undefined) reject(Object.assign(new Error(), error, { message: error.message + ` (response to ${callsite.message})`, stack: error.stack ? `${error.stack}\n${callsite.stack}` : callsite.stack })); else accept(result); }); // … and general errors worker.addEventListener('error', this.onerror.bind(this)); // Await initialisation. This will also nicely error out if the WASM // runtime fails to load. await call('initialize', this.options); /** * Little wrapper around the message passing api of Worker to make it * easy to await a response to a sent message. This wraps the worker in * a Proxy so you can treat it as if it is an instance of the * TranslationWorker class that lives inside the worker. All function * calls to it are transparently passed through the message passing * channel. */ return { worker, exports: new Proxy({}, { get(target, name, receiver) { // Prevent this object from being marked "then-able" if (name !== 'then') return (...args) => call(name, ...args); } }) }; } /** * Loads the model registry. Uses the registry shipped with this extension, * but formatted a bit easier to use, and future-proofed to be swapped out * with a TranslateLocally type registry. * @return {Promise<{ * from: string, * to: string, * files: { * [part:string]: { * name: string, * size: number, * expectedSha256Hash: string * } * }[] * }>} */ async loadModelRegistery() { const response = await fetch(this.registryUrl, {credentials: 'omit'}); const registry = await response.json(); // Add 'from' and 'to' keys for each model. return Array.from(Object.entries(registry), ([key, files]) => { const languages = key.match(/^([a-z]{2}(-[a-z]+)?)([a-z]{2}(-[a-z]+)?)/i); return { from: languages[1], to: languages[3], files } }); } /** * Gets or loads translation model data. Caching wrapper around * `loadTranslationModel()`. * @param {{from:string, to:string}} * @return {Promise<{ * model: ArrayBuffer, * vocab: ArrayBuffer, * shortlist: ArrayBuffer, * qualityModel: ArrayBuffer? * }>} */ getTranslationModel({from, to}, options) { const key = JSON.stringify({from, to}); if (!this.buffers.has(key)) { const promise = this.loadTranslationModel({from, to}, options); // set the promise so we return the same promise when its still pending this.buffers.set(key, promise); // But if loading fails, remove the promise again so we can try again later promise.catch(err => this.buffers.delete(key)) } return this.buffers.get(key); } /** * Downloads a translation model and returns a set of * ArrayBuffers. These can then be passed to a TranslationWorker thread * to instantiate a TranslationModel inside the WASM vm. * @param {{from:string, to:string}} * @param {{signal:AbortSignal?}?} * @return {Promise<{ * model: ArrayBuffer, * vocab: ArrayBuffer, * shortlist: ArrayBuffer, * qualityModel: ArrayBuffer? * config: string? * }>} */ async loadTranslationModel({from, to}, options) { performance.mark(`loadTranslationModule.${JSON.stringify({from, to})}`); // Find that model in the registry which will tell us about its files const entries = (await this.registry).filter(model => model.from == from && model.to == to); if (!entries) throw new Error(`No model for '${from}' -> '${to}'`); const files = entries[0].files; const abort = () => reject(new CancelledError('abort signal')); // Promise that resolves (or rejects really) when the abort signal hits const escape = new Promise((accept, reject) => { if (options?.signal) options.signal.addEventListener('abort', abort); }); // Download all files mentioned in the registry entry. Race the promise // of all fetch requests, and a promise that rejects on the abort signal const buffers = Object.fromEntries(await Promise.race([ Promise.all(Object.entries(files).map(async ([part, file]) => { // Special case where qualityModel is not part of the model, and this // should also catch the `config` case. if (file === undefined || file.name === undefined) return [part, null]; try { return [part, await this.fetch(file.name, file.expectedSha256Hash, options)]; } catch (cause) { throw new Error(`Could not fetch ${file.name} for ${from}->${to} model`, {cause}); } })), escape ])); // Nothing to abort now, clean up abort promise if (options?.signal) options.signal.removeEventListener('abort', abort); performance.measure('loadTranslationModel', `loadTranslationModule.${JSON.stringify({from, to})}`); let vocabs = []; if (buffers.vocab) vocabs = [buffers.vocab] else if (buffers.trgvocab && buffers.srcvocab) vocabs = [buffers.srcvocab, buffers.trgvocab] else throw new Error(`Could not identify vocab files for ${from}->${to} model among: ${Array.from(Object.keys(files)).join(' ')}`); let config = {}; // For the Ukrainian models we need to override the gemm-precision if (files.model.name.endsWith('intgemm8.bin')) config['gemm-precision'] = 'int8shiftAll'; // If quality estimation is used, we need to turn off skip-cost. Turning // this off causes quite the slowdown. if (files.qualityModel) config['skip-cost'] = false; // Allow the registry to also specify marian configuration parameters if (files.config) Object.assign(config, files.config); // Translate to generic bergamot-translator format that also supports // separate vocabularies for input & output language, and calls 'lex' // a more descriptive 'shortlist'. return { model: buffers.model, shortlist: buffers.lex, vocabs, qualityModel: buffers.qualityModel, config }; } /** * Helper to download file from the web. Verifies the checksum. * @param {string} url * @param {string?} checksum sha256 checksum as hexadecimal string * @param {{signal:AbortSignal}?} extra fetch options * @returns {Promise<ArrayBuffer>} */ async fetch(url, checksum, extra) { // Rig up a timeout cancel signal for our fetch const controller = new AbortController(); const abort = () => controller.abort(); const timeout = this.downloadTimeout ? setTimeout(abort, this.downloadTimeout) : null; try { // Also maintain the original abort signal if (extra?.signal) extra.signal.addEventListener('abort', abort); const options = { credentials: 'omit', signal: controller.signal, }; if (checksum) options['integrity'] = `sha256-${this.hexToBase64(checksum)}`; // Disable the integrity check for NodeJS because of // https://github.com/nodejs/undici/issues/1594 if (typeof window === 'undefined') delete options['integrity']; // Start downloading the url, using the hex checksum to ask // `fetch()` to verify the download using subresource integrity const response = await fetch(url, options); // Finish downloading (or crash due to timeout) return await response.arrayBuffer(); } finally { if (timeout) clearTimeout(timeout); if (extra?.signal) extra.signal.removeEventListener('abort', abort); } } /** * Converts the hexadecimal hashes from the registry to something we can use with * the fetch() method. */ hexToBase64(hexstring) { return btoa(hexstring.match(/\w{2}/g).map(function(a) { return String.fromCharCode(parseInt(a, 16)); }).join("")); } /** * Crappy named method that gives you a list of models to translate from * one language into the other. Generally this will be the same as you * just put in if there is a direct model, but it could return a list of * two models if you need to pivot through a third language. * Returns just [{from:str,to:str}...]. To be used something like this: * ``` * const models = await this.getModels(from, to); * models.forEach(({from, to}) => { * const buffers = await this.loadTranslationModel({from,to}); * [TranslationWorker].loadTranslationModel({from,to}, buffers) * }); * ``` * @returns {Promise<TranslationModel[]>} */ getModels({from, to}) { const key = JSON.stringify({from, to}); // Note that the `this.models` map stores Promises. This so that // multiple calls to `getModels` that ask for the same model will // return the same promise, and the actual lookup is only done once. // The lookup is async because we need to await `this.registry` if (!this.models.has(key)) this.models.set(key, this.findModels(from, to)); return this.models.get(key); } /** * Find model (or model pair) to translate from `from` to `to`. * @param {string} from * @param {string} to * @returns {Promise<TranslationModel[]>} */ async findModels(from, to) { const registry = await this.registry; let direct = [], outbound = [], inbound = []; registry.forEach(model => { if (model.from === from && model.to === to) direct.push(model); else if (model.from === from && model.to === this.pivotLanguage) outbound.push(model); else if (model.to === to && model.from === this.pivotLanguage) inbound.push(model); }); if (direct.length) return [direct[0]]; if (outbound.length && inbound.length) return [outbound[0], inbound[0]]; throw new Error(`No model available to translate from '${from}' to '${to}'`); } } /** * Translator balancing between throughput and latency. Can use multiple worker * threads. */ export class BatchTranslator { /** * @param {{ * cacheSize?: number, * useNativeIntGemm?: boolean, * workers?: number, * batchSize?: number, * downloadTimeout?: number, * workerUrl?: string | URL, * registryUrl?: string * pivotLanguage?: string? * }} options */ constructor(options, backing) { if (!backing) backing = new TranslatorBacking(options); this.backing = backing; /** * @type {Array<{idle:Boolean, worker:Proxy}>} List of active workers * (and a flag to mark them idle or not) */ this.workers = []; /** * Maximum number of workers * @type {number} */ this.workerLimit = Math.max(options?.workers || 0, 1); /** * List of batches we push() to & shift() from using `enqueue`. * @type {{ * id: number, * key: string, * priority: number, * models: TranslationModel[], * requests: Array<{ * request: TranslationRequest, * resolve: (response: TranslationResponse), * reject: (error: Error) * }> * }} */ this.queue = []; /** * batch serial to help keep track of batches when debugging * @type {Number} */ this.batchSerial = 0; /** * Number of requests in a batch before it is ready to be translated in * a single call. Bigger is better for throughput (better matrix packing) * but worse for latency since you'll have to wait for the entire batch * to be translated. * @type {Number} */ this.batchSize = Math.max(options?.batchSize || 8, 1); this.onerror = options?.onerror || (err => console.error('WASM Translation Worker error:', err)); } /** * Destructor that stops and cleans up. */ async delete() { // Empty the queue this.remove(() => true); // Terminate the workers this.workers.forEach(({worker}) => worker.terminate()); } /** * Makes sure queued work gets send to a worker. Will delay it till `idle` * to make sure the batches have been filled to some degree. Will keep * calling itself as long as there is work in the queue, but it does not * hurt to call it multiple times. This function always returns immediately. */ notify() { setTimeout(async () => { // Is there work to be done? if (!this.queue.length) return; // Find an idle worker let worker = this.workers.find(worker => worker.idle); // No worker free, but space for more? if (!worker && this.workers.length < this.workerLimit) { try { // Claim a place in the workers array (but mark it busy so // it doesn't get used by any other `notify()` calls). const placeholder = {idle: false}; this.workers.push(placeholder); // adds `worker` and `exports` props Object.assign(placeholder, await this.backing.loadWorker()); // At this point we know our new worker will be usable. worker = placeholder; } catch (e) { this.onerror(new Error(`Could not initialise translation worker: ${e.message}`)); } } // If no worker, that's the end of it. if (!worker) return; // Up to this point, this function has not used await, so no // chance that another call stole our batch since we did the check // at the beginning of this function and JavaScript is only // cooperatively parallel. const batch = this.queue.shift(); // Put this worker to work, marking as busy worker.idle = false; try { await this.consumeBatch(batch, worker.exports); } catch (e) { batch.requests.forEach(({reject}) => reject(e)); } worker.idle = true; // Is there more work to be done? Do another idleRequest if (this.queue.length) this.notify(); }); } /** * The only real public call you need! * ``` * const {target: {text:string}} = await this.translate({ * from: 'de', * to: 'en', * text: 'Hallo Welt!', * html: false, // optional * priority: 0 // optional, like `nice` lower numbers are translated first * }) * ``` * @param {TranslationRequest} request * @returns {Promise<TranslationResponse>} */ translate(request) { const {from, to, priority} = request; return new Promise(async (resolve, reject) => { try { // Batching key: only requests with the same key can be batched // together. Think same translation model, same options. const key = JSON.stringify({from, to}); // (Fetching models first because if we would do it between looking // for a batch and making a new one, we end up with a race condition.) const models = await this.backing.getModels(request); // Put the request and its callbacks into a fitting batch this.enqueue({key, models, request, resolve, reject, priority}); // Tell a worker to pick up the work at some point. this.notify(); } catch (e) { reject(e); } }); } /** * Prune pending requests by testing each one of them to whether they're * still relevant. Used to prune translation requests from tabs that got * closed. * @param {(request:TranslationRequest) => boolean} filter evaluates to true if request should be removed */ remove(filter) { const queue = this.queue; this.queue = []; queue.forEach(batch => { batch.requests.forEach(({request, resolve, reject}) => { if (filter(request)) { // Add error.request property to match response.request for // a resolve() callback. Pretty useful if you don't want to // do all kinds of Funcion.bind() dances. reject(Object.assign(new CancelledError('removed by filter'), {request})); return; } this.enqueue({ key: batch.key, priority: batch.priority, models: batch.models, request, resolve, reject }); }); }); } /** * Internal function used to put a request in a batch that still has space. * Also responsible for keeping the batches in order of priority. Called by * `translate()` but also used when filtering pending requests. * @param {{request:TranslateRequest, models:TranslationModel[], key:String, priority:Number?, resolve:(TranslateResponse)=>any, reject:(Error)=>any}} */ enqueue({key, models, request, resolve, reject, priority}) { if (priority === undefined) priority = 0; // Find a batch in the queue that we can add to // (TODO: can we search backwards? that would speed things up) let batch = this.queue.find(batch => { return batch.key === key && batch.priority === priority && batch.requests.length < this.batchSize }); // No batch or full batch? Queue up a new one if (!batch) { batch = {id: ++this.batchSerial, key, priority, models, requests: []}; this.queue.push(batch); this.queue.sort((a, b) => a.priority - b.priority); } batch.requests.push({request, resolve, reject}); } /** * Internal method that uses a worker thread to process a batch. You can * wait for the batch to be done by awaiting this call. You should only * then reuse the worker otherwise you'll just clog up its message queue. */ async consumeBatch(batch, worker) { performance.mark('BergamotBatchTranslator.start'); // Make sure the worker has all necessary models loaded. If not, tell it // first to load them. await Promise.all(batch.models.map(async ({from, to}) => { if (!await worker.hasTranslationModel({from, to})) { const buffers = await this.backing.getTranslationModel({from, to}); await worker.loadTranslationModel({from, to}, buffers); } })); // Call the worker to translate. Only sending the actually necessary // parts of the batch to avoid trying to send things that don't survive // the message passing API between this thread and the worker thread. const responses = await worker.translate({ models: batch.models.map(({from, to}) => ({from, to})), texts: batch.requests.map(({request: {text, html, qualityScores}}) => ({ text: text.toString(), html: !!html, qualityScores: !!qualityScores })) }); // Responses are in! Connect them back to their requests and call their // callbacks. batch.requests.forEach(({request, resolve, reject}, i) => { // TODO: look at response.ok and reject() if it is false resolve({ request, // Include request for easy reference? Will allow you // to specify custom properties and use that to link // request & response back to each other. ...responses[i] // {target: {text: String}} }); }); performance.measure('BergamotBatchTranslator', 'BergamotBatchTranslator.start'); } } /** * Translator optimised for interactive use. */ export class LatencyOptimisedTranslator { /** * @type {TranslatorBacking} */ backing; /** * @type {Promise<{idle:boolean, worker:Worker, exports:Proxy<TranslationWorker>}>} */ worker; /** * @type {{request: TranslationRequest, accept:(TranslationResponse), reject:(Error)} | null} */ pending; /** * @param {{ * cacheSize?: number, * useNativeIntGemm?: boolean, * downloadTimeout?: number, * workerUrl?: string | URL, * registryUrl?: string * pivotLanguage?: string? * }} options */ constructor(options, backing) { if (!backing) backing = new TranslatorBacking(options); this.backing = backing; // Exposing the this.loadWorker() returned promise through this.worker // so that you can use that to catch any errors that happened during // loading. this.worker = this.backing.loadWorker().then(worker => ({...worker, idle:true})); } /** * Destructor that stops and cleans up. */ async delete() { // Cancel pending translation if (this.pending) { this.pending.reject(new CancelledError('translator got deleted')); this.pending = null; } // Terminate the worker (I don't care if this fails) try { const {worker} = await this.worker; worker.terminate(); } finally { this.worker = null; } } /** * Sets `request` as the next translation to process. If there was already * a translation waiting to be processed, their promise is rejected with a * SupersededError. * @param {TranslationRequest} request * @return {Promise<TranslationResponse>} */ translate(request, options) { if (this.pending) this.pending.reject(new SupersededError()); return new Promise((accept, reject) => { const pending = {request, accept, reject, options}; if (options?.signal) { options.signal.addEventListener('abort', e => { reject(new CancelledError('abort signal')); if (this.pending === pending) this.pending = null; }); } this.pending = pending; this.notify(); }); } notify() { setTimeout(async () => { if (!this.pending) return; // Catch errors such as the worker not working try { // Possibly wait for the worker to finish loading. After it loaded // these calls are pretty much instantaneous. const worker = await this.worker; // Is another notify() call hogging the worker? Then stop. if (!worker.idle) return; // Claim the pending translation request. const {request, accept, reject, options} = this.pending; this.pending = null; // Mark the worker as occupied worker.idle = false; try { const models = await this.backing.getModels(request) await Promise.all(models.map(async ({from, to}) => { if (!await worker.exports.hasTranslationModel({from, to})) { const buffers = await this.backing.getTranslationModel({from, to}, {signal: options?.signal}); await worker.exports.loadTranslationModel({from, to}, buffers); } })); const {text, html, qualityScores} = request; const responses = await worker.exports.translate({ models: models.map(({from,to}) => ({from, to})), texts: [{text, html, qualityScores}] }); accept({request, ...responses[0]}); } catch (e) { reject(e); } worker.idle = true; // Is there more work to be done? Do another idleRequest if (this.pending) this.notify(); } catch (e) { this.backing.onerror(e); } }); } }