@vowpalwabbit/vowpalwabbit
Version:
wasm bindings for vowpal wabbit
447 lines (446 loc) • 21.8 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
const { v4: uuidv4 } = require('uuid');
const VWWasmModule = require('./vw-wasm.js');
// internals
const ProblemType = {
All: 'All',
CB: 'Cb',
};
// exported
class VWError extends Error {
constructor(message, originalError) {
super(message);
this.name = 'VWError';
this.stack = originalError;
}
}
exports.default = new Promise((resolve) => {
VWWasmModule().then((moduleInstance) => {
class WorkspaceBase {
constructor(type, readSync, writeSync, { args_str, model_file, model_array } = {}) {
let vwModelConstructor = null;
if (type === ProblemType.All) {
vwModelConstructor = moduleInstance.VWModel;
}
else if (type === ProblemType.CB) {
vwModelConstructor = moduleInstance.VWCBModel;
}
else {
throw new Error("Unknown model type");
}
this._readSync = readSync;
this._writeSync = writeSync;
let model_array_ptr = undefined;
let model_array_len = undefined;
if (model_array !== undefined) {
[model_array_ptr, model_array_len] = model_array;
}
let model_array_defined = model_array_ptr !== undefined && model_array_len !== undefined && model_array_ptr !== null && model_array_len > 0;
if (args_str === undefined && model_file === undefined && !model_array_defined) {
throw new Error("Can not initialize vw object without args_str or a model_file or a model_array");
}
if (model_file !== undefined && model_array_defined) {
throw new Error("Can not initialize vw object with both model_file and model_array");
}
this._args_str = args_str;
if (args_str === undefined) {
this._args_str = "";
}
if (model_file !== undefined) {
let modelBuffer = readSync(model_file);
let ptr = moduleInstance._malloc(modelBuffer.byteLength);
let heapBytes = new Uint8Array(moduleInstance.HEAPU8.buffer, ptr, modelBuffer.byteLength);
heapBytes.set(new Uint8Array(modelBuffer));
this._instance = new vwModelConstructor(this._args_str, ptr, modelBuffer.byteLength);
moduleInstance._free(ptr);
}
else if (model_array_defined) {
this._instance = new vwModelConstructor(this._args_str, model_array_ptr, model_array_len);
}
else {
this._instance = new vwModelConstructor(this._args_str);
}
return this;
}
/**
* Returns the enum value of the prediction type corresponding to the problem type of the model
* @returns enum value of prediction type
*/
predictionType() {
return this._instance.predictionType();
}
/**
* The current total sum of the progressive validation loss
*
* @returns {number} the sum of all losses accumulated by the model
*/
sumLoss() {
return this._instance.sumLoss();
}
/**
*
* Takes a file location and stores the VW model in binary format in the file.
*
* @param {string} model_file the path to the file where the model will be saved
*/
saveModelToFile(model_file) {
let char_vector = this._instance.getModel();
const size = char_vector.size();
const uint8Array = new Uint8Array(size);
for (let i = 0; i < size; ++i) {
uint8Array[i] = char_vector.get(i);
}
this._writeSync(model_file, Buffer.from(uint8Array));
char_vector.delete();
}
/**
* Gets the VW model in binary format as a Uint8Array that can be saved to a file.
* There is no need to delete or free the array returned by this function.
* If the same array is however used to re-load the model into VW, then the array needs to be stored in wasm memory (see loadModelFromArray)
*
* @returns {Uint8Array} the VW model in binary format
*/
getModelAsArray() {
let char_vector = this._instance.getModel();
const size = char_vector.size();
const uint8Array = new Uint8Array(size);
for (let i = 0; i < size; ++i) {
uint8Array[i] = char_vector.get(i);
}
char_vector.delete();
return uint8Array;
}
/**
*
* Takes a file location and loads the VW model from the file.
*
* @param {string} model_file the path to the file where the model will be loaded from
*/
loadModelFromFile(model_file) {
let modelBuffer = this._readSync(model_file);
let ptr = moduleInstance._malloc(modelBuffer.byteLength);
let heapBytes = new Uint8Array(moduleInstance.HEAPU8.buffer, ptr, modelBuffer.byteLength);
heapBytes.set(new Uint8Array(modelBuffer));
this._instance.loadModelFromBuffer(ptr, modelBuffer.byteLength);
moduleInstance._free(ptr);
}
/**
* Takes a model in an array binary format and loads it into the VW instance.
* The memory must be allocated via the WebAssembly module's _malloc function and should later be freed via the _free function.
*
* @param {number} model_array_ptr the pre-loaded model's array pointer
* The memory must be allocated via the WebAssembly module's _malloc function and should later be freed via the _free function.
* @param {number} model_array_len the pre-loaded model's array length
*/
loadModelFromArray(model_array_ptr, model_array_len) {
this._instance.loadModelFromBuffer(model_array_ptr, model_array_len);
}
/**
* Deletes the underlying VW instance. This function should be called when the instance is no longer needed.
*/
delete() {
this._instance.delete();
}
}
;
/**
* A Wrapper around the Wowpal Wabbit C++ library.
* @class
* @private
* @extends WorkspaceBase
*/
class Workspace extends WorkspaceBase {
/**
* Creates a new Vowpal Wabbit workspace.
* Can accept either or both string arguments and a model file.
*
* @constructor
* @param {Function} readSync - A function that reads a file synchronously and returns a buffer
* @param {Function} writeSync - A function that writes a buffer to a file synchronously
* @param {string} [args_str] - The arguments that are used to initialize Vowpal Wabbit (optional)
* @param {string} [model_file] - The path to the file where the model will be loaded from (optional)
* @param {tuple} [model_array] - The pre-loaded model's array pointer and length (optional).
* The memory must be allocated via the WebAssembly module's _malloc function and should later be freed via the _free function.
* @throws {Error} Throws an error if:
* - no argument is provided
* - both string arguments and a model file are provided, and the string arguments and arguments defined in the model clash
* - both string arguments and a model array are provided, and the string arguments and arguments defined in the model clash
* - both a model file and a model array are provided
*/
constructor(readSync, writeSync, { args_str, model_file, model_array } = {}) {
super(ProblemType.All, readSync, writeSync, { args_str, model_file, model_array });
}
/**
* Parse a line of text into a VW example.
* The example can then be used for prediction or learning.
* finishExample() must be called and then delete() on the example, when it is no longer needed.
*
* @param {string} line
* @returns a parsed vw example that can be used for prediction or learning
*/
parse(line) {
return this._instance.parse(line);
}
/**
* Creates a new example from a dense array of features, where the key of the map is the namespace.
*
* @example
* let example = model.create_example_from_dense({
* my_namespace: [0.3, 0.2, 0.1, 0.3, 0.5, 0.9]
* });
* @param {Map<string, number[]>} features
* @param {string} label Empty label by default
* @returns a parsed vw example that can be used for prediction or learning
*/
createExampleFromDense(features, label = "") {
return this._instance.createExampleFromDense(features, label);
}
/**
* Calls vw predict on the example and returns the prediction.
*
* @param {object} example returned from parse()
* @returns the prediction with a type corresponding to the reduction that was used
* @throws {VWError} Throws an error if the example is not well defined
*/
predict(example) {
try {
return this._instance.predict(example);
}
catch (e) {
throw new VWError(e.message, e);
}
}
/**
* Calls vw learn on the example and updates the model
*
* @param {object} example returned from parse()
* @throws {VWError} Throws an error if the example is not well defined
*/
learn(example) {
try {
return this._instance.learn(example);
}
catch (e) {
throw new VWError(e.message, e);
}
}
/**
* Cleans the example and returns it to the pool of available examples. delete() must also be called on the example object
*
* @param {object} example returned from parse()
*/
finishExample(example) {
return this._instance.finishExample(example);
}
}
;
/**
* A Wrapper around the Wowpal Wabbit C++ library for Contextual Bandit exploration algorithms.
* @class
* @private
* @extends WorkspaceBase
*/
class CbWorkspace extends WorkspaceBase {
constructor(readSync, writeSync, { args_str, model_file, model_array } = {}) {
super(ProblemType.CB, readSync, writeSync, { args_str, model_file, model_array });
this._ex = "";
}
/**
* Takes a CB example and returns an array of (action, score) pairs, representing the probability mass function over the available actions
* The returned pmf can be used with samplePmf to sample an action
*
* Example must have the following properties:
* - text_context: a string representing the context
*
* @param {object} example the example object that will be used for prediction
* @returns {array} probability mass function, an array of action,score pairs that was returned by predict
* @throws {VWError} Throws an error if the example text_context is missing from the example
*/
predict(example) {
try {
return this._instance.predict(example);
}
catch (e) {
throw new VWError(e.message, e);
}
}
/**
* Takes a CB example and uses it to update the model
*
* Example must have the following properties:
* - text_context: a string representing the context
* - labels: an array of label objects (usually one), each label object must have the following properties:
* - action: the action index
* - cost: the cost of the action
* - probability: the probability of the action
*
* A label object should have more than one labels only if a reduction that accepts multiple labels was used (e.g. graph_feedback)
*
*
* @param {object} example the example object that will be used for prediction
* @throws {VWError} Throws an error if the example does not have the required properties to learn
*/
learn(example) {
try {
return this._instance.learn(example);
}
catch (e) {
throw new VWError(e.message, e);
}
}
/**
* Accepts a CB example (in text format) line by line. Once a full CB example is passed in it will call learnFromString.
* This is intended to be used with files that have CB examples, that were logged using logCBExampleToStream and are being read line by line.
*
* @param {string} line a string representing a line from a CB example in text Vowpal Wabbit format
*/
addLine(line) {
if (line.trim() === '') {
this.learnFromString(this._ex);
this._ex = "";
}
else {
this._ex = this._ex + line + "\n";
}
}
/**
* Takes a full multiline CB example in text format and uses it to update the model. This is intended to be used with examples that are logged to a file using logCBExampleToStream.
*
* @param {string} example a string representing the CB example in text Vowpal Wabbit format
* @throws {Error} Throws an error if the example is an object with a label and/or a text_context
*/
learnFromString(example) {
if (example.hasOwnProperty("labels") || example.hasOwnProperty("text_context")) {
throw new Error("Example should not have a label or a text_context when using learnFromString, the label and context should just be in the string");
}
let ex = {
text_context: example
};
return this._instance.learnFromString(ex);
}
/**
*
* Takes an exploration prediction (array of action, score pairs) and returns a single action and score,
* along with a unique id that was used to seed the sampling and that can be used to track and reproduce the sampling.
*
* @param {array} pmf probability mass function, an array of action,score pairs that was returned by predict
* @returns {object} an object with the following properties:
* - action: the action index that was sampled
* - score: the score of the action that was sampled
* - uuid: the uuid that was passed to the predict function
* @throws {VWError} Throws an error if the input is not an array of action,score pairs
*/
samplePmf(pmf) {
let uuid = uuidv4();
try {
let ret = this._instance._samplePmf(pmf, uuid);
ret["uuid"] = uuid;
return ret;
}
catch (e) {
throw new VWError(e.message, e);
}
}
/**
*
* Takes an exploration prediction (array of action, score pairs) and a unique id that is used to seed the sampling,
* and returns a single action index and the corresponding score.
*
* @param {array} pmf probability mass function, an array of action,score pairs that was returned by predict
* @param {string} uuid a unique id that can be used to seed the prediction
* @returns {object} an object with the following properties:
* - action: the action index that was sampled
* - score: the score of the action that was sampled
* - uuid: the uuid that was passed to the predict function
* @throws {VWError} Throws an error if the input is not an array of action,score pairs
*/
samplePmfWithUUID(pmf, uuid) {
try {
let ret = this._instance._samplePmf(pmf, uuid);
ret["uuid"] = uuid;
return ret;
}
catch (e) {
throw new VWError(e.message, e);
}
}
/**
*
* Takes an example with a text_context field and calls predict. The prediction (a probability mass function over the available actions)
* will then be sampled from, and only the chosen action index and the corresponding score will be returned,
* along with a unique id that was used to seed the sampling and that can be used to track and reproduce the sampling.
*
* @param {object} example an example object containing the context to be used during prediction
* @returns {object} an object with the following properties:
* - action: the action index that was sampled
* - score: the score of the action that was sampled
* - uuid: the uuid that was passed to the predict function
* @throws {VWError} if there is no text_context field in the example
*/
predictAndSample(example) {
try {
let uuid = uuidv4();
let ret = this._instance._predictAndSample(example, uuid);
ret["uuid"] = uuid;
return ret;
}
catch (e) {
throw new VWError(e.message, e);
}
}
/**
*
* Takes an example with a text_context field and calls predict, and a unique id that is used to seed the sampling.
* The prediction (a probability mass function over the available actions) will then be sampled from, and only the chosen action index
* and the corresponding score will be returned, along with a unique id that was used to seed the sampling and that can be used to track and reproduce the sampling.
*
* @param {object} example an example object containing the context to be used during prediction
* @returns {object} an object with the following properties:
* - action: the action index that was sampled
* - score: the score of the action that was sampled
* - uuid: the uuid that was passed to the predict function
* @throws {VWError} if there is no text_context field in the example
*/
predictAndSampleWithUUID(example, uuid) {
try {
let ret = this._instance._predictAndSample(example, uuid);
ret["uuid"] = uuid;
return ret;
}
catch (e) {
throw new VWError(e.message, e);
}
}
}
;
function getExceptionMessage(exception) {
return moduleInstance.getExceptionMessage(exception);
}
;
class Prediction {
}
Prediction.Type = {
Scalar: moduleInstance.PredictionType.scalar,
Scalars: moduleInstance.PredictionType.scalars,
ActionScores: moduleInstance.PredictionType.action_scores,
Pdf: moduleInstance.PredictionType.pdf,
ActionProbs: moduleInstance.PredictionType.action_probs,
MultiClass: moduleInstance.PredictionType.multiclass,
MultiLabels: moduleInstance.PredictionType.multilabels,
Prob: moduleInstance.PredictionType.prob,
MultiClassProb: moduleInstance.PredictionType.multiclassprob,
DecisionProbs: moduleInstance.PredictionType.decision_probs,
ActionPdfValue: moduleInstance.PredictionType.ActionPdfValue,
ActiveMultiClass: moduleInstance.PredictionType.activeMultiClass,
};
;
resolve({
Workspace: Workspace,
CbWorkspace: CbWorkspace,
Prediction: Prediction,
getExceptionMessage: getExceptionMessage,
wasmModule: moduleInstance
});
});
});