UNPKG

clustex

Version:

Clustex is a lightweight text classification package designed to efficiently categorize text based on similarity metrics and learned token weights.

112 lines (106 loc) 4.62 kB
import getDataset from "./config.js"; function common(a, b) { const common = []; const shortest = a.length < b.length ? a : b; const longest = shortest === a ? b : a; for (let i = 0; i < longest.length; i++) { for (let j = longest.length + 1; j > i; j--) { const part = longest.slice(i, j); if (common.every(el => !el.includes(part)) && shortest.includes(part)) { common.push(part); } } } return common; } function similarity(a, b) { const commons = common(a, b); const middleLength = (a.length + b.length) / 2; const commonLength = commons.reduce((acc, el) => acc + el.length, 0); if (middleLength === 0) { return 1; } if (commons.length === 0) { return 0; } return commonLength / middleLength / (1 + 1 / middleLength) ** (commons.length - 1); } function softmax(array) { const sum = array.reduce((acc, el) => acc + Math.exp(el), 0); return array.map(el => Math.exp(el) / sum); } const regex = /[\wа-яА-ЯёЁ]+|[^\s\wа-яА-ЯёЁ]/g; class Classifier { constructor(classifications = [], learningRate = 1, threshold = 0.8) { this.tokens = []; this.classifications = classifications; this.learningRate = learningRate; this.threshold = threshold; } run(text) { const tokens = text.match(regex) || []; const sums = new Array(this.classifications.length).fill(0); this.classifications.forEach((classification, i) => { this.tokens.forEach(token => { tokens.forEach(token2 => { const similarity2 = similarity(token2, token.token); if (similarity2 >= this.threshold) { sums[i] += token.weights[i] * similarity2; } }); }); }); const softmaxed = softmax(sums); const chance = {}; this.classifications.forEach((classification, i) => { chance[classification] = isNaN(softmaxed[i]) ? 1 : softmaxed[i]; }); return { chance, sums }; } classify(text) { const { chance } = this.run(text); return Object.keys(chance).sort((a, b) => chance[b] - chance[a])[0]; } chance(text) { const { chance } = this.run(text); return chance; } example(text, classification = this.classifications[0]) { const tokens = text.match(regex) || []; const { sums } = this.run(text); const otherMiddle = sums.filter((el, i) => this.classifications[i] !== classification).reduce((acc, el) => acc + el, 0) / (sums.length - 1); const error = 2 * Math.max(otherMiddle - sums[this.classifications.indexOf(classification)], 0); tokens.forEach(token => { this.tokens = this.tokens.map(token2 => { const similarity2 = similarity(token2.token, token); if (similarity2 >= this.threshold) { return {...token2, weights: token2.weights.map((weight, i) => weight + (classification === this.classifications[i] ? this.learningRate * (similarity2 + error) : 0))}; } else { return token2; } }); }); tokens.forEach(token => { if (!this.tokens.some(el => similarity(el.token, token) >= this.threshold)) { this.tokens.push({token, weights: this.classifications.map((classification2) => classification2 === classification ? this.learningRate * (1 + error) : 0)}); } }); } dataset(name = "news", iterations = 1) { const dataset = getDataset(name); this.classifications = [...dataset.classifications, ...this.classifications]; for (let i = 0; i < dataset.data.length; i++) { for (let j = 0; j < iterations; j++) { this.example(dataset.data[i].text.toLowerCase().replaceAll('!', '').replaceAll('?', '').replaceAll('.', ''), dataset.data[i].classification); } } dataset.data.reverse(); for (let i = 0; i < dataset.data.length; i++) { for (let j = 0; j < iterations; j++) { this.example(dataset.data[i].text.toLowerCase().replaceAll('!', '').replaceAll('?', '').replaceAll('.', ''), dataset.data[i].classification); } } } static datasets = ["news", "spam", "tone", "importance"]; } export default Classifier;