UNPKG

@modelx/model

Version:

Deep Learning Classification, LSTM Time Series, Regression and Multi-Layered Perceptrons with Tensorflow

82 lines (78 loc) 3.28 kB
import { TensorScriptModelInterface, TensorScriptOptions, TensorScriptProperties, Matrix, Vector, PredictionOptions, InputTextArray, } from './model_interface'; import * as UniversalSentenceEncoder from '@tensorflow-models/universal-sentence-encoder'; let model:UniversalSentenceEncoder.UniversalSentenceEncoder; let tokenizer:UniversalSentenceEncoder.Tokenizer; /** * Text Embedding with Tensorflow Universal Sentence Encoder (USE) * @class TextEmbedding * @implements {TensorScriptModelInterface} */ export class TextEmbedding extends TensorScriptModelInterface { /** * @param {Object} options - Options for USE * @param {{model:Object,tf:Object,}} properties - extra instance properties */ constructor(options:TensorScriptOptions = {}, properties?:TensorScriptProperties) { const config = Object.assign({ }, options); super(config, properties); this.type = 'TextEmbedding'; return this; } /** * Asynchronously loads Universal Sentence Encoder and tokenizer * @override * @return {Object} returns loaded UniversalSentenceEncoder model */ async train() { const promises:Promise<any>[] = []; if (!model) promises.push(UniversalSentenceEncoder.load()); else promises.push(Promise.resolve(model)); if (!tokenizer) promises.push(UniversalSentenceEncoder.loadTokenizer()); else promises.push(Promise.resolve(tokenizer)); const USE = await Promise.all(promises); if (!model) model = USE[ 0 ]; if (!tokenizer) tokenizer = USE[ 1 ]; this.model = model; this.tokenizer = tokenizer; this.trained = true; this.compiled = true; return this.model; } /** * Calculates sentence embeddings * @override * @param {Array<Array<number>>|Array<number>} input_array - new test independent variables * @param {Object} options - model prediction options * @return {{data: Promise}} returns tensorflow prediction */ calculate(input_array:InputTextArray, options = {}) { if (!input_array || Array.isArray(input_array) === false) throw new Error('invalid input array of sentences'); const embeddings = this.model.embed(input_array); return embeddings; } /** * Returns prediction values from tensorflow model * @param {Array<string>} input_matrix - array of sentences to embed * @param {Boolean} [options.json=true] - return object instead of typed array * @param {Boolean} [options.probability=true] - return real values instead of integers * @return {Array<Array<number>>} predicted model values */ async predict(input_array:InputTextArray, options:PredictionOptions = {}): Promise<Matrix|Vector> { const config = Object.assign({ json: true, probability: true, }, options); const embeddings = await this.calculate(input_array, options); const predictions:number[] = await embeddings.data(); if (config.json === false) { return predictions; } else { const shape = [input_array.length, 512, ]; const predictionValues = (options.probability === false) ? Array.from(predictions).map(Math.round) : Array.from(predictions); return this.reshape(predictionValues, shape); } } }