UNPKG

cleanai

Version:

A fully standalone, terminal-based AI model CLI for training and inference

1,124 lines (1,030 loc) 333 kB
#!/usr/bin/env node // cleanai.js // //* It is way better to download the better comments extension for vscode to view this file because //* it adds color to the comments depending on how much you should read them thanks to the little symbols //* like * ! and ? that I put after // on the comments. // Credits // // Me ofc (willmil11) // Claude (3.5 sonnet, 3.7 sonnet) // Chatgpt (4o, o1, o3, o3-mini-high, o4-mini-high) // Gemini (2.5 flash, 2.5 pro. aistudio versions) // // Contact me // Discord: willmil11 // Email: willmil111012@gmail.com // // Drop a star on the github repo if you like it :D var tresholdworkersworthit = 27; //This is for remote mode that is not activated by default so ignore it. var mode = "SharedMemory" //Could also be "Remote" var submode = "Turtle" //Could also be "Rabbit" //Remote mode is legacy I built it before I learnt of SharedBufferArrays, it has imense overhead and it //is very slow, however it works perfectly. It is inferior to other modes however I am leaving it here as //some kind of museum, of a display of the hour I wasted building it. //Rabit mode means the worker asks for data when it is required which adds some overhead but means the workers are spawned near instantly. //Turtle mode means the worker is given all data on spawn, so less overhead but way longer spawn (initialisation) time. //Neither mode uses more or less memory. (or the difference is negligible, a few kilobytes at most) //Do not change mode or submode (recommended is SharedMemory/Turtle) //All modes and submodes should work natively however. var workerPids = []; var timers = []; var fs = require("fs"); var process = require("process"); var ranKill = false var toSharedFloat32 = function(arr) { var shared = new SharedArrayBuffer(arr.byteLength); var sharedView = new Float32Array(shared); sharedView.set(arr); return sharedView; }; var sharedFloat32Array = function(input) { if (input instanceof Float32Array) return toSharedFloat32(input); if (Array.isArray(input)) return toSharedFloat32(new Float32Array(input)); if (ArrayBuffer.isView(input)) return toSharedFloat32(new Float32Array(input)); if (typeof input === 'number') return new Float32Array(new SharedArrayBuffer(input * 4)); throw new Error("sharedFloat32Array: invalid input"); }; var killWorkers = function(error) { if (mode === "SharedMemory") return; if (ranKill) return; ranKill = true; if (error) { console.error("[Exit] Uncaught exception:", error); } console.log("[Exit] Killing workers..."); for (var i in workerPids) { console.log("[Exit] Killing worker " + workerPids[i] + "..."); try { if (os.platform() === 'win32') { process.kill(workerPids[i]); } else { process.kill(workerPids[i], 'SIGTERM'); } } catch (err) { if (err.code !== 'ESRCH') throw err; } console.log("[Exit] Killed worker " + workerPids[i] + "."); } console.log("[Exit] Killed workers."); workerPids = []; console.log("[Exit] Exiting..."); process.exit(0); }; if (mode !== "SharedMemory") { process.on('exit', killWorkers); process.on('SIGINT', killWorkers); process.on('SIGUSR1', killWorkers); process.on('SIGUSR2', killWorkers); process.on('uncaughtException', function(err) { killWorkers(err); }); } //var json = JSON; var path = require("path"); var os = require("os"); var readline = require("readline"); var spawnchild = require("child_process").spawn; var hasExposeGC = process.execArgv.includes('--expose-gc'); var hasMemoryLimit = process.execArgv.some(function(arg) { return arg.startsWith('--max-old-space-size='); }); // This value is a meme and will work fine until Node panics or the kernel screams // - Chatgpt 4o 19.04.2025 var desiredMemFlag = '--max-old-space-size=9999999999999'; if (!hasExposeGC || !hasMemoryLimit) { var newArgs = [ !hasMemoryLimit ? desiredMemFlag : null, !hasExposeGC ? '--expose-gc' : null, "--no-opt", "--interpreted-frames-native-stack" ] .concat(process.execArgv) .concat([process.argv[1]]) .concat(process.argv.slice(2)) .filter(Boolean); // Remove nulls spawnchild(process.argv[0], newArgs, { stdio: 'inherit' }) .on('exit', function(code) { process.exit(code); }); return; // Let the child process take over } var processids = []; var randomRangeInclusive = function(range) { return Math.floor(Math.random() * (range[1] - range[0] + 1)) + range[0]; }; var generateProcessId = function() { var processid = ""; while (true) { processid = ""; for (var index = 0; index < 5; index++) { processid += randomRangeInclusive([0, 9]); } var found = false; for (var i = 0; i < processids.length; i++) { if (processids[i] === processid) { found = true; break; } } if (!found) { processids.push(processid); break; } } return { processid: processid, revoke: function() { console.log("[" + this.processid + "] Revoking processid..."); var index = processids.indexOf(this.processid); if (index !== -1) { processids.splice(index, 1); } var old = this.processid; this.processid = null; console.log("[" + old + "] Revoked processid."); } }; }; var bridgeids = []; var generateBridgeId = function() { var bridgeid = ""; while (true) { bridgeid = ""; for (var index = 0; index < 5; index++) { bridgeid += randomRangeInclusive([0, 9]); } var found = false; for (var i = 0; i < bridgeids.length; i++) { if (bridgeids[i] === bridgeid) { found = true; break; } } if (!found) { bridgeids.push(bridgeid); break; } } return { bridgeid: bridgeid, revoke: function() { console.log("[" + this.bridgeid + "] Revoking bridgeid..."); var index = bridgeids.indexOf(this.bridgeid); if (index !== -1) { bridgeids.splice(index, 1); } var old = this.bridgeid; this.bridgeid = null; console.log("[" + old + "] Revoked bridgeid."); } }; }; var readline_async_synclike = async function(query){ return await new Promise((resolve) => { var rl = readline.createInterface({ input: process.stdin, output: process.stdout }); rl.question(query, function(answer) { rl.close(); resolve(answer); }); }); } var wait = async function(ms){ await new Promise(function(resolve){ setTimeout(function(){ resolve(); }, ms) }) } var resolveDependency = async function(dependency){ console.log("Missing dependency: " + dependency) console.log("Would you like to auto-install all dependencies?") console.log("Dependencies are:") console.log(" - readline-sync") console.log(" - tiktoken") console.log(" - archiver") console.log(" - yauzl") console.log(" - uuid") console.log(" - ws") console.log("I am not liable if auto-installing the dependencies using the command below causes any damage of any sort to your system.") console.log("Command: \"npm install readline-sync tiktoken archiver yauzl uuid ws\"") console.log("Correct answers to the following prompt are \"y\" for yes and \"n\" for no without the quotes.") while (true){ res = await readline_async_synclike("Auto-install? (y/n) › ") if (res === "y"){ console.log("Auto-installing dependencies...") var install = spawnchild('npm', [ 'install', 'readline-sync', 'tiktoken', 'archiver', 'yauzl', 'uuid', 'ws' ], { stdio: 'inherit' // This is the key: pipes stdout, stderr, and stdin }); finishedJob = false; install.on('exit', code => { if (code === 0) { console.log("Dependencies auto-installed successfully."); finishedJob = true; } else { console.log("Failed to auto-install dependencies."); console.log("Please check you have npm installed, an internet connection and enough disk space available then retry running the command, perhaps try running it manually."); console.log("Exiting...") process.exit(1); } }); while (!finishedJob) { await wait(100); } if (finishedJob){ break } } else{ if (res === "n"){ console.log("Dependencies not installed, and you refused to install them automatically. Cannot continue.") console.log("Exiting...") process.exit(1); } else{ console.log("Invalid answer, valid answers are \"y\" for yes and \"n\" for no without the quotes. Please try again.") } } } } ;(async function(){ try{ var readlineSync = require("readline-sync"); } catch (error){ await resolveDependency("readline-sync") try{ var readlineSync = require("readline-sync"); } catch (error){ console.log("Even though auto-installation was successful, the readline-sync dependency could not be loaded.") console.log("Exiting...") process.exit(1); } } try{ var Tiktoken = require("tiktoken/lite").Tiktoken; } catch (error){ await resolveDependency("tiktoken") try{ var Tiktoken = require("tiktoken/lite").Tiktoken; } catch (error){ console.log("Even though auto-installation was successful, the tiktoken dependency could not be loaded.") console.log("Exiting...") process.exit(1); } } var cl100k_base = require("tiktoken/encoders/cl100k_base.json"); try{ var uuid = require("uuid"); } catch (error){ await resolveDependency("uuid") try{ var uuid = require("uuid"); } catch (error){ console.log("Even though auto-installation was successful, the uuid dependency could not be loaded.") console.log("Exiting...") process.exit(1); } } try{ var archiver = require('archiver'); } catch (error){ await resolveDependency("archiver") try{ var archiver = require('archiver'); } catch (error){ console.log("Even though auto-installation was successful, the archiver dependency could not be loaded.") console.log("Exiting...") process.exit(1); } } try{ var yauzl = require('yauzl'); } catch (error){ await resolveDependency("yauzl") try{ var yauzl = require('yauzl'); } catch (error){ console.log("Even though auto-installation was successful, the yauzl dependency could not be loaded.") console.log("Exiting...") process.exit(1); } } try{ var WebSocket = require("ws"); } catch (error){ await resolveDependency("ws") try{ var WebSocket = require("ws"); } catch (error){ console.log("Even though auto-installation was successful, the ws dependency could not be loaded.") console.log("Exiting...") process.exit(1); } } var args = process.argv.slice(2); var config = {}; var CHUNK_THRESHOLD_BYTES = 100 * 1024 * 1024; function inputWithTimeout(prompt, timeoutMs) { return new Promise(function (resolve) { var rl = readline.createInterface({ input: process.stdin, output: process.stdout }); var done = false; // Timer to enforce the timeout var timer = setTimeout(function () { if (!done) { done = true; rl.close(); resolve(null); // Return null on timeout } }, timeoutMs); // Ask the question rl.question(prompt, function (answer) { if (!done) { done = true; clearTimeout(timer); rl.close(); resolve(answer); } }); }); } function help(issue) { if (issue === undefined) { issue = "No args found."; } function spacing() { return " ".repeat(("cleanai").length); } console.log("=====" + "=".repeat(issue.length) + "====="); console.log("==== " + issue + " ===="); console.log("=====" + "=".repeat(issue.length) + "====="); console.log(""); console.log("cleanai" + " --new"); console.log(spacing() + " ".repeat(" --new".length) + "--config path/to/config.json"); console.log(spacing() + " ".repeat(" --new".length) + " ".repeat("--config path/to/config.json".length) + "--train"); console.log(spacing() + " ".repeat(" --new".length) + " ".repeat("--config path/to/config.json".length) + " ".repeat("--pretrain".length) + "[--pretrain]"); console.log(spacing() + " ".repeat(" --new".length) + " ".repeat("--config path/to/config.json".length) + "--pretrain"); console.log(spacing() + " ".repeat(" --new".length) + " ".repeat("--config path/to/config.json".length) + " ".repeat("--pretrain".length) + "[--train]"); console.log(spacing() + " --load path/to/model.zip"); console.log(spacing() + " ".repeat(" --load path/to/model.zip".length) + "[--config path/to/config.json]"); console.log(spacing() + " ".repeat(" --load path/to/model.zip".length) + " ".repeat("[--config path/to/config.json]".length) + "[--train]"); console.log(spacing() + " ".repeat(" --load path/to/model.zip".length) + " ".repeat("[--config path/to/config.json]".length) + " ".repeat("[--train]".length) + " [--pretrain]"); console.log(spacing() + " ".repeat(" --load path/to/model.zip".length) + " ".repeat("[--config path/to/config.json]".length) + "[--pretrain]"); console.log(spacing() + " ".repeat(" --load path/to/model.zip".length) + " ".repeat("[--config path/to/config.json]".length) + " ".repeat("[--pretrain]".length) + "[--train]"); console.log(""); console.log("Note: Arguments between square brackets ([...]) are optional."); //console.log("") } var flag = null; var VERBOSE = false; var training__ = null; var pretraining__ = null; var config__ = false; var skipnext = false; var config_location = null; var model_location = null; if (args.length === 0) { help(); process.exit(0); } else { for (var i = 0; i < args.length; i++) { var arg = args[i]; if (skipnext) { skipnext = false; continue; } if (arg === "--new") { if (flag === true) { help("You can't specify --new multiple times."); process.exit(0); } else { if (flag === false) { help("You can't specify --new and --load at the same time."); process.exit(0); } } flag = true; } else if (arg === "--load") { if (flag === true) { help("You can't specify --new and --load at the same time."); process.exit(0); } else { if (flag === false) { help("You can't specify --load multiple times."); process.exit(0); } } flag = false; try { if (args[i + 1] !== "--new") { if (args[i + 1] !== "--train") { if (args[i + 1] !== "--pretrain") { if (args[i + 1] !== "--config") { model_location = args[i + 1]; if (!fs.existsSync(model_location)) { help("Model file " + model_location + " does not exist."); process.exit(0); } if (!fs.statSync(model_location).isFile()) { help("Model file " + model_location + " is not a file."); process.exit(0); } if (model_location.slice(-4) !== ".zip") { help("Model file " + model_location + " is not a zip file."); process.exit(0); } skipnext = true; } else { help("You need to specify a model file after --load."); process.exit(0); } } else { help("You need to specify a model file after --load."); process.exit(0); } } else { help("You need to specify a model file after --load."); process.exit(0); } } else { help("You need to specify a model file after --load."); process.exit(0); } } catch (e) { help("You need to specify a model file after --load."); process.exit(0); } } else if (arg === "--verbose") { if (VERBOSE === true) { help("You can't specify --verbose multiple times."); process.exit(0); } else { VERBOSE = true; } } else if (arg === "--train") { if (training__ === true) { help("You can't specify --train multiple times."); process.exit(0); } else { training__ = true; } } else if (arg === "--pretrain") { if (pretraining__ === true) { help("You can't specify --pretrain multiple times."); process.exit(0); } else { pretraining__ = true; } } else if (arg === "--config") { if (config__ === true) { help("You can't specify --config multiple times."); process.exit(0); } config__ = true; try { if (args[i + 1] !== "--new") { if (args[i + 1] !== "--train") { if (args[i + 1] !== "--pretrain") { if (args[i + 1] !== "--config") { config_location = args[i + 1]; if (!fs.existsSync(config_location)) { help("Config file " + config_location + " does not exist."); process.exit(0); } if (!fs.statSync(config_location).isFile()) { help("Config file " + config_location + " is not a file."); process.exit(0); } if (config_location.slice(-5) !== ".json") { help("Config file " + config_location + " is not a json file."); process.exit(0); } } else { help("You need to specify a config file after --config."); process.exit(0); } } else { help("You need to specify a config file after --config."); process.exit(0); } } else { help("You need to specify a config file after --config."); process.exit(0); } } else { help("You need to specify a config file after --config."); process.exit(0); } } catch (e) { help("You need to specify a config file after --config."); process.exit(0); } skipnext = true; continue; } else { help("Argument " + arg + " not recognised."); process.exit(0); } } } if (args.indexOf("--new") !== -1) { if (args.indexOf("--config") === -1) { help("You need to specify a config file with --config."); process.exit(0); } else { if (args.indexOf("--train") === -1) { if (args.indexOf("--pretrain") === -1) { help("You need to specify either --train or --pretrain or both with --new."); process.exit(0); } } else { if (args.indexOf("--pretrain") === -1) { if (args.indexOf("--train") === -1) { help("You need to specify either --train or --pretrain or both with --new."); process.exit(0); } } } } } if (args.indexOf("--load") !== -1){ if (args.indexOf("--train") !== -1) { if (args.indexOf("--config") === -1) { help("You need to specify a config file with --config."); process.exit(0); } } else{ if (args.indexOf("--pretrain") !== -1) { if (args.indexOf("--config") === -1) { help("You need to specify a config file with --config."); process.exit(0); } } } } if (args.indexOf("--new") === -1 && args.indexOf("--load") === -1) { help("You need to specify either --new or --load."); process.exit(0); } if (!VERBOSE) { VERBOSE = false; } console.log("Arguments parsed successfully."); if (config__) { console.log("Reading and loading config file..."); var configtoparse; try { configtoparse = fs.readFileSync(config_location, "utf-8"); } catch (error) { console.log("Failed to read config file, check if it's corrupted or if you don't have permissions."); console.log("JavaScript error:"); console.log(String(error)); console.log("Exiting..."); process.exit(1); } try { configtoparse = JSON.parse(configtoparse); } catch (error) { console.log("Failed to load json of config file, check if it's corrupted."); console.log("JavaScript error:"); console.log(String(error)); console.log("Exiting..."); process.exit(1); } //key logic to check: //if new model aka if flag is true: // pre-training-paths must be an array of string(s) (must contain at least one string), every string in the array must lead to an existing txt file // training-dataset-path must be a string that leads to an existing json file // train-epochs must be an int >0 // pre-train-optimizer must be a string, can only be adam, sgd_momentum or sgd // train-optimizer must be a string, can only be adam, sgd_momentum or sgd // contextSize must be an int >0 // embeddingSize must be an int>0 // learningRate must be a float >0 // maxOutputSize must be an int >0 // layersAmount must be an int >0 // heads must be an int >0 // batchSize must be an int >0 // biasesinitrange must be an array of two floats >0 // embeddinginitrange must be an array of two floats >0 // antiOverfittingOptimisations must be a boolean //if loaded model aka if flag is false: // then everything must remain the same except it is invalid to provide the following options: // embeddingSize // layersAmount // heads // biasesinitrange // embeddinginitrange // Validate configtoparse for required fields and types var isInt = function(n) { return typeof n === "number" && isFinite(n) && Math.floor(n) === n; }; var isFloat = function(n) { return typeof n === "number" && isFinite(n); }; var isString = function(s) { return typeof s === "string"; }; var isBool = function(b) { return typeof b === "boolean"; }; var isArray = function(a) { return Array.isArray(a); }; var fileExists = function(path) { try { return fs.existsSync(path); } catch (e) { return false; } }; // If --new, validate all required fields for new model // Otherwise, validate for loaded model (less strict) var isNew = args.indexOf("--new") !== -1; // Helper for error var configError = function(msg) { console.log("Config validation error: " + msg); process.exit(1); }; // Required for both new and loaded // Only require pre-training-paths if --pretrain flag is present var needsPretrain = args.indexOf("--pretrain") !== -1; if (needsPretrain) { if (!isArray(configtoparse["pre-training-paths"]) || configtoparse["pre-training-paths"].length < 1) { configError("pre-training-paths must be a non-empty array of strings."); } for (var i = 0; i < configtoparse["pre-training-paths"].length; i++) { var p = configtoparse["pre-training-paths"][i]; if (!isString(p)) { configError("pre-training-paths must only contain strings."); } if (!fileExists(p)) { configError("pre-training-paths file does not exist: " + p); } } } // Only require training-dataset-path if --train flag is present var needsTrain = args.indexOf("--train") !== -1; if (needsTrain) { if (!isString(configtoparse["training-dataset-path"])) { configError("training-dataset-path must be a string."); } if (!fileExists(configtoparse["training-dataset-path"])) { configError("training-dataset-path file does not exist: " + configtoparse["training-dataset-path"]); } } var valid_optim = ["adam", "sgd_momentum", "sgd"]; // Only require train-epochs and train-optimizer if --train flag is present if (needsTrain) { if (!isInt(configtoparse["train-epochs"]) || configtoparse["train-epochs"] <= 0) { configError("train-epochs must be an integer > 0."); } if (!isString(configtoparse["train-optimizer"]) || valid_optim.indexOf(configtoparse["train-optimizer"]) === -1) { configError("train-optimizer must be one of: " + valid_optim.join(", ")); } } // Only require pre-train-epochs and pre-train-optimizer if --pretrain flag is present if (needsPretrain) { if (!isInt(configtoparse["pre-train-epochs"]) || configtoparse["pre-train-epochs"] <= 0) { configError("pre-train-epochs must be an integer > 0."); } if (!isString(configtoparse["pre-train-optimizer"]) || valid_optim.indexOf(configtoparse["pre-train-optimizer"]) === -1) { configError("pre-train-optimizer must be one of: " + valid_optim.join(", ")); } } if (!isInt(configtoparse["contextSize"]) || configtoparse["contextSize"] <= 0) { configError("contextSize must be an integer > 0."); } if (!isInt(configtoparse["maxOutputSize"]) || configtoparse["maxOutputSize"] <= 0) { configError("maxOutputSize must be an integer > 0."); } if (!isInt(configtoparse["batchSize"]) || configtoparse["batchSize"] <= 0) { configError("batchSize must be an integer > 0."); } if (!isFloat(configtoparse["learningRate"]) || configtoparse["learningRate"] <= 0) { configError("learningRate must be a float > 0."); } if (!isBool(configtoparse["antiOverfittingOptimisations"])) { configError("antiOverfittingOptimisations must be a boolean."); } // For new model, check all architecture params if (isNew) { if (!isInt(configtoparse["embeddingSize"]) || configtoparse["embeddingSize"] <= 0) { configError("embeddingSize must be an integer > 0."); } if (!isInt(configtoparse["layersAmount"]) || configtoparse["layersAmount"] <= 0) { configError("layersAmount must be an integer > 0."); } if (!isInt(configtoparse["heads"]) || configtoparse["heads"] <= 0) { configError("heads must be an integer > 0."); } if (!isArray(configtoparse["biasesinitrange"]) || configtoparse["biasesinitrange"].length !== 2 || !isFloat(configtoparse["biasesinitrange"][0]) || !isFloat(configtoparse["biasesinitrange"][1]) || configtoparse["biasesinitrange"][0] >= configtoparse["biasesinitrange"][1]) { configError("biasesinitrange must be an array of two floats [min, max] with min < max."); } if (!isArray(configtoparse["embeddinginitrange"]) || configtoparse["embeddinginitrange"].length !== 2 || !isFloat(configtoparse["embeddinginitrange"][0]) || !isFloat(configtoparse["embeddinginitrange"][1]) || configtoparse["embeddinginitrange"][0] >= configtoparse["embeddinginitrange"][1]) { configError("embeddinginitrange must be an array of two floats [min, max] with min < max."); } } else { // For loaded model, these fields must NOT be present var forbidden = ["embeddingSize", "layersAmount", "heads", "biasesinitrange", "embeddinginitrange"]; for (var j = 0; j < forbidden.length; j++) { if (configtoparse.hasOwnProperty(forbidden[j])) { configError("Config for loaded model must not contain: " + forbidden[j]); } } } if (configtoparse.noSweetSpotSaving !== undefined && typeof configtoparse.noSweetSpotSaving !== "boolean") { configError("If specified, noSweetSpotSaving must be a boolean.") } config = configtoparse; console.log("Config file loaded successfully."); } var ndprint = function(...args) { console.log(...args); }; var print = function() { if (VERBOSE) { ndprint.apply(null, arguments); } }; function timer_() { // Generate random 32 char string as timer id using uuid v4 without dashes var timer_id = uuid.v4().replace(/-/g, ""); timers.push({ "id": timer_id, "start": Date.now() / 1000 }); return timer_id; } function timer_end(timer_id) { for (var i = 0; i < timers.length; i++) { if (timers[i]["id"] === timer_id) { timers[i]["end"] = Date.now() / 1000; // Return time in ms return (timers[i]["end"] - timers[i]["start"]) * 1000; } } return null; } function random_range(range) { // Return random float between range[0] and range[1] (inclusive) return Math.random() * (range[1] - range[0]) + range[0]; } class Transformer { constructor(newFlag, parameters, path, vocab_path) { if (vocab_path === undefined) { vocab_path = "vocabulary.json"; } this.adam_params = { 'beta1': 0.9, 'beta2': 0.98, // From 0.999 to 0.98 to match paper 'epsilon': 1e-9, 't': 0 }; ndprint("Trying to read vocabulary file..."); try { this.vocab = JSON.parse(fs.readFileSync(__dirname + "/vocabulary.json", "utf-8")); } catch (e) { console.log("Failed to read vocabulary file, creating error..."); throw new Error("Failed to read vocabulary file"); } ndprint("Successfully read vocabulary file"); ndprint("Computing lookup table..."); this.id_to_token = {}; for (var i = 0; i < this.vocab.length; i++) { var tok = this.vocab[i]; this.id_to_token[tok[1]] = tok[0]; } ndprint("Computed lookup table"); this.encoder = new Tiktoken( cl100k_base.bpe_ranks, { "<|endoftext|>": 100257 }, // example special token cl100k_base.pat_str ); this.temperature = 0.7; this.nan_checks_enabled = true; // Control logging easily this.nan_count_this_step = 0; this.nan_forward_pass_count_epoch = 0; this.nan_backprop_calc_count_epoch = 0; this.nan_final_gradient_count_epoch = 0; this.steps_with_nan_epoch = 0; if (newFlag) { ndprint("Initializing model..."); // Calculate total parameters var total_params = ( this.vocab.length * parameters["embeddingSize"] + parameters["contextSize"] * parameters["embeddingSize"] + parameters["layersAmount"] * ( 2 * parameters["embeddingSize"] + parameters["heads"] * (3 * parameters["embeddingSize"] * parameters["embeddingSize"] / parameters["heads"] + 3 * parameters["embeddingSize"] / parameters["heads"]) + parameters["embeddingSize"] * parameters["embeddingSize"] + parameters["embeddingSize"] + 2 * parameters["embeddingSize"] + parameters["embeddingSize"] * (4 * parameters["embeddingSize"]) + 4 * parameters["embeddingSize"] + (4 * parameters["embeddingSize"]) * parameters["embeddingSize"] + parameters["embeddingSize"] ) ); var total_ram = total_params * 4; // (32 bit floats take up 4 bytes each) ndprint("Model is of size " + total_params + " parameters"); total_ram = total_params * 4; ndprint(" ~" + (total_params / 1e9).toFixed(2) + "b parameters"); ndprint(""); var adam_ram = total_params * 3 * 4; // Assuming 3 times the parameters for Adam ndprint("Would cost the equivalent of " + (total_params * 3) + " parameters if trained with Adam"); ndprint(" ~" + ((total_params * 3) / 1e9).toFixed(2) + "b parameters if trained with adam"); ndprint(""); var sgd_momentum_ram = total_params * 2 * 4; // Assuming 2 times the parameters for SGD with momentum ndprint("Would cost the equivalent of " + (total_params * 2) + " parameters if trained with SGD with momentum"); ndprint(" ~" + ((total_params * 2) / 1e9).toFixed(2) + "b parameters if trained with SGD with momentum"); ndprint(""); ndprint("Would not cost more than the original size of the model if trained with vanilla SGD"); var sgtimer = timer_(); ndprint("Initializing parameters..."); var timer = timer_(); this.contextSize = parameters["contextSize"]; this.embeddingSize = parameters["embeddingSize"]; this.learningRate = parameters["learningRate"]; this.maxOutputSize = parameters["maxOutputSize"]; this.layersAmount = parameters["layersAmount"]; if ("use_he_init" in parameters && parameters["use_he_init"]) { this.weightsinitrange = this.he_init(this.embeddingSize); console.log("Using He initialization with range: " + this.weightsinitrange); } else { this.weightsinitrange = parameters["weightsinitrange"]; } this.biasesinitrange = parameters["biasesinitrange"]; this.heads = parameters["heads"]; this.embeddinginitrange = parameters["embeddinginitrange"]; this.transformer = {}; this.step_num = 0; ndprint("Initialized parameters in", timer_end(timer), "ms"); var percentagePrintInterval = 10; ndprint("Initializing layers..."); var gtimer = timer_(); this.transformer["layers"] = []; for (var i = 0; i < this.layersAmount; i++) { var timer_layer = timer_(); console.log("Initializing weights and biases for layer " + i); this.transformer["layers"].push({ "weights": { "normalize_1": sharedFloat32Array(this.embeddingSize * 3).fill(0), "attention": { "heads": (function() { var arr = []; for (var h = 0; h < this.heads; h++) { arr.push({ "query": sharedFloat32Array(this.embeddingSize * this.embeddingSize * 3).fill(0), "key": sharedFloat32Array(this.embeddingSize * this.embeddingSize * 3).fill(0), "value": sharedFloat32Array(this.embeddingSize * this.embeddingSize * 3).fill(0) }); } return arr; }).call(this), "output": sharedFloat32Array(this.embeddingSize * (this.embeddingSize * this.heads) * 3).fill(0) }, "normalize_2": sharedFloat32Array(this.embeddingSize * 3).fill(0), "feed_forward": { "grow": sharedFloat32Array(this.embeddingSize * (this.embeddingSize * 4) * 3).fill(0), "shrink": sharedFloat32Array((this.embeddingSize * 4) * this.embeddingSize * 3).fill(0) } }, "biases": { "normalize_1": sharedFloat32Array(this.embeddingSize * 3).fill(0), "attention": { "heads": (function() { var arr = []; for (var h = 0; h < this.heads; h++) { arr.push({ "query": sharedFloat32Array(this.embeddingSize * 3).fill(0), "key": sharedFloat32Array(this.embeddingSize * 3).fill(0), "value": sharedFloat32Array(this.embeddingSize * 3).fill(0) }); } return arr; }).call(this), "output": sharedFloat32Array(this.embeddingSize * 3).fill(0) }, "normalize_2": sharedFloat32Array(this.embeddingSize * 3).fill(0), "feed_forward": { "grow": sharedFloat32Array((this.embeddingSize * 4) * 3).fill(0), "shrink": sharedFloat32Array(this.embeddingSize * 3).fill(0) } } }); var total_params_layer = 2 * this.embeddingSize + 3 * this.heads * (this.embeddingSize * this.embeddingSize + this.embeddingSize) + this.embeddingSize * (this.embeddingSize * this.heads) + this.embeddingSize + 2 * this.embeddingSize + this.embeddingSize * (this.embeddingSize * 4) + (this.embeddingSize * 4) + (this.embeddingSize * 4) * this.embeddingSize + this.embeddingSize; var params_done = 0; var last_percent = -percentagePrintInterval; for (var j = 0; j < this.embeddingSize; j++) { this.transformer["layers"][i]["weights"]["normalize_1"][j * 3] = random_range(this.weightsinitrange); this.transformer["layers"][i]["biases"]["normalize_1"][j * 3] = random_range(this.biasesinitrange); this.transformer["layers"][i]["weights"]["normalize_2"][j * 3] = random_range(this.weightsinitrange); this.transformer["layers"][i]["biases"]["normalize_2"][j * 3] = random_range(this.biasesinitrange); params_done += 4; var percent = Math.floor((params_done * 100) / total_params_layer); if (percent >= last_percent + percentagePrintInterval) { last_percent = Math.floor(percent / percentagePrintInterval) * percentagePrintInterval; ndprint(" Layer " + i + ": " + last_percent + "% complete"); } } for (var j = 0; j < this.heads; j++) { for (var k = 0; k < this.embeddingSize * this.embeddingSize; k++) { this.transformer["layers"][i]["weights"]["attention"]["heads"][j]["query"][k * 3] = random_range(this.weightsinitrange); this.transformer["layers"][i]["weights"]["attention"]["heads"][j]["key"][k * 3] = random_range(this.weightsinitrange); this.transformer["layers"][i]["weights"]["attention"]["heads"][j]["value"][k * 3] = random_range(this.weightsinitrange); params_done += 3; percent = Math.floor((params_done * 100) / total_params_layer); if (percent >= last_percent + percentagePrintInterval) { last_percent = Math.floor(percent / percentagePrintInterval) * percentagePrintInterval; ndprint(" Layer " + i + ": " + last_percent + "% complete"); } } for (var k = 0; k < this.embeddingSize; k++) { this.transformer["layers"][i]["biases"]["attention"]["heads"][j]["query"][k * 3] = random_range(this.biasesinitrange); this.transformer["layers"][i]["biases"]["attention"]["heads"][j]["key"][k * 3] = random_range(this.biasesinitrange); this.transformer["layers"][i]["biases"]["attention"]["heads"][j]["value"][k * 3] = random_range(this.biasesinitrange); params_done += 3; percent = Math.floor((params_done * 100) / total_params_layer); if (percent >= last_percent + percentagePrintInterval) { last_percent = Math.floor(percent / percentagePrintInterval) * percentagePrintInterval; ndprint(" Layer " + i + ": " + last_percent + "% complete"); } } } for (var j = 0; j < this.embeddingSize * (this.embeddingSize * this.heads); j++) { this.transformer["layers"][i]["weights"]["attention"]["output"][j * 3] = random_range(this.weightsinitrange); params_done += 1; percent = Math.floor((params_done * 100) / total_params_layer); if (percent >= last_percent + percentagePrintInterval) { last_percent = Math.floor(percent / percentagePrintInterval) * percentagePrintInterval; ndprint(" Layer " + i + ": " + last_percent + "% complete"); } } for (var j = 0; j < this.embeddingSize; j++) { this.transformer["layers"][i]["biases"]["attention"]["output"][j * 3] = random_range(this.biasesinitrange); params_done += 1; percent = Math.floor((params_done * 100) / total_params_layer); if (percent >= last_percent + percentagePrintInterval) { last_percent = Math.floor(percent / percentagePrintInterval) * percentagePrintInterval; ndprint(" Layer " + i + ": " + last_percent + "% complete"); } } for (var j = 0; j < this.embeddingSize * (this.embeddingSize * 4); j++) { this.transformer["layers"][i]["weights"]["feed_forward"]["grow"][j * 3] = random_range(this.weightsinitrange); params_done += 1; percent = Math.floor((params_done * 100) / total_params_layer); if (percent >= last_percent + percentagePrintInterval) { last_percent = Math.floor(percent / percentagePrintInterval) * percentagePrintInterval; ndprint(" Layer " + i + ": " + last_percent + "% complete"); } } for (var j = 0; j < this.embeddingSize * 4; j++) { this.transformer["layers"][i]["biases"]["feed_forward"]["grow"][j * 3] = random_range(this.biasesinitrange); params_done += 1; percent = Math.floor((params_done * 100) / total_params_layer); if (percent >= last_percent + percentagePrintInterval) { last_percent = Math.floor(percent / percentagePrintInterval) * percentagePrintInterval; ndprint(" Layer " + i + ": " + last_percent + "% complete"); } } for (var j = 0; j < (this.embeddingSize * 4) * this.embeddingSize; j++) { this.transformer["layers"][i]["weights"]["feed_forward"]["shrink"][j * 3] = random_range(this.weightsinitrange); params_done += 1; percent = Math.floor((params_done * 100) / total_params_layer); if (percent >= last_percent + percentagePrintInterval) { last_percent = Math.floor(percent / percentagePrintInterval) * percentagePrintInterval; ndprint(" Layer " + i + ": " + last_percent + "% complete"); } } for (var j = 0; j < this.embeddingSize; j++) { this.transformer["layers"][i]["biases"]["feed_forward"]["shrink"][j * 3] = random_range(this.biasesinitrange); params_done += 1; percent = Math.floor((params_done * 100) / total_params_layer); if (percent >= last_percent + percentagePrintInterval) { last_percent = Math.floor(percent / percentagePrintInterval) * percentagePrintInterval; ndprint(" Layer " + i + ": " + last_perce