cleanai
Version:
A fully standalone, terminal-based AI model CLI for training and inference
1,124 lines (1,030 loc) • 333 kB
JavaScript
#!/usr/bin/env node
// cleanai.js
//
//* It is way better to download the better comments extension for vscode to view this file because
//* it adds color to the comments depending on how much you should read them thanks to the little symbols
//* like * ! and ? that I put after // on the comments.
// Credits
//
// Me ofc (willmil11)
// Claude (3.5 sonnet, 3.7 sonnet)
// Chatgpt (4o, o1, o3, o3-mini-high, o4-mini-high)
// Gemini (2.5 flash, 2.5 pro. aistudio versions)
//
// Contact me
// Discord: willmil11
// Email: willmil111012@gmail.com
//
// Drop a star on the github repo if you like it :D
var tresholdworkersworthit = 27; //This is for remote mode that is not activated by default so ignore it.
var mode = "SharedMemory" //Could also be "Remote"
var submode = "Turtle" //Could also be "Rabbit"
//Remote mode is legacy I built it before I learnt of SharedBufferArrays, it has imense overhead and it
//is very slow, however it works perfectly. It is inferior to other modes however I am leaving it here as
//some kind of museum, of a display of the hour I wasted building it.
//Rabit mode means the worker asks for data when it is required which adds some overhead but means the workers are spawned near instantly.
//Turtle mode means the worker is given all data on spawn, so less overhead but way longer spawn (initialisation) time.
//Neither mode uses more or less memory. (or the difference is negligible, a few kilobytes at most)
//Do not change mode or submode (recommended is SharedMemory/Turtle)
//All modes and submodes should work natively however.
var workerPids = [];
var timers = [];
var fs = require("fs");
var process = require("process");
var ranKill = false
var toSharedFloat32 = function(arr) {
var shared = new SharedArrayBuffer(arr.byteLength);
var sharedView = new Float32Array(shared);
sharedView.set(arr);
return sharedView;
};
var sharedFloat32Array = function(input) {
if (input instanceof Float32Array) return toSharedFloat32(input);
if (Array.isArray(input)) return toSharedFloat32(new Float32Array(input));
if (ArrayBuffer.isView(input)) return toSharedFloat32(new Float32Array(input));
if (typeof input === 'number') return new Float32Array(new SharedArrayBuffer(input * 4));
throw new Error("sharedFloat32Array: invalid input");
};
var killWorkers = function(error) {
if (mode === "SharedMemory") return;
if (ranKill) return;
ranKill = true;
if (error) {
console.error("[Exit] Uncaught exception:", error);
}
console.log("[Exit] Killing workers...");
for (var i in workerPids) {
console.log("[Exit] Killing worker " + workerPids[i] + "...");
try {
if (os.platform() === 'win32') {
process.kill(workerPids[i]);
} else {
process.kill(workerPids[i], 'SIGTERM');
}
} catch (err) {
if (err.code !== 'ESRCH') throw err;
}
console.log("[Exit] Killed worker " + workerPids[i] + ".");
}
console.log("[Exit] Killed workers.");
workerPids = [];
console.log("[Exit] Exiting...");
process.exit(0);
};
if (mode !== "SharedMemory") {
process.on('exit', killWorkers);
process.on('SIGINT', killWorkers);
process.on('SIGUSR1', killWorkers);
process.on('SIGUSR2', killWorkers);
process.on('uncaughtException', function(err) { killWorkers(err); });
}
//var json = JSON;
var path = require("path");
var os = require("os");
var readline = require("readline");
var spawnchild = require("child_process").spawn;
var hasExposeGC = process.execArgv.includes('--expose-gc');
var hasMemoryLimit = process.execArgv.some(function(arg) {
return arg.startsWith('--max-old-space-size=');
});
// This value is a meme and will work fine until Node panics or the kernel screams
// - Chatgpt 4o 19.04.2025
var desiredMemFlag = '--max-old-space-size=9999999999999';
if (!hasExposeGC || !hasMemoryLimit) {
var newArgs = [
!hasMemoryLimit ? desiredMemFlag : null,
!hasExposeGC ? '--expose-gc' : null,
"--no-opt",
"--interpreted-frames-native-stack"
]
.concat(process.execArgv)
.concat([process.argv[1]])
.concat(process.argv.slice(2))
.filter(Boolean); // Remove nulls
spawnchild(process.argv[0], newArgs, { stdio: 'inherit' })
.on('exit', function(code) {
process.exit(code);
});
return; // Let the child process take over
}
var processids = [];
var randomRangeInclusive = function(range) {
return Math.floor(Math.random() * (range[1] - range[0] + 1)) + range[0];
};
var generateProcessId = function() {
var processid = "";
while (true) {
processid = "";
for (var index = 0; index < 5; index++) {
processid += randomRangeInclusive([0, 9]);
}
var found = false;
for (var i = 0; i < processids.length; i++) {
if (processids[i] === processid) {
found = true;
break;
}
}
if (!found) {
processids.push(processid);
break;
}
}
return {
processid: processid,
revoke: function() {
console.log("[" + this.processid + "] Revoking processid...");
var index = processids.indexOf(this.processid);
if (index !== -1) {
processids.splice(index, 1);
}
var old = this.processid;
this.processid = null;
console.log("[" + old + "] Revoked processid.");
}
};
};
var bridgeids = [];
var generateBridgeId = function() {
var bridgeid = "";
while (true) {
bridgeid = "";
for (var index = 0; index < 5; index++) {
bridgeid += randomRangeInclusive([0, 9]);
}
var found = false;
for (var i = 0; i < bridgeids.length; i++) {
if (bridgeids[i] === bridgeid) {
found = true;
break;
}
}
if (!found) {
bridgeids.push(bridgeid);
break;
}
}
return {
bridgeid: bridgeid,
revoke: function() {
console.log("[" + this.bridgeid + "] Revoking bridgeid...");
var index = bridgeids.indexOf(this.bridgeid);
if (index !== -1) {
bridgeids.splice(index, 1);
}
var old = this.bridgeid;
this.bridgeid = null;
console.log("[" + old + "] Revoked bridgeid.");
}
};
};
var readline_async_synclike = async function(query){
return await new Promise((resolve) => {
var rl = readline.createInterface({
input: process.stdin,
output: process.stdout
});
rl.question(query, function(answer) {
rl.close();
resolve(answer);
});
});
}
var wait = async function(ms){
await new Promise(function(resolve){
setTimeout(function(){
resolve();
}, ms)
})
}
var resolveDependency = async function(dependency){
console.log("Missing dependency: " + dependency)
console.log("Would you like to auto-install all dependencies?")
console.log("Dependencies are:")
console.log(" - readline-sync")
console.log(" - tiktoken")
console.log(" - archiver")
console.log(" - yauzl")
console.log(" - uuid")
console.log(" - ws")
console.log("I am not liable if auto-installing the dependencies using the command below causes any damage of any sort to your system.")
console.log("Command: \"npm install readline-sync tiktoken archiver yauzl uuid ws\"")
console.log("Correct answers to the following prompt are \"y\" for yes and \"n\" for no without the quotes.")
while (true){
res = await readline_async_synclike("Auto-install? (y/n) › ")
if (res === "y"){
console.log("Auto-installing dependencies...")
var install = spawnchild('npm', [
'install',
'readline-sync',
'tiktoken',
'archiver',
'yauzl',
'uuid',
'ws'
], {
stdio: 'inherit' // This is the key: pipes stdout, stderr, and stdin
});
finishedJob = false;
install.on('exit', code => {
if (code === 0) {
console.log("Dependencies auto-installed successfully.");
finishedJob = true;
}
else {
console.log("Failed to auto-install dependencies.");
console.log("Please check you have npm installed, an internet connection and enough disk space available then retry running the command, perhaps try running it manually.");
console.log("Exiting...")
process.exit(1);
}
});
while (!finishedJob) {
await wait(100);
}
if (finishedJob){
break
}
}
else{
if (res === "n"){
console.log("Dependencies not installed, and you refused to install them automatically. Cannot continue.")
console.log("Exiting...")
process.exit(1);
}
else{
console.log("Invalid answer, valid answers are \"y\" for yes and \"n\" for no without the quotes. Please try again.")
}
}
}
}
;(async function(){
try{
var readlineSync = require("readline-sync");
}
catch (error){
await resolveDependency("readline-sync")
try{
var readlineSync = require("readline-sync");
}
catch (error){
console.log("Even though auto-installation was successful, the readline-sync dependency could not be loaded.")
console.log("Exiting...")
process.exit(1);
}
}
try{
var Tiktoken = require("tiktoken/lite").Tiktoken;
}
catch (error){
await resolveDependency("tiktoken")
try{
var Tiktoken = require("tiktoken/lite").Tiktoken;
}
catch (error){
console.log("Even though auto-installation was successful, the tiktoken dependency could not be loaded.")
console.log("Exiting...")
process.exit(1);
}
}
var cl100k_base = require("tiktoken/encoders/cl100k_base.json");
try{
var uuid = require("uuid");
}
catch (error){
await resolveDependency("uuid")
try{
var uuid = require("uuid");
}
catch (error){
console.log("Even though auto-installation was successful, the uuid dependency could not be loaded.")
console.log("Exiting...")
process.exit(1);
}
}
try{
var archiver = require('archiver');
}
catch (error){
await resolveDependency("archiver")
try{
var archiver = require('archiver');
}
catch (error){
console.log("Even though auto-installation was successful, the archiver dependency could not be loaded.")
console.log("Exiting...")
process.exit(1);
}
}
try{
var yauzl = require('yauzl');
}
catch (error){
await resolveDependency("yauzl")
try{
var yauzl = require('yauzl');
}
catch (error){
console.log("Even though auto-installation was successful, the yauzl dependency could not be loaded.")
console.log("Exiting...")
process.exit(1);
}
}
try{
var WebSocket = require("ws");
}
catch (error){
await resolveDependency("ws")
try{
var WebSocket = require("ws");
}
catch (error){
console.log("Even though auto-installation was successful, the ws dependency could not be loaded.")
console.log("Exiting...")
process.exit(1);
}
}
var args = process.argv.slice(2);
var config = {};
var CHUNK_THRESHOLD_BYTES = 100 * 1024 * 1024;
function inputWithTimeout(prompt, timeoutMs) {
return new Promise(function (resolve) {
var rl = readline.createInterface({
input: process.stdin,
output: process.stdout
});
var done = false;
// Timer to enforce the timeout
var timer = setTimeout(function () {
if (!done) {
done = true;
rl.close();
resolve(null); // Return null on timeout
}
}, timeoutMs);
// Ask the question
rl.question(prompt, function (answer) {
if (!done) {
done = true;
clearTimeout(timer);
rl.close();
resolve(answer);
}
});
});
}
function help(issue) {
if (issue === undefined) {
issue = "No args found.";
}
function spacing() {
return " ".repeat(("cleanai").length);
}
console.log("=====" + "=".repeat(issue.length) + "=====");
console.log("==== " + issue + " ====");
console.log("=====" + "=".repeat(issue.length) + "=====");
console.log("");
console.log("cleanai" + " --new");
console.log(spacing() + " ".repeat(" --new".length) + "--config path/to/config.json");
console.log(spacing() + " ".repeat(" --new".length) + " ".repeat("--config path/to/config.json".length) + "--train");
console.log(spacing() + " ".repeat(" --new".length) + " ".repeat("--config path/to/config.json".length) + " ".repeat("--pretrain".length) + "[--pretrain]");
console.log(spacing() + " ".repeat(" --new".length) + " ".repeat("--config path/to/config.json".length) + "--pretrain");
console.log(spacing() + " ".repeat(" --new".length) + " ".repeat("--config path/to/config.json".length) + " ".repeat("--pretrain".length) + "[--train]");
console.log(spacing() + " --load path/to/model.zip");
console.log(spacing() + " ".repeat(" --load path/to/model.zip".length) + "[--config path/to/config.json]");
console.log(spacing() + " ".repeat(" --load path/to/model.zip".length) + " ".repeat("[--config path/to/config.json]".length) + "[--train]");
console.log(spacing() + " ".repeat(" --load path/to/model.zip".length) + " ".repeat("[--config path/to/config.json]".length) + " ".repeat("[--train]".length) + " [--pretrain]");
console.log(spacing() + " ".repeat(" --load path/to/model.zip".length) + " ".repeat("[--config path/to/config.json]".length) + "[--pretrain]");
console.log(spacing() + " ".repeat(" --load path/to/model.zip".length) + " ".repeat("[--config path/to/config.json]".length) + " ".repeat("[--pretrain]".length) + "[--train]");
console.log("");
console.log("Note: Arguments between square brackets ([...]) are optional.");
//console.log("")
}
var flag = null;
var VERBOSE = false;
var training__ = null;
var pretraining__ = null;
var config__ = false;
var skipnext = false;
var config_location = null;
var model_location = null;
if (args.length === 0) {
help();
process.exit(0);
} else {
for (var i = 0; i < args.length; i++) {
var arg = args[i];
if (skipnext) {
skipnext = false;
continue;
}
if (arg === "--new") {
if (flag === true) {
help("You can't specify --new multiple times.");
process.exit(0);
} else {
if (flag === false) {
help("You can't specify --new and --load at the same time.");
process.exit(0);
}
}
flag = true;
} else if (arg === "--load") {
if (flag === true) {
help("You can't specify --new and --load at the same time.");
process.exit(0);
} else {
if (flag === false) {
help("You can't specify --load multiple times.");
process.exit(0);
}
}
flag = false;
try {
if (args[i + 1] !== "--new") {
if (args[i + 1] !== "--train") {
if (args[i + 1] !== "--pretrain") {
if (args[i + 1] !== "--config") {
model_location = args[i + 1];
if (!fs.existsSync(model_location)) {
help("Model file " + model_location + " does not exist.");
process.exit(0);
}
if (!fs.statSync(model_location).isFile()) {
help("Model file " + model_location + " is not a file.");
process.exit(0);
}
if (model_location.slice(-4) !== ".zip") {
help("Model file " + model_location + " is not a zip file.");
process.exit(0);
}
skipnext = true;
} else {
help("You need to specify a model file after --load.");
process.exit(0);
}
} else {
help("You need to specify a model file after --load.");
process.exit(0);
}
} else {
help("You need to specify a model file after --load.");
process.exit(0);
}
} else {
help("You need to specify a model file after --load.");
process.exit(0);
}
} catch (e) {
help("You need to specify a model file after --load.");
process.exit(0);
}
} else if (arg === "--verbose") {
if (VERBOSE === true) {
help("You can't specify --verbose multiple times.");
process.exit(0);
} else {
VERBOSE = true;
}
} else if (arg === "--train") {
if (training__ === true) {
help("You can't specify --train multiple times.");
process.exit(0);
} else {
training__ = true;
}
} else if (arg === "--pretrain") {
if (pretraining__ === true) {
help("You can't specify --pretrain multiple times.");
process.exit(0);
} else {
pretraining__ = true;
}
} else if (arg === "--config") {
if (config__ === true) {
help("You can't specify --config multiple times.");
process.exit(0);
}
config__ = true;
try {
if (args[i + 1] !== "--new") {
if (args[i + 1] !== "--train") {
if (args[i + 1] !== "--pretrain") {
if (args[i + 1] !== "--config") {
config_location = args[i + 1];
if (!fs.existsSync(config_location)) {
help("Config file " + config_location + " does not exist.");
process.exit(0);
}
if (!fs.statSync(config_location).isFile()) {
help("Config file " + config_location + " is not a file.");
process.exit(0);
}
if (config_location.slice(-5) !== ".json") {
help("Config file " + config_location + " is not a json file.");
process.exit(0);
}
} else {
help("You need to specify a config file after --config.");
process.exit(0);
}
} else {
help("You need to specify a config file after --config.");
process.exit(0);
}
} else {
help("You need to specify a config file after --config.");
process.exit(0);
}
} else {
help("You need to specify a config file after --config.");
process.exit(0);
}
} catch (e) {
help("You need to specify a config file after --config.");
process.exit(0);
}
skipnext = true;
continue;
} else {
help("Argument " + arg + " not recognised.");
process.exit(0);
}
}
}
if (args.indexOf("--new") !== -1) {
if (args.indexOf("--config") === -1) {
help("You need to specify a config file with --config.");
process.exit(0);
} else {
if (args.indexOf("--train") === -1) {
if (args.indexOf("--pretrain") === -1) {
help("You need to specify either --train or --pretrain or both with --new.");
process.exit(0);
}
} else {
if (args.indexOf("--pretrain") === -1) {
if (args.indexOf("--train") === -1) {
help("You need to specify either --train or --pretrain or both with --new.");
process.exit(0);
}
}
}
}
}
if (args.indexOf("--load") !== -1){
if (args.indexOf("--train") !== -1) {
if (args.indexOf("--config") === -1) {
help("You need to specify a config file with --config.");
process.exit(0);
}
}
else{
if (args.indexOf("--pretrain") !== -1) {
if (args.indexOf("--config") === -1) {
help("You need to specify a config file with --config.");
process.exit(0);
}
}
}
}
if (args.indexOf("--new") === -1 && args.indexOf("--load") === -1) {
help("You need to specify either --new or --load.");
process.exit(0);
}
if (!VERBOSE) {
VERBOSE = false;
}
console.log("Arguments parsed successfully.");
if (config__) {
console.log("Reading and loading config file...");
var configtoparse;
try {
configtoparse = fs.readFileSync(config_location, "utf-8");
} catch (error) {
console.log("Failed to read config file, check if it's corrupted or if you don't have permissions.");
console.log("JavaScript error:");
console.log(String(error));
console.log("Exiting...");
process.exit(1);
}
try {
configtoparse = JSON.parse(configtoparse);
} catch (error) {
console.log("Failed to load json of config file, check if it's corrupted.");
console.log("JavaScript error:");
console.log(String(error));
console.log("Exiting...");
process.exit(1);
}
//key logic to check:
//if new model aka if flag is true:
// pre-training-paths must be an array of string(s) (must contain at least one string), every string in the array must lead to an existing txt file
// training-dataset-path must be a string that leads to an existing json file
// train-epochs must be an int >0
// pre-train-optimizer must be a string, can only be adam, sgd_momentum or sgd
// train-optimizer must be a string, can only be adam, sgd_momentum or sgd
// contextSize must be an int >0
// embeddingSize must be an int>0
// learningRate must be a float >0
// maxOutputSize must be an int >0
// layersAmount must be an int >0
// heads must be an int >0
// batchSize must be an int >0
// biasesinitrange must be an array of two floats >0
// embeddinginitrange must be an array of two floats >0
// antiOverfittingOptimisations must be a boolean
//if loaded model aka if flag is false:
// then everything must remain the same except it is invalid to provide the following options:
// embeddingSize
// layersAmount
// heads
// biasesinitrange
// embeddinginitrange
// Validate configtoparse for required fields and types
var isInt = function(n) {
return typeof n === "number" && isFinite(n) && Math.floor(n) === n;
};
var isFloat = function(n) {
return typeof n === "number" && isFinite(n);
};
var isString = function(s) {
return typeof s === "string";
};
var isBool = function(b) {
return typeof b === "boolean";
};
var isArray = function(a) {
return Array.isArray(a);
};
var fileExists = function(path) {
try {
return fs.existsSync(path);
} catch (e) {
return false;
}
};
// If --new, validate all required fields for new model
// Otherwise, validate for loaded model (less strict)
var isNew = args.indexOf("--new") !== -1;
// Helper for error
var configError = function(msg) {
console.log("Config validation error: " + msg);
process.exit(1);
};
// Required for both new and loaded
// Only require pre-training-paths if --pretrain flag is present
var needsPretrain = args.indexOf("--pretrain") !== -1;
if (needsPretrain) {
if (!isArray(configtoparse["pre-training-paths"]) || configtoparse["pre-training-paths"].length < 1) {
configError("pre-training-paths must be a non-empty array of strings.");
}
for (var i = 0; i < configtoparse["pre-training-paths"].length; i++) {
var p = configtoparse["pre-training-paths"][i];
if (!isString(p)) {
configError("pre-training-paths must only contain strings.");
}
if (!fileExists(p)) {
configError("pre-training-paths file does not exist: " + p);
}
}
}
// Only require training-dataset-path if --train flag is present
var needsTrain = args.indexOf("--train") !== -1;
if (needsTrain) {
if (!isString(configtoparse["training-dataset-path"])) {
configError("training-dataset-path must be a string.");
}
if (!fileExists(configtoparse["training-dataset-path"])) {
configError("training-dataset-path file does not exist: " + configtoparse["training-dataset-path"]);
}
}
var valid_optim = ["adam", "sgd_momentum", "sgd"];
// Only require train-epochs and train-optimizer if --train flag is present
if (needsTrain) {
if (!isInt(configtoparse["train-epochs"]) || configtoparse["train-epochs"] <= 0) {
configError("train-epochs must be an integer > 0.");
}
if (!isString(configtoparse["train-optimizer"]) || valid_optim.indexOf(configtoparse["train-optimizer"]) === -1) {
configError("train-optimizer must be one of: " + valid_optim.join(", "));
}
}
// Only require pre-train-epochs and pre-train-optimizer if --pretrain flag is present
if (needsPretrain) {
if (!isInt(configtoparse["pre-train-epochs"]) || configtoparse["pre-train-epochs"] <= 0) {
configError("pre-train-epochs must be an integer > 0.");
}
if (!isString(configtoparse["pre-train-optimizer"]) || valid_optim.indexOf(configtoparse["pre-train-optimizer"]) === -1) {
configError("pre-train-optimizer must be one of: " + valid_optim.join(", "));
}
}
if (!isInt(configtoparse["contextSize"]) || configtoparse["contextSize"] <= 0) {
configError("contextSize must be an integer > 0.");
}
if (!isInt(configtoparse["maxOutputSize"]) || configtoparse["maxOutputSize"] <= 0) {
configError("maxOutputSize must be an integer > 0.");
}
if (!isInt(configtoparse["batchSize"]) || configtoparse["batchSize"] <= 0) {
configError("batchSize must be an integer > 0.");
}
if (!isFloat(configtoparse["learningRate"]) || configtoparse["learningRate"] <= 0) {
configError("learningRate must be a float > 0.");
}
if (!isBool(configtoparse["antiOverfittingOptimisations"])) {
configError("antiOverfittingOptimisations must be a boolean.");
}
// For new model, check all architecture params
if (isNew) {
if (!isInt(configtoparse["embeddingSize"]) || configtoparse["embeddingSize"] <= 0) {
configError("embeddingSize must be an integer > 0.");
}
if (!isInt(configtoparse["layersAmount"]) || configtoparse["layersAmount"] <= 0) {
configError("layersAmount must be an integer > 0.");
}
if (!isInt(configtoparse["heads"]) || configtoparse["heads"] <= 0) {
configError("heads must be an integer > 0.");
}
if (!isArray(configtoparse["biasesinitrange"]) || configtoparse["biasesinitrange"].length !== 2 ||
!isFloat(configtoparse["biasesinitrange"][0]) || !isFloat(configtoparse["biasesinitrange"][1]) ||
configtoparse["biasesinitrange"][0] >= configtoparse["biasesinitrange"][1]) {
configError("biasesinitrange must be an array of two floats [min, max] with min < max.");
}
if (!isArray(configtoparse["embeddinginitrange"]) || configtoparse["embeddinginitrange"].length !== 2 ||
!isFloat(configtoparse["embeddinginitrange"][0]) || !isFloat(configtoparse["embeddinginitrange"][1]) ||
configtoparse["embeddinginitrange"][0] >= configtoparse["embeddinginitrange"][1]) {
configError("embeddinginitrange must be an array of two floats [min, max] with min < max.");
}
} else {
// For loaded model, these fields must NOT be present
var forbidden = ["embeddingSize", "layersAmount", "heads", "biasesinitrange", "embeddinginitrange"];
for (var j = 0; j < forbidden.length; j++) {
if (configtoparse.hasOwnProperty(forbidden[j])) {
configError("Config for loaded model must not contain: " + forbidden[j]);
}
}
}
if (configtoparse.noSweetSpotSaving !== undefined && typeof configtoparse.noSweetSpotSaving !== "boolean") {
configError("If specified, noSweetSpotSaving must be a boolean.")
}
config = configtoparse;
console.log("Config file loaded successfully.");
}
var ndprint = function(...args) {
console.log(...args);
};
var print = function() {
if (VERBOSE) {
ndprint.apply(null, arguments);
}
};
function timer_() {
// Generate random 32 char string as timer id using uuid v4 without dashes
var timer_id = uuid.v4().replace(/-/g, "");
timers.push({
"id": timer_id,
"start": Date.now() / 1000
});
return timer_id;
}
function timer_end(timer_id) {
for (var i = 0; i < timers.length; i++) {
if (timers[i]["id"] === timer_id) {
timers[i]["end"] = Date.now() / 1000;
// Return time in ms
return (timers[i]["end"] - timers[i]["start"]) * 1000;
}
}
return null;
}
function random_range(range) {
// Return random float between range[0] and range[1] (inclusive)
return Math.random() * (range[1] - range[0]) + range[0];
}
class Transformer {
constructor(newFlag, parameters, path, vocab_path) {
if (vocab_path === undefined) { vocab_path = "vocabulary.json"; }
this.adam_params = {
'beta1': 0.9,
'beta2': 0.98, // From 0.999 to 0.98 to match paper
'epsilon': 1e-9,
't': 0
};
ndprint("Trying to read vocabulary file...");
try {
this.vocab = JSON.parse(fs.readFileSync(__dirname + "/vocabulary.json", "utf-8"));
} catch (e) {
console.log("Failed to read vocabulary file, creating error...");
throw new Error("Failed to read vocabulary file");
}
ndprint("Successfully read vocabulary file");
ndprint("Computing lookup table...");
this.id_to_token = {};
for (var i = 0; i < this.vocab.length; i++) {
var tok = this.vocab[i];
this.id_to_token[tok[1]] = tok[0];
}
ndprint("Computed lookup table");
this.encoder = new Tiktoken(
cl100k_base.bpe_ranks,
{ "<|endoftext|>": 100257 }, // example special token
cl100k_base.pat_str
);
this.temperature = 0.7;
this.nan_checks_enabled = true; // Control logging easily
this.nan_count_this_step = 0;
this.nan_forward_pass_count_epoch = 0;
this.nan_backprop_calc_count_epoch = 0;
this.nan_final_gradient_count_epoch = 0;
this.steps_with_nan_epoch = 0;
if (newFlag) {
ndprint("Initializing model...");
// Calculate total parameters
var total_params = (
this.vocab.length * parameters["embeddingSize"] +
parameters["contextSize"] * parameters["embeddingSize"] +
parameters["layersAmount"] * (
2 * parameters["embeddingSize"] +
parameters["heads"] * (3 * parameters["embeddingSize"] * parameters["embeddingSize"] / parameters["heads"] + 3 * parameters["embeddingSize"] / parameters["heads"]) +
parameters["embeddingSize"] * parameters["embeddingSize"] + parameters["embeddingSize"] +
2 * parameters["embeddingSize"] +
parameters["embeddingSize"] * (4 * parameters["embeddingSize"]) + 4 * parameters["embeddingSize"] +
(4 * parameters["embeddingSize"]) * parameters["embeddingSize"] + parameters["embeddingSize"]
)
);
var total_ram = total_params * 4; // (32 bit floats take up 4 bytes each)
ndprint("Model is of size " + total_params + " parameters");
total_ram = total_params * 4;
ndprint(" ~" + (total_params / 1e9).toFixed(2) + "b parameters");
ndprint("");
var adam_ram = total_params * 3 * 4; // Assuming 3 times the parameters for Adam
ndprint("Would cost the equivalent of " + (total_params * 3) + " parameters if trained with Adam");
ndprint(" ~" + ((total_params * 3) / 1e9).toFixed(2) + "b parameters if trained with adam");
ndprint("");
var sgd_momentum_ram = total_params * 2 * 4; // Assuming 2 times the parameters for SGD with momentum
ndprint("Would cost the equivalent of " + (total_params * 2) + " parameters if trained with SGD with momentum");
ndprint(" ~" + ((total_params * 2) / 1e9).toFixed(2) + "b parameters if trained with SGD with momentum");
ndprint("");
ndprint("Would not cost more than the original size of the model if trained with vanilla SGD");
var sgtimer = timer_();
ndprint("Initializing parameters...");
var timer = timer_();
this.contextSize = parameters["contextSize"];
this.embeddingSize = parameters["embeddingSize"];
this.learningRate = parameters["learningRate"];
this.maxOutputSize = parameters["maxOutputSize"];
this.layersAmount = parameters["layersAmount"];
if ("use_he_init" in parameters && parameters["use_he_init"]) {
this.weightsinitrange = this.he_init(this.embeddingSize);
console.log("Using He initialization with range: " + this.weightsinitrange);
} else {
this.weightsinitrange = parameters["weightsinitrange"];
}
this.biasesinitrange = parameters["biasesinitrange"];
this.heads = parameters["heads"];
this.embeddinginitrange = parameters["embeddinginitrange"];
this.transformer = {};
this.step_num = 0;
ndprint("Initialized parameters in", timer_end(timer), "ms");
var percentagePrintInterval = 10;
ndprint("Initializing layers...");
var gtimer = timer_();
this.transformer["layers"] = [];
for (var i = 0; i < this.layersAmount; i++) {
var timer_layer = timer_();
console.log("Initializing weights and biases for layer " + i);
this.transformer["layers"].push({
"weights": {
"normalize_1": sharedFloat32Array(this.embeddingSize * 3).fill(0),
"attention": {
"heads": (function() {
var arr = [];
for (var h = 0; h < this.heads; h++) {
arr.push({
"query": sharedFloat32Array(this.embeddingSize * this.embeddingSize * 3).fill(0),
"key": sharedFloat32Array(this.embeddingSize * this.embeddingSize * 3).fill(0),
"value": sharedFloat32Array(this.embeddingSize * this.embeddingSize * 3).fill(0)
});
}
return arr;
}).call(this),
"output": sharedFloat32Array(this.embeddingSize * (this.embeddingSize * this.heads) * 3).fill(0)
},
"normalize_2": sharedFloat32Array(this.embeddingSize * 3).fill(0),
"feed_forward": {
"grow": sharedFloat32Array(this.embeddingSize * (this.embeddingSize * 4) * 3).fill(0),
"shrink": sharedFloat32Array((this.embeddingSize * 4) * this.embeddingSize * 3).fill(0)
}
},
"biases": {
"normalize_1": sharedFloat32Array(this.embeddingSize * 3).fill(0),
"attention": {
"heads": (function() {
var arr = [];
for (var h = 0; h < this.heads; h++) {
arr.push({
"query": sharedFloat32Array(this.embeddingSize * 3).fill(0),
"key": sharedFloat32Array(this.embeddingSize * 3).fill(0),
"value": sharedFloat32Array(this.embeddingSize * 3).fill(0)
});
}
return arr;
}).call(this),
"output": sharedFloat32Array(this.embeddingSize * 3).fill(0)
},
"normalize_2": sharedFloat32Array(this.embeddingSize * 3).fill(0),
"feed_forward": {
"grow": sharedFloat32Array((this.embeddingSize * 4) * 3).fill(0),
"shrink": sharedFloat32Array(this.embeddingSize * 3).fill(0)
}
}
});
var total_params_layer = 2 * this.embeddingSize + 3 * this.heads * (this.embeddingSize * this.embeddingSize + this.embeddingSize) + this.embeddingSize * (this.embeddingSize * this.heads) + this.embeddingSize + 2 * this.embeddingSize + this.embeddingSize * (this.embeddingSize * 4) + (this.embeddingSize * 4) + (this.embeddingSize * 4) * this.embeddingSize + this.embeddingSize;
var params_done = 0;
var last_percent = -percentagePrintInterval;
for (var j = 0; j < this.embeddingSize; j++) {
this.transformer["layers"][i]["weights"]["normalize_1"][j * 3] = random_range(this.weightsinitrange);
this.transformer["layers"][i]["biases"]["normalize_1"][j * 3] = random_range(this.biasesinitrange);
this.transformer["layers"][i]["weights"]["normalize_2"][j * 3] = random_range(this.weightsinitrange);
this.transformer["layers"][i]["biases"]["normalize_2"][j * 3] = random_range(this.biasesinitrange);
params_done += 4;
var percent = Math.floor((params_done * 100) / total_params_layer);
if (percent >= last_percent + percentagePrintInterval) {
last_percent = Math.floor(percent / percentagePrintInterval) * percentagePrintInterval;
ndprint(" Layer " + i + ": " + last_percent + "% complete");
}
}
for (var j = 0; j < this.heads; j++) {
for (var k = 0; k < this.embeddingSize * this.embeddingSize; k++) {
this.transformer["layers"][i]["weights"]["attention"]["heads"][j]["query"][k * 3] = random_range(this.weightsinitrange);
this.transformer["layers"][i]["weights"]["attention"]["heads"][j]["key"][k * 3] = random_range(this.weightsinitrange);
this.transformer["layers"][i]["weights"]["attention"]["heads"][j]["value"][k * 3] = random_range(this.weightsinitrange);
params_done += 3;
percent = Math.floor((params_done * 100) / total_params_layer);
if (percent >= last_percent + percentagePrintInterval) {
last_percent = Math.floor(percent / percentagePrintInterval) * percentagePrintInterval;
ndprint(" Layer " + i + ": " + last_percent + "% complete");
}
}
for (var k = 0; k < this.embeddingSize; k++) {
this.transformer["layers"][i]["biases"]["attention"]["heads"][j]["query"][k * 3] = random_range(this.biasesinitrange);
this.transformer["layers"][i]["biases"]["attention"]["heads"][j]["key"][k * 3] = random_range(this.biasesinitrange);
this.transformer["layers"][i]["biases"]["attention"]["heads"][j]["value"][k * 3] = random_range(this.biasesinitrange);
params_done += 3;
percent = Math.floor((params_done * 100) / total_params_layer);
if (percent >= last_percent + percentagePrintInterval) {
last_percent = Math.floor(percent / percentagePrintInterval) * percentagePrintInterval;
ndprint(" Layer " + i + ": " + last_percent + "% complete");
}
}
}
for (var j = 0; j < this.embeddingSize * (this.embeddingSize * this.heads); j++) {
this.transformer["layers"][i]["weights"]["attention"]["output"][j * 3] = random_range(this.weightsinitrange);
params_done += 1;
percent = Math.floor((params_done * 100) / total_params_layer);
if (percent >= last_percent + percentagePrintInterval) {
last_percent = Math.floor(percent / percentagePrintInterval) * percentagePrintInterval;
ndprint(" Layer " + i + ": " + last_percent + "% complete");
}
}
for (var j = 0; j < this.embeddingSize; j++) {
this.transformer["layers"][i]["biases"]["attention"]["output"][j * 3] = random_range(this.biasesinitrange);
params_done += 1;
percent = Math.floor((params_done * 100) / total_params_layer);
if (percent >= last_percent + percentagePrintInterval) {
last_percent = Math.floor(percent / percentagePrintInterval) * percentagePrintInterval;
ndprint(" Layer " + i + ": " + last_percent + "% complete");
}
}
for (var j = 0; j < this.embeddingSize * (this.embeddingSize * 4); j++) {
this.transformer["layers"][i]["weights"]["feed_forward"]["grow"][j * 3] = random_range(this.weightsinitrange);
params_done += 1;
percent = Math.floor((params_done * 100) / total_params_layer);
if (percent >= last_percent + percentagePrintInterval) {
last_percent = Math.floor(percent / percentagePrintInterval) * percentagePrintInterval;
ndprint(" Layer " + i + ": " + last_percent + "% complete");
}
}
for (var j = 0; j < this.embeddingSize * 4; j++) {
this.transformer["layers"][i]["biases"]["feed_forward"]["grow"][j * 3] = random_range(this.biasesinitrange);
params_done += 1;
percent = Math.floor((params_done * 100) / total_params_layer);
if (percent >= last_percent + percentagePrintInterval) {
last_percent = Math.floor(percent / percentagePrintInterval) * percentagePrintInterval;
ndprint(" Layer " + i + ": " + last_percent + "% complete");
}
}
for (var j = 0; j < (this.embeddingSize * 4) * this.embeddingSize; j++) {
this.transformer["layers"][i]["weights"]["feed_forward"]["shrink"][j * 3] = random_range(this.weightsinitrange);
params_done += 1;
percent = Math.floor((params_done * 100) / total_params_layer);
if (percent >= last_percent + percentagePrintInterval) {
last_percent = Math.floor(percent / percentagePrintInterval) * percentagePrintInterval;
ndprint(" Layer " + i + ": " + last_percent + "% complete");
}
}
for (var j = 0; j < this.embeddingSize; j++) {
this.transformer["layers"][i]["biases"]["feed_forward"]["shrink"][j * 3] = random_range(this.biasesinitrange);
params_done += 1;
percent = Math.floor((params_done * 100) / total_params_layer);
if (percent >= last_percent + percentagePrintInterval) {
last_percent = Math.floor(percent / percentagePrintInterval) * percentagePrintInterval;
ndprint(" Layer " + i + ": " + last_perce