wingbot
Version:
Enterprise Messaging Bot Conversation Engine
746 lines (616 loc) • 24.1 kB
JavaScript
/**
* @author David Menger
*/
;
const { replaceDiacritics } = require('../utils');
const { iterateThroughWords } = require('../utils/ai');
/**
* @typedef {object} DetectedEntity
* @prop {number} [start]
* @prop {string} [entity]
* @prop {number} [end]
* @prop {number} [score]
* @prop {string|number|boolean} [value]
* @prop {string} [text]
*/
/**
* @callback EntityDetector
* @param {string} text - part of text
* @param {DetectedEntity[]} entities - dependent entities
* @param {boolean} [searchWithinWords] - optional ability to search within words
* @returns {DetectedEntity[]|DetectedEntity|Promise<DetectedEntity>|Promise<DetectedEntity[]>}
*/
/**
* @callback ValueExtractor
* @param {string[]} match - regexp result
* @param {DetectedEntity[]} entities - dependent entities
* @returns {*}
*/
/**
* @typedef {object} Entity
* @prop {string} entity
* @prop {string} value
* @prop {number} score
*/
/**
* @typedef {object} Intent
* @prop {string} intent
* @prop {number} score
* @prop {Entity[]} [entities]
*/
/**
* @typedef {object} Result
* @prop {string} [text]
* @prop {Entity[]} entities
* @prop {Intent[]} intents
*/
/**
* @typedef {object} DetectorOptions
* @prop {boolean} [anonymize] - if true, value will not be sent to NLP
* @prop {Function|string} [extractValue] - entity extractor
* @prop {boolean} [matchWholeWords] - match whole words at regular expression
* @prop {boolean} [replaceDiacritics] - keep diacritics when matching regexp
* @prop {boolean} [options.caseSensitiveRegex] - make regex case sensitive
* @prop {string[]} [dependencies] - array of dependent entities
* @prop {boolean} [clearOverlaps] - let longer entities from NLP to replace entity
*/
/**
* @callback WordEntityDetector
* @param {string} text
* @param {DetectedEntity[]} [entities]
* @param {number} [startIndex]
* @param {string} [prefix]
* @returns {DetectedEntity[]}
*/
/**
* @typedef {object} Phrases
* @prop {Map<string,string[]>} phrases
*/
/** @typedef {import('../Request')} Request */
function optionalWrap (l, r, content) {
const consistent = !l === !r;
if (consistent) {
return `(${content})`;
}
return `${l || ''}(${content})${r || ''}`;
}
const MULTI_ENTITY_CLEANER = /((?<!\\)\([^()]*[^()\\]\)|@[A-Z0-9-]+)\?/g;
class CustomEntityDetectionModel {
/**
* @param {object} options
* @param {string} [options.prefix]
* @param {boolean} [options.verbose]
* @param {{ warn: Function, error: Function, log: Function }} [log]
*/
constructor (options, log = console) {
this._options = options;
this._log = log;
this.prefix = options.prefix;
this._entityDetectors = new Map();
/**
* @type {number}
*/
this.phrasesCacheTime = 0;
/**
* @type {number}
*/
this.maxWordCount = 0;
/**
* @type {WordEntityDetector}
*/
this.wordEntityDetector = null;
}
/**
*
* @param {DetectedEntity[]} entities
* @param {string} entity
* @param {string} text
* @param {number} offset
* @param {string} originalText
*/
_normalizeResult (entities, entity, text, offset, originalText) {
return entities
.map((e) => {
if (!e) {
return null;
}
const score = typeof e.score !== 'number' ? 1 : Math.max(Math.min(e.score, 1), 0);
if (typeof e.text !== 'string'
&& (typeof e.start !== 'number' || typeof e.end !== 'number')) {
throw new Error(`Entity matcher '${entity}' should return 'text' or 'start'+'end' hint`);
}
if (typeof e.text === 'string') {
if (!e.text) {
return null;
}
const start = offset + text.toLocaleLowerCase()
.indexOf(e.text.toLocaleLowerCase());
if (start === -1) {
throw new Error(`Entity matcher '${entity}' retuned string, which cannot be found in query`);
}
const end = start + e.text.length;
const useText = originalText.substring(start, end);
return {
...e,
text: useText,
entity,
start,
end,
score
};
}
if (e.start < 0 || e.start >= text.length) {
throw new Error(`Entity matcher '${entity}' retuned start out of bounds: ${e.start} (string length was ${text.length})`);
}
if (e.start === e.end) {
return null;
}
if (e.end < e.start || e.end > text.length) {
throw new Error(`Entity matcher '${entity}' retuned end out of bounds: ${e.end} (start: ${e.start}, length: ${text.length})`);
}
const entityText = text.substring(e.start, e.end);
return {
...e,
entity,
text: entityText,
start: offset + e.start,
end: offset + e.end,
score
};
})
.filter((e) => e !== null);
}
/**
*
* @param {string} entity
* @param {string} text
* @param {DetectedEntity[]} entities
* @param {DetectedEntity[]} subWord
* @returns {Promise<DetectedEntity[]>}
*/
async _detectAllEntities (entity, text, entities, subWord) {
const [regularResults, subWordResults] = await Promise.all([
this._detectEntities(entity, text, entities, subWord, false),
this._detectEntities(entity, text, entities, subWord, true)
]);
const cleanSubWordResults = subWordResults
.filter((regular) => !regularResults
.some((e) => e.start < regular.end && e.end > regular.start
&& e.end >= regular.end && e.start <= regular.start));
subWord.push(...cleanSubWordResults);
return regularResults;
}
/**
*
* @param {string} entity
* @param {string} text
* @param {DetectedEntity[]} entities
* @param {DetectedEntity[]} subWord
* @param {boolean} detectSubWords
* @returns {Promise<DetectedEntity[]>}
*/
async _detectEntities (entity, text, entities, subWord, detectSubWords) {
const { entityDetector, dependencies } = this._entityDetectors.get(entity);
if (detectSubWords && entityDetector.length < 3) {
return [];
}
const collected = [];
let o = 0;
let t = text;
try {
for (let i = 0; i < text.length; i++) {
const dependentEntities = [
...subWord.filter((e) => dependencies.includes(`@${e.entity.toUpperCase()}`)),
...entities.filter((e) => dependencies.includes(`@${e.entity.toUpperCase()}`))
];
const res = await Promise.resolve(
entityDetector(t, dependentEntities, detectSubWords)
);
const resWasArray = Array.isArray(res);
const resArray = resWasArray ? res : [res];
const normalized = this._normalizeResult(resArray, entity, t, o, text);
if (resWasArray || normalized.length === 0) {
return [...collected, ...normalized];
}
const [e] = normalized;
t = text.substring(e.end);
const [whitespaces] = t.match(/^\s*/);
t = t.substring(whitespaces.length);
o = e.end + whitespaces.length;
collected.push(e);
}
this._log.error(`Entity '${entity}' detection reached iteration limit`);
return collected;
} catch (e) {
this._log.error(`Entity '${entity}' detection failed`, e);
return collected;
}
}
_escapeRegex (string, shouldReplaceDiacritics) {
const ret = string.replace(/[-/\\^$*+?.()|[\]{}]/g, '\\$&');
if (!shouldReplaceDiacritics) {
return ret;
}
return replaceDiacritics(ret);
}
/**
* Return only entities without overlap
*
* @param {DetectedEntity[]} entities
* @param {string[]} [expectedEntities]
* @param {boolean} [justDuplicates]
* @returns {DetectedEntity[]}
*/
nonOverlapping (entities, expectedEntities = [], justDuplicates = false) {
// longest first
entities.sort(({ start: a, end: b }, { start: z, end: y }) => {
const aLen = b - a;
const zLen = y - z;
if (aLen === zLen) {
return a - z;
}
return zLen - aLen;
});
if (this._options.verbose) this._log.log('#NLP [nonOverlapping]', { entities, expectedEntities, justDuplicates });
let res = [];
for (let i = 0; i < entities.length; i++) {
const entity = entities[i];
const isExpected = expectedEntities.includes(entity.entity);
const duplicate = res
.find((e) => e.start === entity.start && e.end === entity.end);
let overlapping = justDuplicates
? !!duplicate
: res
.some((e) => e.start < entity.end && e.end > entity.start);
if (overlapping) {
if (duplicate) {
overlapping = !isExpected && expectedEntities.includes(duplicate.entity);
}
if (duplicate && duplicate.entity === entity.entity) {
overlapping = true;
} else if (isExpected) {
overlapping = false;
res = res.filter((e) => {
const isOverlapping = e.start < entity.end && e.end > entity.start;
if (!isOverlapping) {
return true;
}
const expectedEntity = expectedEntities.includes(e.entity);
if (expectedEntity && !duplicate) {
overlapping = true;
}
return expectedEntity;
});
// try to put back previously ignored entities
for (let k = 0; k < i; k++) {
const putback = entities[k];
const currentConflict = putback.start < entity.end
&& putback.end > entity.start;
const othersConflict = res.some((e) => putback === e
|| (e.start < putback.end && e.end > putback.start));
this._log.log(`#NLP (${i}|${k} [putBack: ${entity.entity}|${putback.entity}] (${i}|${k})`, {
putback, entity, currentConflict, othersConflict
});
if (!currentConflict && !othersConflict) {
res.push(putback);
}
}
}
}
if (this._options.verbose) {
this._log.log(`#NLP (${i}) [nonOverlapping| ${entity.entity}:${entity.value}]`, {
willRemoveEntity: overlapping, overlapping, duplicate, isExpected, entity
});
}
if (!overlapping) {
res.push(entity);
}
}
res.sort(({ start: a }, { start: z }) => a - z);
return res;
}
/**
*
* @param {boolean} [known]
* @returns {string[]} -
*/
getDependentEntities (known = null) {
const entities = new Set();
const knownEntities = new Map(
Array.from(this._entityDetectors.keys())
.map((key) => [key.toLowerCase(), key])
);
for (const { dependencies } of this._entityDetectors.values()) {
for (const entity of dependencies) {
const lowerCase = entity.replace(/^@/, '').toLowerCase();
const isKnown = knownEntities.has(lowerCase);
if (isKnown && known !== false) {
entities.add(knownEntities.get(lowerCase));
} else if (!isKnown && known !== true) {
entities.add(lowerCase);
}
}
}
return Array.from(entities.values());
}
/**
*
* @param {string} text
* @param {string} [singleEntity]
* @param {string[]} [expected]
* @param {DetectedEntity[]} [prevEnts] - previously detected entities to include
* @param {DetectedEntity[]} [subWord] - previously detected entities within words
* @returns {Promise<DetectedEntity[]>}
*/
async resolveEntities (text, singleEntity = null, expected = [], prevEnts = [], subWord = []) {
let entities = prevEnts.slice();
if (this.wordEntityDetector) {
for (const [s, startIndex] of iterateThroughWords(text, this.maxWordCount)) {
const ents = this.wordEntityDetector(s, prevEnts, startIndex, this.prefix);
const byEntity = new Map();
for (const entity of ents) {
let list;
if (byEntity.has(entity.entity)) {
list = byEntity.get(entity.entity);
} else {
list = [];
byEntity.set(entity.entity, list);
}
list.push({
text: s,
...entity
});
}
const normalized = Array.from(byEntity.entries())
.flatMap(([e, list]) => this._normalizeResult(list, e, s, startIndex, text));
entities.push(...normalized);
}
}
// mark unknown dependencies as resolved
const resolved = new Set(
this.getDependentEntities(false)
.map((entity) => `@${entity.toUpperCase()}`)
);
let missing = Array.from(this._entityDetectors.keys());
entities = entities.map((e) => {
if (typeof e.text === 'string') {
return e;
}
return {
...e,
text: text.substring(e.start, e.end)
};
});
while (missing.length !== 0) {
let detect = [];
missing = missing.filter((e) => {
const { dependencies } = this._entityDetectors.get(e);
if (dependencies.every((d) => resolved.has(d))) {
detect.push(e);
return false;
}
return true;
});
if (detect.length === 0 && missing.length !== 0) {
this._log.warn(`Ignoring entities because of dependency cycle: ${missing.join(', ')}`);
break;
}
if (singleEntity && detect.includes(singleEntity)) {
detect = [singleEntity];
missing = [];
}
const results = await Promise.all(
detect.map((entity) => this._detectAllEntities(entity, text, entities, subWord))
);
detect.forEach((entity) => resolved.add(`@${entity.toUpperCase()}`));
results.forEach((res) => entities.push(...res));
}
const clean = this.nonOverlapping(entities, expected);
if (!singleEntity) {
return clean;
}
const entity = clean.find((e) => e.entity === singleEntity);
if (!entity) {
return [];
}
return [entity];
}
async resolveEntityValue (entity, text) {
if (!this._entityDetectors.has(entity)) {
return text;
}
const [res = null] = await this.resolveEntities(text, entity);
return res ? res.value : null;
}
/**
*
* @param {string} text - the user input
* @param {Request} [req]
* @returns {Promise<Result>}
*/
async resolve (text, req) {
let cleanText = text
.replace(/[\r\n]+/g, ' ')
.trim();
const expectedEntities = req ? req.expectedEntities() : [];
const entities = await this.resolveEntities(cleanText, null, expectedEntities);
cleanText = cleanText.toLocaleLowerCase();
// filter the text
for (let i = entities.length - 1; i >= 0; i--) {
const entity = entities[i];
if (!this._entityDetectors.has(entity.entity)) continue;
const { anonymize } = this._entityDetectors.get(entity.entity);
if (anonymize) {
const before = cleanText.substring(0, entity.start);
const after = cleanText.substring(entity.end);
cleanText = `${before}@${entity.entity.toUpperCase()}${after}`;
}
}
return {
text: cleanText,
intents: [],
// @ts-ignore
entities
};
}
/**
*
* @param {RegExp|string} regexp
*/
_extractRegExpDependencies (regexp) {
let str = typeof regexp === 'string' ? regexp : regexp.source;
const matches = str.match(/@[A-Z0-9-]+/g);
const known = Array.from(new Set(matches));
if (known.length <= 1 || !str.match(MULTI_ENTITY_CLEANER)) {
return known;
}
str = str.replace(MULTI_ENTITY_CLEANER, '');
const cleanDeps = this._extractRegExpDependencies(str);
return Array.from(new Set([...cleanDeps, ...matches]));
}
/**
*
* @param {DetectedEntity[]} entities
* @param {string} dependency
* @returns {DetectedEntity|null}
*/
_entityByDependency (entities, dependency) {
return entities.find((e) => `@${e.entity.toUpperCase()}` === dependency);
}
/**
*
* @param {RegExp} regexp
* @param {object} [options]
* @param {Function|string} [options.extractValue] - entity extractor
* @param {boolean} [options.matchWholeWords] - match whole words at regular expression
* @param {boolean} [options.replaceDiacritics] - replace diacritics when matching regexp
* @param {string[]} [options.dependencies] - array of dependent entities
* @param {boolean} [options.caseSensitiveRegex] - make regex case sensitive
*/
_regexpToDetector (regexp, options) {
const { dependencies = [], extractValue = null } = options;
const { source } = regexp;
/**
* @param {string} text
* @param {DetectedEntity[]} entities
* @param {boolean} searchWithinWords
*/
return (text, entities, searchWithinWords) => {
if (typeof extractValue === 'string'
&& !this._entityByDependency(entities, extractValue)) {
return null;
}
let replaced = source.replace(/(\()?@([A-Z0-9-]+)(\))?/g, (value, l, ent, r) => {
const matchingEntities = entities
.filter((e) => e.entity.toUpperCase() === ent)
.map((e) => this._escapeRegex(e.text, options.replaceDiacritics));
if (matchingEntities.length === 0) {
return optionalWrap(l, r, `@${ent}`);
}
matchingEntities.sort((a, z) => z.length - a.length);
return optionalWrap(l, r, matchingEntities.join('|'));
});
if (options.matchWholeWords && !searchWithinWords) {
replaced = `(?<=(^|[^a-zA-Z0-9\u00C0-\u017F]))${replaced}(?=([^a-zA-Z0-9\u00C0-\u017F]|$))`;
}
const r = new RegExp(replaced, options.caseSensitiveRegex ? '' : 'i');
const lc = options.caseSensitiveRegex
? text
: text.toLocaleLowerCase();
let matchText = lc;
if (options.replaceDiacritics) {
matchText = replaceDiacritics(matchText);
}
const match = matchText.match(r);
// console.log({ matchText, replaced, match: match && match[0], searchWithinWords });
if (!match) {
return null;
}
// find the right entity
const start = match.index;
const end = start + match[0].length;
matchText = lc.substring(start, end);
const useEntities = entities.filter((e) => e.start >= start && e.end <= end);
let value;
if (typeof extractValue === 'function') {
value = extractValue(match, useEntities);
} else if (typeof extractValue === 'string' || dependencies.length > 0) {
const entityName = typeof extractValue === 'string'
? extractValue
: dependencies[0];
const entity = this._entityByDependency(useEntities, entityName);
value = entity ? entity.value : null;
} else {
[value] = match;
}
return {
text: matchText,
start,
end,
value
};
};
}
/**
*
* @param {string} name
* @param {EntityDetector|RegExp} detector
* @param {DetectorOptions} [options]
* @returns {this}
*/
setEntityDetector (name, detector, options = {}) {
const entity = name;
let entityDetector = detector;
let dependencies = [];
if (detector instanceof RegExp) {
dependencies = this._extractRegExpDependencies(detector);
if (typeof options.extractValue === 'string' && !dependencies.includes(options.extractValue)) {
throw new Error(`RegExp entity detector '${name}' uses ${options.extractValue} extractValue but it's missing in RegExp`);
}
entityDetector = this._regexpToDetector(detector, { ...options, dependencies });
} else if (options.dependencies) {
dependencies = dependencies
.map((d) => (`${d}`.match(/^@/) ? `${d}`.toUpperCase() : `@${d.toUpperCase()}`));
}
this._entityDetectors.set(entity, {
entityDetector,
detector,
dependencies,
anonymize: !!options.anonymize,
clearOverlaps: !!options.clearOverlaps
});
return this;
}
/**
* Sets options to entity detector.
* Useful for disabling anonymization of local system entities.
*
* @param {string} name
* @param {object} options
* @param {boolean} [options.anonymize]
* @param {boolean} [options.clearOverlaps]
* @returns {this}
* @example
*
* ai.register('wingbot-model-name')
* .setDetectorOptions('phone', { anonymize: false })
* .setDetectorOptions('email', { anonymize: false })
*/
setDetectorOptions (name, options) {
if (!this._entityDetectors.has(name)) {
throw new Error(`Can't set entity detector options. Entity "${name}" does not exist.`);
}
Object.assign(this._entityDetectors.get(name), options);
return this;
}
async getPhrases () {
return this._getPhrases();
}
async _getPhrases () {
return CustomEntityDetectionModel.getEmptyPhrasesObject();
}
static getEmptyPhrasesObject () {
return { phrases: new Map() };
}
}
module.exports = CustomEntityDetectionModel;