clustex
Version:
Clustex is a lightweight text classification package designed to efficiently categorize text based on similarity metrics and learned token weights.
112 lines (106 loc) • 4.62 kB
JavaScript
import getDataset from "./config.js";
function common(a, b) {
const common = [];
const shortest = a.length < b.length ? a : b;
const longest = shortest === a ? b : a;
for (let i = 0; i < longest.length; i++) {
for (let j = longest.length + 1; j > i; j--) {
const part = longest.slice(i, j);
if (common.every(el => !el.includes(part)) && shortest.includes(part)) {
common.push(part);
}
}
}
return common;
}
function similarity(a, b) {
const commons = common(a, b);
const middleLength = (a.length + b.length) / 2;
const commonLength = commons.reduce((acc, el) => acc + el.length, 0);
if (middleLength === 0) {
return 1;
}
if (commons.length === 0) {
return 0;
}
return commonLength / middleLength / (1 + 1 / middleLength) ** (commons.length - 1);
}
function softmax(array) {
const sum = array.reduce((acc, el) => acc + Math.exp(el), 0);
return array.map(el => Math.exp(el) / sum);
}
const regex = /[\wа-яА-ЯёЁ]+|[^\s\wа-яА-ЯёЁ]/g;
class Classifier {
constructor(classifications = [], learningRate = 1, threshold = 0.8) {
this.tokens = [];
this.classifications = classifications;
this.learningRate = learningRate;
this.threshold = threshold;
}
run(text) {
const tokens = text.match(regex) || [];
const sums = new Array(this.classifications.length).fill(0);
this.classifications.forEach((classification, i) => {
this.tokens.forEach(token => {
tokens.forEach(token2 => {
const similarity2 = similarity(token2, token.token);
if (similarity2 >= this.threshold) {
sums[i] += token.weights[i] * similarity2;
}
});
});
});
const softmaxed = softmax(sums);
const chance = {};
this.classifications.forEach((classification, i) => {
chance[classification] = isNaN(softmaxed[i]) ? 1 : softmaxed[i];
});
return { chance, sums };
}
classify(text) {
const { chance } = this.run(text);
return Object.keys(chance).sort((a, b) => chance[b] - chance[a])[0];
}
chance(text) {
const { chance } = this.run(text);
return chance;
}
example(text, classification = this.classifications[0]) {
const tokens = text.match(regex) || [];
const { sums } = this.run(text);
const otherMiddle = sums.filter((el, i) => this.classifications[i] !== classification).reduce((acc, el) => acc + el, 0) / (sums.length - 1);
const error = 2 * Math.max(otherMiddle - sums[this.classifications.indexOf(classification)], 0);
tokens.forEach(token => {
this.tokens = this.tokens.map(token2 => {
const similarity2 = similarity(token2.token, token);
if (similarity2 >= this.threshold) {
return {...token2, weights: token2.weights.map((weight, i) => weight + (classification === this.classifications[i] ? this.learningRate * (similarity2 + error) : 0))};
} else {
return token2;
}
});
});
tokens.forEach(token => {
if (!this.tokens.some(el => similarity(el.token, token) >= this.threshold)) {
this.tokens.push({token, weights: this.classifications.map((classification2) => classification2 === classification ? this.learningRate * (1 + error) : 0)});
}
});
}
dataset(name = "news", iterations = 1) {
const dataset = getDataset(name);
this.classifications = [...dataset.classifications, ...this.classifications];
for (let i = 0; i < dataset.data.length; i++) {
for (let j = 0; j < iterations; j++) {
this.example(dataset.data[i].text.toLowerCase().replaceAll('!', '').replaceAll('?', '').replaceAll('.', ''), dataset.data[i].classification);
}
}
dataset.data.reverse();
for (let i = 0; i < dataset.data.length; i++) {
for (let j = 0; j < iterations; j++) {
this.example(dataset.data[i].text.toLowerCase().replaceAll('!', '').replaceAll('?', '').replaceAll('.', ''), dataset.data[i].classification);
}
}
}
static datasets = ["news", "spam", "tone", "importance"];
}
export default Classifier;