UNPKG

lande

Version:

A tiny neural network for natural language detection.

256 lines (173 loc) 8.31 kB
/* IMPORT */ import _ from 'lodash'; import fs from 'node:fs'; import path from 'node:path'; import {NeuralNetwork, Tensor, Trainers} from 'toygrad'; import {DATASET_PATH, DATASET_TRAIN_LENGTH_MIN, DATASET_TRAIN_LIMIT, DATASET_TRAIN_PERC, CONFIGS} from './constants'; import {forEachLine, getNormalized, getNgrams, getTopKeys, padEnd} from './utils'; import type {DatasetRaw, DatumRaw, Dataset, Datum, Config, Result} from './types'; /* HELPERS */ const getConfigsLangs = ( configs: Config[] ): string[] => { return _.uniq ( configs.flatMap ( config => config.langs ) ); }; const getDatasetRaw = ( langs: string[] ): DatasetRaw => { const datasetRaw: DatasetRaw = {}; const datasetFallbackRaw: DatasetRaw = {}; const langsSet = new Set ( langs ); const csv = fs.readFileSync ( DATASET_PATH ); forEachLine ( csv, line => { const parts = line.split ( '\t' ); if ( parts.length !== 3 ) return; // Something went wrong with this line const lang = parts[1]; const sentence = parts[2]; if ( !langsSet.has ( lang ) ) return; const longNr = datasetRaw[lang]?.length || 0; const shortNr = datasetFallbackRaw[lang]?.length || 0; if ( longNr >= DATASET_TRAIN_LIMIT ) return; // Already parsed enough sentences if ( shortNr >= ( DATASET_TRAIN_LIMIT - longNr ) ) return; // Already parsed enough fallback sentences const isLongEnough = ( sentence.length >= DATASET_TRAIN_LENGTH_MIN ); const bucket = isLongEnough ? datasetRaw : datasetFallbackRaw; const sentenceNorm = getNormalized ( sentence ); const unigrams = getNgrams ( sentenceNorm, 1 ); const bigrams = getNgrams ( sentenceNorm, 2 ); const trigrams = getNgrams ( sentenceNorm, 3 ); const quadgrams = getNgrams ( sentenceNorm, 4 ); const datumRaw: DatumRaw = { lang, sentence, unigrams, bigrams, trigrams, quadgrams }; bucket[lang] ||= []; bucket[lang].push ( datumRaw ); }); langs.forEach ( lang => { const long = datasetRaw[lang] || []; const short = datasetFallbackRaw[lang] || []; const shortSorted = short.sort ( ( a, b ) => b.sentence.length - a.sentence.length ); const fallbackNr = Math.max ( 0, ( DATASET_TRAIN_LIMIT - long.length ) ); const fallback = shortSorted.slice ( 0, fallbackNr ); datasetRaw[lang] = long.concat ( fallback ); }); return datasetRaw; }; const getDatasetRawTopNgrams = ( dataset: DatasetRaw, config: Config, type: 'unigrams' | 'bigrams' | 'trigrams' | 'quadgrams' ): string[] => { const ngrams: Record<string, Record<string, number>> = {}; config.langs.forEach ( lang => { dataset[lang]?.forEach ( datum => { Object.values ( datum[type] ).forEach ( ngram => { ngrams[lang] ||= {}; ngrams[lang][ngram.value] ||= 0; ngrams[lang][ngram.value] += ngram.count; }); }); }); const valuesByLangs = Object.values ( ngrams ).map ( getTopKeys ); const values: string[] = []; const valuesSet = new Set<string> (); const valuesLimit = config.network[type]; while ( values.length < valuesLimit ) { for ( const valuesByLang of valuesByLangs ) { while ( true ) { const value = valuesByLang.shift (); if ( !value ) break; if ( valuesSet.has ( value ) ) continue; values.push ( value ); valuesSet.add ( value ); break; } } } const valuesLimited = padEnd ( values.slice ( 0, valuesLimit ), valuesLimit, '' ); return valuesLimited; }; const getDataset = ( dataset: DatasetRaw, config: Config ): Dataset => { const unigrams = getDatasetRawTopNgrams ( dataset, config, 'unigrams' ); const bigrams = getDatasetRawTopNgrams ( dataset, config, 'bigrams' ); const trigrams = getDatasetRawTopNgrams ( dataset, config, 'trigrams' ); const quadgrams = getDatasetRawTopNgrams ( dataset, config, 'quadgrams' ); let train: Datum[] = []; let test: Datum[] = []; config.langs.forEach ( lang => { const data: Datum[] = []; dataset[lang]?.forEach ( datumRaw => { const inputUnigrams = unigrams.map ( value => datumRaw.unigrams[value]?.frequency || 0 ); const inputBigrams = bigrams.map ( value => datumRaw.bigrams[value]?.frequency || 0 ); const inputTrigrams = trigrams.map ( value => datumRaw.trigrams[value]?.frequency || 0 ); const inputQuadgrams = quadgrams.map ( value => datumRaw.quadgrams[value]?.frequency || 0 ); const inputNgrams = [...inputUnigrams, ...inputBigrams, ...inputTrigrams, ...inputQuadgrams]; const input = new Tensor ( 1, 1, inputNgrams.length, new Float32Array ( inputNgrams ) ); const output = config.langs.indexOf ( datumRaw.lang ); const datum: Datum = { lang, sentence: datumRaw.sentence, input, output }; data.push ( datum ); }); const trainLength = Math.floor ( data.length * DATASET_TRAIN_PERC ); train = train.concat ( data.slice ( 0, trainLength ) ); test = test.concat ( data.slice ( trainLength ) ); }); return { train, test }; }; /* MAIN */ const langs = getConfigsLangs ( CONFIGS ); const datasetRaw = getDatasetRaw ( langs ); for ( const config of CONFIGS ) { console.log ( `=== ${config.id} ===` ); const dataset = getDataset ( datasetRaw, config ); /* TRAINING */ const nn = new NeuralNetwork ({ layers: [ { type: 'input', sx: 1, sy: 1, sz: config.network.unigrams + config.network.bigrams + config.network.trigrams + config.network.quadgrams }, { type: 'dense', filters: config.network.hidden, bias: 0.1 }, { type: 'relu' }, { type: 'dense', filters: config.langs.length }, { type: 'softmax' } ] }); const trainer = new Trainers.Adadelta ( nn, { batchSize: config.network.batchSize }); for ( let epoch = 0; epoch < config.network.epochs; epoch++ ) { const batch = _.shuffle ( dataset.train ); for ( let i = 0, l = batch.length; i < l; i++ ) { if ( i % 10000 === 0 ) console.log ( `Epoch ${epoch + 1}/${config.network.epochs} - ${i}/${batch.length}` ); trainer.train ( batch[i].input, batch[i].output ); } } /* SAVING */ const langs = config.langs; const langsPath = path.join ( process.cwd (), 'standalone', `${config.id}-langs.js` ); const langsModule = `export default ${JSON.stringify ( langs )};`; const unigrams = getDatasetRawTopNgrams ( datasetRaw, config, 'unigrams' ); const bigrams = getDatasetRawTopNgrams ( datasetRaw, config, 'bigrams' ); const trigrams = getDatasetRawTopNgrams ( datasetRaw, config, 'trigrams' ); const quadgrams = getDatasetRawTopNgrams ( datasetRaw, config, 'quadgrams' ); const ngrams = { unigrams, bigrams, trigrams, quadgrams }; const ngramsPath = path.join ( process.cwd (), 'standalone', `${config.id}-ngrams.js` ); const ngramsModule = `export default ${JSON.stringify ( ngrams )};`; const nnPath = path.join ( process.cwd (), 'standalone', `${config.id}-options.js` ); const nnOptions = nn.getAsOptions ( 'f8' ); const nnModule = `export default ${JSON.stringify ( nnOptions )};`; fs.writeFileSync ( langsPath, langsModule ); fs.writeFileSync ( ngramsPath, ngramsModule ); fs.writeFileSync ( nnPath, nnModule ); /* TESTING */ const {default: lande} = await import ( `../standalone/${config.id}.js` ); let pass = 0; let fail = 0; let loss = 0; for ( let i = 0; i < dataset.test.length; i++ ) { const datum = dataset.test[i]; const result: Result = lande ( datum.sentence ); const expectedLang = datum.lang; const actualLang = result[0][0]; const expectedProbability = 1; const actualProbability = result.find ( result => result[0] === expectedLang )?.[1] || 0; loss += ( expectedProbability - actualProbability ) / dataset.test.length; if ( expectedLang === actualLang ) { pass += 1; } else { fail += 1; } } console.log ( `=== results ===` ); console.log ( 'Pass:', pass ); console.log ( 'Fail:', fail ); console.log ( 'Loss:', loss ); console.log ( 'Accuracy:', ( pass * 100 ) / ( pass + fail ) ); console.log ( 'Weights:', ( ( config.network.unigrams + config.network.bigrams + config.network.trigrams + config.network.quadgrams ) * config.network.hidden ) + config.network.hidden + ( ( config.network.hidden * config.langs.length ) + config.langs.length ) ); }