UNPKG

languagemodel

Version:

A natural language model and cross-language model, for natural language understanding and generation

215 lines (187 loc) 8.54 kB
var LanguageModel = require('./LanguageModel'); var logSumExp = require('./logSumExp'); var extend = require('util')._extend; /** * This class represents a model for two different languages - input language and output language. * Based on: * * Leuski Anton, Traum David. A Statistical Approach for Text Processing in Virtual Humans tech. rep.University of Southern California, Institute for Creative Technologies 2008. * http://www.citeulike.org/user/erelsegal-halevi/article/12540655 * * @author Erel Segal-Halevi * @since 2013-08 * * opts - may contain the following options: * * smoothingCoefficient - the lambda-factor for smoothing the unigram probabilities. */ var CrossLanguageModel = function(opts) { this.smoothingCoefficient = opts.smoothingCoefficient || 0.9; this.inputLanguageModel = new LanguageModel(opts); this.outputLanguageModel = new LanguageModel(opts); } CrossLanguageModel.prototype = { /** * Tell the model that the given sample belongs to the given classes. * * @param sample * a document. * @param classes * an object whose KEYS are classes, or an array whose VALUES are classes. */ trainOnline: function(features, labels) { throw new Error("CrossLanguageModel does not support online training"); }, /** * Train the model with all the given documents. * * @param dataset * an array with objects of the format: * {input: {feature1:count1, feature2:count2,...}, output: {feature1:count1, feature2:count2,...}} */ trainBatch : function(dataset) { this.inputLanguageModel.trainBatch(dataset.map(function(datum) {return datum.input;})); this.outputLanguageModel.trainBatch(dataset.map(function(datum) {return datum.output;})); this.outputFeatures = extend({}, this.outputLanguageModel.getAllWordCounts()); delete this.outputFeatures["_total"]; }, /** * Train the model with all the given documents. * * @param dataset * an array with objects of the format: * {input: {feature1:count1, feature2:count2,...}, output: {feature1:count1, feature2:count2,...}} */ trainOnline: function(input, output) { throw new Error("CrossLanguageModel currently does not support online training"); }, getAllWordCounts: function() { return this.mapWordToTotalCount; }, /** * Calculate the Kullback-Leibler divergence between the language models of the given samples. * This can be used as an approximation of the (inverse) semantic similarity. between them. * * @param inputSentenceCounts (hash) represents a sentence from the INPUT domain. * @param outputSentenceCounts (hash) represents a sentence from the OUTPUT domain. * * @note divergence is not symmetric - divergence(a,b) != divergence(b,a). */ divergence: function(inputSentenceCounts, outputSentenceCounts) { // (6) D(P(W)||P(F)) = ... if (outputSentenceCounts!==Object(outputSentenceCounts)) throw new Error("expected outputSentenceCounts to be an object, but found "+JSON.stringify(outputSentenceCounts)); var elements = []; // elements for summation // if (inputSentenceCounts!==Object(inputSentenceCounts)) // throw new Error("expected inputSentenceCounts to be an object, but found "+JSON.stringify(inputSentenceCounts)); // if (outputSentenceCounts!==Object(outputSentenceCounts)) // throw new Error("expected outputSentenceCounts to be an object, but found "+JSON.stringify(outputSentenceCounts)); for (var feature in this.outputFeatures) { var logFeatureGivenInput = this.logProbFeatureGivenSentence(feature, inputSentenceCounts); // if (isNaN(logFeatureGivenInput)||!isFinite(logFeatureGivenInput)) throw new Error("logFeatureGivenInput is "+logFeatureGivenInput); var probFeatureGivenInput = Math.exp(logFeatureGivenInput); var logFeatureGivenOutput = this.outputLanguageModel.logProbWordGivenSentence(feature, outputSentenceCounts); // if (isNaN(logFeatureGivenOutput)||!isFinite(logFeatureGivenOutput)) throw new Error("logFeatureGivenOutput ("+feature+", "+outputSentenceCounts+") is "+logFeatureGivenOutput); var element = probFeatureGivenInput * (logFeatureGivenInput - logFeatureGivenOutput); // if (isNaN(element)||!isFinite(element)) throw new Error(probFeatureGivenInput+" * ("+logFeatureGivenInput+" - "+logFeatureGivenOutput+") = "+element); elements.push(element) } return elements.reduce(function(memo, num){ return memo + num; }, 0); }, /** * Calculate the similarity scores between the given input sentence and all output sentences in the corpus, sorted from high (most similar) to low (least similar). * Note: similarity = - divergence */ similarities: function(inputSentenceCounts) { var sims = []; for (var i in this.outputLanguageModel.dataset) { var output = extend({}, this.outputLanguageModel.dataset[i]); delete output['_total']; sims.push({ output: output, similarity: -this.divergence(inputSentenceCounts, output) }); } sims.sort(function(a,b) { return b.similarity-a.similarity; }); return sims; }, /** * @param feature a single feature (-word) from the OUTPUT domain. * @param givenSentenceCounts a hash that represents a sentence from the INPUT domain. */ logProbFeatureGivenSentence: function(feature, givenSentenceCounts) { // (5) P(f|W) = ... if (!givenSentenceCounts) throw new Error("no givenSentenceCounts"); var logSentenceAndFeature = this.logProbSentenceAndFeatureGivenDataset(feature,givenSentenceCounts); if (isNaN(logSentenceAndFeature)||!isFinite(logSentenceAndFeature)) throw new Error("logSentenceAndFeature is "+logSentenceAndFeature); var logSentence = this.inputLanguageModel.logProbSentenceGivenDataset(givenSentenceCounts); if (isNaN(logSentence)||!isFinite(logSentence)) throw new Error("logSentence is "+logSentence); //console.log("\t\t(5) "+feature+": "+Math.exp(logSentenceAndFeature)*81+" / "+Math.exp(logSentence)*81+" = "+Math.exp((logSentenceAndFeature - logSentence))); return logSentenceAndFeature - logSentence; }, /** * @param feature a single feature (-word) from the OUTPUT domain. * @param sentenceCounts a hash that represents a sentence from the INPUT domain. * @return the joint probability of the output feature and the input sentence. */ logProbSentenceAndFeatureGivenDataset: function(feature, sentenceCounts) { // (2') log P(f,w1...wn) = ... if (!sentenceCounts) throw new Error("no sentenceCounts"); var logProducts = []; for (var i = 0; i<this.inputLanguageModel.dataset.length; ++i) { logProducts.push( this.inputLanguageModel .logProbSentenceGivenSentence(sentenceCounts, this.inputLanguageModel.dataset[i]) + this.outputLanguageModel.logProbWordGivenSentence(feature, this.outputLanguageModel.dataset[i]) ); } var logSentenceLikelihood = logSumExp(logProducts); return logSentenceLikelihood - Math.log(this.inputLanguageModel.dataset.length); // The last element is not needed in practice (see eq. (5)) }, toJSON: function() { return { inputLanguageModel: this.inputLanguageModel.toJSON(), outputLanguageModel: this.outputLanguageModel.toJSON(), }; }, fromJSON: function(json) { this.inputLanguageModel.fromJSON(json.inputLanguageModel); this.outputLanguageModel.fromJSON(json.outputLanguageModel); } } module.exports = CrossLanguageModel; if (process.argv[1] === __filename) { console.log("CrossLanguageModel demo start"); var model = new CrossLanguageModel({ smoothingFactor : 0.9, }); var wordcounts = require('./wordcounts'); model.trainBatch([ {input: wordcounts("I want aa"), output: wordcounts("a")}, {input: wordcounts("I want bb"), output: wordcounts("b")}, {input: wordcounts("I want cc"), output: wordcounts("c")}, ]); var assertProbSentence = function(actual, expected) { if (Math.abs(actual-expected)/expected>0.01) { console.warn("Received "+actual+" but expected "+expected); } } var show = function(sentence) { console.log(sentence+": "); console.dir(model.similarities(wordcounts(sentence))); // console.log(model.similarities(wordcounts(sentence)).map(function(sim) { // var output = ""; // for (f in sim.output) // if (f!='_total') // output += (f+" "); // return {output:output, divergence:-sim.similarity}; // })); } show("I want"); show("I want nothing"); show("I want aa"); show("I want bb"); show("I want aa and bb"); show("I want aa , bb and cc"); show("I want aa bb cc"); console.log("CrossLanguageModel demo end"); }