UNPKG

ml5-save

Version:
180 lines (155 loc) 4.99 kB
// Copyright (c) 2018 ml5 // // This software is released under the MIT License. // https://opensource.org/licenses/MIT /* Word2Vec */ import * as tf from '@tensorflow/tfjs'; import callCallback from '../utils/callcallback'; class Word2Vec { /** * Create Word2Vec model * @param {String} modelPath - path to pre-trained word vector model in .json e.g data/wordvecs1000.json * @param {function} callback - Optional. A callback function that is called once the model has loaded. If no callback is provided, it will return a promise * that will be resolved once the model has loaded. */ constructor(modelPath, callback) { this.model = {}; this.modelPath = modelPath; this.modelSize = 0; this.modelLoaded = false; this.ready = callCallback(this.loadModel(), callback); // TODO: Add support to Promise // this.then = this.ready.then.bind(this.ready); } async loadModel() { const json = await fetch(this.modelPath) .then(response => response.json()); Object.keys(json.vectors).forEach((word) => { this.model[word] = tf.tensor1d(json.vectors[word]); }); this.modelSize = Object.keys(this.model).length; this.modelLoaded = true; return this; } dispose(callback) { Object.values(this.model).forEach(x => x.dispose()); if (callback) { callback(); } } async add(inputs, maxOrCb, cb) { const { max, callback } = Word2Vec.parser(maxOrCb, cb, 10); await this.ready; return tf.tidy(() => { const sum = Word2Vec.addOrSubtract(this.model, inputs, 'ADD'); const result = Word2Vec.nearest(this.model, sum, inputs.length, inputs.length + max); if (callback) { callback(undefined, result); } return result; }); } async subtract(inputs, maxOrCb, cb) { const { max, callback } = Word2Vec.parser(maxOrCb, cb, 10); await this.ready; return tf.tidy(() => { const subtraction = Word2Vec.addOrSubtract(this.model, inputs, 'SUBTRACT'); const result = Word2Vec.nearest(this.model, subtraction, inputs.length, inputs.length + max); if (callback) { callback(undefined, result); } return result; }); } async average(inputs, maxOrCb, cb) { const { max, callback } = Word2Vec.parser(maxOrCb, cb, 10); await this.ready; return tf.tidy(() => { const sum = Word2Vec.addOrSubtract(this.model, inputs, 'ADD'); const avg = tf.div(sum, tf.tensor(inputs.length)); const result = Word2Vec.nearest(this.model, avg, inputs.length, inputs.length + max); if (callback) { callback(undefined, result); } return result; }); } async nearest(input, maxOrCb, cb) { const { max, callback } = Word2Vec.parser(maxOrCb, cb, 10); await this.ready; const vector = this.model[input]; let result; if (vector) { result = Word2Vec.nearest(this.model, vector, 1, max + 1); } else { result = null; } if (callback) { callback(undefined, result); } return result; } async getRandomWord(callback) { await this.ready; const words = Object.keys(this.model); const result = words[Math.floor(Math.random() * words.length)]; if (callback) { callback(undefined, result); } return result; } static parser(maxOrCallback, cb, defaultMax) { let max = defaultMax; let callback = cb; if (typeof maxOrCallback === 'function') { callback = maxOrCallback; } else if (typeof maxOrCallback === 'number') { max = maxOrCallback; } return { max, callback }; } static addOrSubtract(model, values, operation) { return tf.tidy(() => { const vectors = []; const notFound = []; if (values.length < 2) { throw new Error('Invalid input, must be passed more than 1 value'); } values.forEach((value) => { const vector = model[value]; if (!vector) { notFound.push(value); } else { vectors.push(vector); } }); if (notFound.length > 0) { throw new Error(`Invalid input, vector not found for: ${notFound.toString()}`); } let result = vectors[0]; if (operation === 'ADD') { for (let i = 1; i < vectors.length; i += 1) { result = tf.add(result, vectors[i]); } } else { for (let i = 1; i < vectors.length; i += 1) { result = tf.sub(result, vectors[i]); } } return result; }); } static nearest(model, input, start, max) { const nearestVectors = []; Object.keys(model).forEach((word) => { const distance = tf.util.distSquared(input.dataSync(), model[word].dataSync()); nearestVectors.push({ word, distance }); }); nearestVectors.sort((a, b) => a.distance - b.distance); return nearestVectors.slice(start, max); } } const word2vec = (model, cb) => new Word2Vec(model, cb); export default word2vec;