UNPKG

@nlpjs/nlu

Version:

Natural Language Understanding

370 lines (342 loc) 11.3 kB
/* * Copyright (c) AXA Group Operations Spain S.A. * * Permission is hereby granted, free of charge, to any person obtaining * a copy of this software and associated documentation files (the * "Software"), to deal in the Software without restriction, including * without limitation the rights to use, copy, modify, merge, publish, * distribute, sublicense, and/or sell copies of the Software, and to * permit persons to whom the Software is furnished to do so, subject to * the following conditions: * * The above copyright notice and this permission notice shall be * included in all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ const { Clonable } = require('@nlpjs/core'); const { Language } = require('@nlpjs/language-min'); const DomainManager = require('./domain-manager'); class NluManager extends Clonable { constructor(settings = {}, container) { super( { settings: {}, container: settings.container || container, }, container ); this.applySettings(this.settings, settings); if (!this.settings.tag) { this.settings.tag = 'nlu-manager'; } this.registerDefault(); this.applySettings( this.settings, this.container.getConfiguration(this.settings.tag) ); if (!this.container.get('Language')) { this.container.register('Language', Language, false); } this.guesser = this.container.get('Language'); this.locales = []; this.languageNames = {}; this.domainManagers = {}; this.intentDomains = {}; if (this.settings.locales) { this.addLanguage(this.settings.locales); } this.applySettings(this, { pipelineTrain: this.getPipeline(`${this.settings.tag}-train`), pipelineProcess: this.getPipeline(`${this.settings.tag}-process`), }); } registerDefault() { this.container.registerConfiguration('nlu-manager', {}, false); this.container.registerPipeline( 'nlu-manager-train', ['.innerTrain'], false ); } describeLanguage(locale, name) { this.languageNames[locale] = { locale, name }; } addLanguage(srcLocales) { if (srcLocales) { const locales = Array.isArray(srcLocales) ? srcLocales : [srcLocales]; for (let i = 0; i < locales.length; i += 1) { const locale = locales[i].substr(0, 2).toLowerCase(); if (!this.locales.includes(locale)) { this.locales.push(locale); } if (!this.domainManagers[locale]) { this.domainManagers[locale] = new DomainManager( { locale, ...this.settings.domain, useNoneFeature: this.settings.useNoneFeature, trainByDomain: this.settings.trainByDomain, }, this.container ); } } } } removeLanguage(locales) { if (Array.isArray(locales)) { locales.forEach((locale) => this.removeLanguage(locale)); } else { delete this.domainManagers[locales]; const index = this.locales.indexOf(locales); if (index !== -1) { this.locales.splice(index, 1); } } } guessLanguage(srcInput) { const input = srcInput; const isString = typeof input === 'string'; if (this.locales.length === 1) { if (isString) { return this.locales[0]; } [input.locale] = this.locales; return input; } if (!input) { return isString ? undefined : input; } if (!isString && input.locale) { return input; } const utterance = isString ? input : input.utterance; if (this.locales.length === 1) { if (isString) { return this.locales[0]; } [input.locale] = this.locales; } const guess = this.guesser.guess(utterance, this.locales, 1); const locale = guess && guess.length > 0 ? guess[0].alpha2 : undefined; if (isString) { return locale; } input.locale = locale; return input; } assignDomain(srcLocale, srcIntent, srcDomain) { const locale = srcDomain ? srcLocale.substr(0, 2).toLowerCase() : undefined; const intent = srcDomain ? srcIntent : srcLocale; const domain = srcDomain || srcIntent; if (locale) { if (!this.intentDomains[locale]) { this.intentDomains[locale] = {}; } this.intentDomains[locale][intent] = domain; } else { for (let i = 0; i < this.locales.length; i += 1) { this.assignDomain(this.locales[i], intent, domain); } } } getIntentDomain(srcLocale, intent) { const locale = srcLocale.substr(0, 2).toLowerCase(); if (!this.intentDomains[locale]) { return 'default'; } return this.intentDomains[locale][intent] || 'default'; } getDomains() { const result = {}; const locales = Object.keys(this.intentDomains); for (let i = 0; i < locales.length; i += 1) { const locale = locales[i]; result[locale] = {}; const intents = Object.keys(this.intentDomains[locale]); for (let j = 0; j < intents.length; j += 1) { const intent = intents[j]; const domain = this.intentDomains[locale][intent]; if (!result[locale][domain]) { result[locale][domain] = []; } result[locale][domain].push(intent); } } return result; } consolidateLocale(srcLocale, utterance) { const locale = srcLocale ? srcLocale.substr(0, 2).toLowerCase() : this.guessLanguage(utterance); if (!locale) { throw new Error('Locale must be defined'); } return locale; } consolidateManager(locale) { const manager = this.domainManagers[locale]; if (!manager) { throw new Error(`Domain Manager not found for locale ${locale}`); } return manager; } add(srcLocale, utterance, intent) { const locale = this.consolidateLocale(srcLocale, utterance); const manager = this.consolidateManager(locale); const domain = this.getIntentDomain(locale, intent); this.guesser.addExtraSentence(locale, utterance); manager.add(domain, utterance, intent); } remove(srcLocale, utterance, intent) { const locale = this.consolidateLocale(srcLocale, utterance); const manager = this.consolidateManager(locale); const domain = this.getIntentDomain(locale, intent); manager.remove(domain, utterance, intent); } async innerTrain(settings) { let locales = settings.locales || this.locales; if (!Array.isArray(locales)) { locales = [locales]; } const promises = locales .filter((locale) => this.domainManagers[locale]) .map((locale) => this.domainManagers[locale].train(settings.settings)); return Promise.all(promises); } async train(settings) { const input = { nluManager: this, settings: this.applySettings(settings, this.settings), }; delete input.settings.tag; return this.runPipeline(input, this.pipelineTrain); } fillLanguage(srcInput) { const input = srcInput; input.languageGuessed = false; if (!input.locale) { input.locale = this.guessLanguage(input.utterance); input.languageGuessed = true; } if (input.locale) { input.localeIso2 = input.locale.substr(0, 2).toLowerCase(); input.language = ( this.languageNames[input.localeIso2] || this.guesser.languagesAlpha2[input.localeIso2] || {} ).name; } return input; } classificationsIsNone(classifications) { if (classifications.length === 1) { return false; } if (classifications.length === 0 || classifications[0].score === 0) { return true; } return classifications[0].score === classifications[1].score; } checkIfIsNone(srcInput) { const input = srcInput; if (this.classificationsIsNone(input.classifications)) { input.intent = 'None'; input.score = 1; } return input; } async innerClassify(srcInput) { const input = srcInput; const domain = this.domainManagers[input.localeIso2]; if (!domain) { input.classifications = []; input.domain = undefined; input.intent = undefined; input.score = undefined; return input; } const classifications = await domain.process(srcInput); input.classifications = classifications.classifications.sort( (a, b) => b.score - a.score ); if (input.classifications.length === 0) { input.classifications.push({ intent: 'None', score: 1 }); } input.intent = input.classifications[0].intent; input.score = input.classifications[0].score; if (input.intent === 'None') { classifications.domain = 'default'; } else if (classifications.domain === 'default') { input.domain = this.getIntentDomain(input.locale, input.intent); } else { input.domain = classifications.domain; } return input; } async defaultPipelineProcess(input) { let output = await this.fillLanguage(input); output = await this.innerClassify(output); output = await this.checkIfIsNone(output); delete output.settings; delete output.classification; return output; } process(locale, utterance, domain, settings) { const input = typeof locale === 'object' ? locale : { locale: utterance === undefined ? undefined : locale, utterance: utterance === undefined ? locale : utterance, domain, settings: settings || this.settings, }; if (this.pipelineProcess) { return this.runPipeline(input, this.pipelineProcess); } return this.defaultPipelineProcess(input); } toJSON() { const result = { settings: this.settings, locales: this.locales, languageNames: this.languageNames, domainManagers: {}, intentDomains: this.intentDomains, extraSentences: this.guesser.extraSentences.slice(0), }; delete result.settings.container; const keys = Object.keys(this.domainManagers); for (let i = 0; i < keys.length; i += 1) { const key = keys[i]; result.domainManagers[key] = this.domainManagers[key].toJSON(); } return result; } fromJSON(json) { this.applySettings(this.settings, json.settings); for (let i = 0; i < json.locales.length; i += 1) { this.addLanguage(json.locales[i]); } this.languageNames = json.languageNames; this.intentDomains = json.intentDomains; const keys = Object.keys(json.domainManagers); for (let i = 0; i < keys.length; i += 1) { const key = keys[i]; this.domainManagers[key].fromJSON(json.domainManagers[key]); } for (let i = 0; i < json.extraSentences.length; i += 1) { const sentence = json.extraSentences[i]; this.guesser.addExtraSentence(sentence[0], sentence[1]); } } } module.exports = NluManager;