cleanai
Version:
A fully standalone, terminal-based AI model CLI for training and inference
1,050 lines (1,023 loc) • 141 kB
JavaScript
// app.js
//
//* It is way better to download the better comments extension for vscode to view this file because
//* it adds color to the comments depending on how much you should read them thanks to the little symbols
//* like * ! and ? that I put after // on the comments.
// Credits
//
// Me ofc (willmil11)
// Claude
// Chatgpt
//
// Contact me
// Discord: willmil11
// Email: willmil111012@gmail.com
//
// Drop a star on the github repo if you like it :D
var fs = require("fs");
var process = require("process");
var json = JSON;
var path = require("path");
var os = require("os");
var readline = require("readline");
var { Tiktoken } = require("tiktoken/lite");
var cl100k_base = require("tiktoken/encoders/cl100k_base.json");
var uuid = require("uuid");
var args = process.argv.slice(2);
var config = {};
function inputWithTimeout(prompt, timeoutMs) {
return new Promise(function (resolve) {
var rl = readline.createInterface({
input: process.stdin,
output: process.stdout
});
var done = false;
// Timer to enforce the timeout
var timer = setTimeout(function () {
if (!done) {
done = true;
rl.close();
resolve(null); // Return null on timeout
}
}, timeoutMs);
// Ask the question
rl.question(prompt, function (answer) {
if (!done) {
done = true;
clearTimeout(timer);
rl.close();
resolve(answer);
}
});
});
}
function help(issue) {
if (issue === undefined) {
issue = "No args found.";
}
function spacing() {
return " ".repeat(("node " + process.argv[1]).length);
}
console.log("=====" + "=".repeat(issue.length) + "=====");
console.log("==== " + issue + " ====");
console.log("=====" + "=".repeat(issue.length) + "=====");
console.log("");
console.log("node " + process.argv[1] + " --new");
console.log(spacing() + " ".repeat(" --new".length) + "--config path/to/config.json");
console.log(spacing() + " ".repeat(" --new".length) + " ".repeat("--config path/to/config.json".length) + "--train");
console.log(spacing() + " ".repeat(" --new".length) + " ".repeat("--config path/to/config.json".length) + " ".repeat("--pretrain".length) + "--pretrain");
console.log(spacing() + " ".repeat(" --new".length) + " ".repeat("--config path/to/config.json".length) + "--pretrain");
console.log(spacing() + " ".repeat(" --new".length) + " ".repeat("--config path/to/config.json".length) + " ".repeat("--pretrain".length) + "--train");
console.log(spacing() + " ".repeat(" --new".length) + "--config path/to/config.json --verbose");
console.log(spacing() + " ".repeat(" --new".length) + " ".repeat("--config path/to/config.json --verbose".length) + "--train");
console.log(spacing() + " ".repeat(" --new".length) + " ".repeat("--config path/to/config.json --verbose".length) + " ".repeat("--pretrain".length) + "--pretrain");
console.log(spacing() + " ".repeat(" --new".length) + " ".repeat("--config path/to/config.json --verbose".length) + "--pretrain");
console.log(spacing() + " ".repeat(" --new".length) + " ".repeat("--config path/to/config.json --verbose".length) + " ".repeat("--pretrain".length) + "--train");
console.log(spacing() + " --load path/to/model.json");
console.log("");
}
var flag = null;
var VERBOSE = false;
var training__ = null;
var pretraining__ = null;
var config__ = false;
var skipnext = false;
var config_location = null;
var model_location = null;
if (args.length === 0) {
help();
process.exit(0);
} else {
for (var i = 0; i < args.length; i++) {
var arg = args[i];
if (skipnext) {
skipnext = false;
continue;
}
if (arg === "--new") {
if (flag === true) {
help("You can't specify --new multiple times.");
process.exit(0);
} else {
if (flag === false) {
help("You can't specify --new and --load at the same time.");
process.exit(0);
}
}
flag = true;
} else if (arg === "--load") {
if (flag === true) {
help("You can't specify --new and --load at the same time.");
process.exit(0);
} else {
if (flag === false) {
help("You can't specify --load multiple times.");
process.exit(0);
}
}
flag = false;
try {
if (args[i + 1] !== "--new") {
if (args[i + 1] !== "--train") {
if (args[i + 1] !== "--pretrain") {
if (args[i + 1] !== "--config") {
model_location = args[i + 1];
if (!fs.existsSync(model_location)) {
help("Model file " + model_location + " does not exist.");
process.exit(0);
}
if (!fs.statSync(model_location).isFile()) {
help("Model file " + model_location + " is not a file.");
process.exit(0);
}
if (model_location.slice(-5) !== ".json") {
help("Model file " + model_location + " is not a json file.");
process.exit(0);
}
} else {
help("You need to specify a model file after --load.");
process.exit(0);
}
} else {
help("You need to specify a model file after --load.");
process.exit(0);
}
} else {
help("You need to specify a model file after --load.");
process.exit(0);
}
} else {
help("You need to specify a model file after --load.");
process.exit(0);
}
} catch (e) {
help("You need to specify a model file after --load.");
process.exit(0);
}
} else if (arg === "--verbose") {
if (VERBOSE === true) {
help("You can't specify --verbose multiple times.");
process.exit(0);
} else {
VERBOSE = true;
}
} else if (arg === "--train") {
if (training__ === true) {
help("You can't specify --train multiple times.");
process.exit(0);
} else {
training__ = true;
}
} else if (arg === "--pretrain") {
if (pretraining__ === true) {
help("You can't specify --pretrain multiple times.");
process.exit(0);
} else {
pretraining__ = true;
}
} else if (arg === "--config") {
if (config__ === true) {
help("You can't specify --config multiple times.");
process.exit(0);
}
config__ = true;
try {
if (args[i + 1] !== "--new") {
if (args[i + 1] !== "--train") {
if (args[i + 1] !== "--pretrain") {
if (args[i + 1] !== "--config") {
config_location = args[i + 1];
if (!fs.existsSync(config_location)) {
help("Config file " + config_location + " does not exist.");
process.exit(0);
}
if (!fs.statSync(config_location).isFile()) {
help("Config file " + config_location + " is not a file.");
process.exit(0);
}
if (config_location.slice(-5) !== ".json") {
help("Config file " + config_location + " is not a json file.");
process.exit(0);
}
} else {
help("You need to specify a config file after --config.");
process.exit(0);
}
} else {
help("You need to specify a config file after --config.");
process.exit(0);
}
} else {
help("You need to specify a config file after --config.");
process.exit(0);
}
} else {
help("You need to specify a config file after --config.");
process.exit(0);
}
} catch (e) {
help("You need to specify a config file after --config.");
process.exit(0);
}
skipnext = true;
continue;
} else {
help("Argument " + arg + " not recognised.");
process.exit(0);
}
}
}
if (args.indexOf("--new") !== -1) {
if (args.indexOf("--config") === -1) {
help("You need to specify a config file with --config.");
process.exit(0);
} else {
if (args.indexOf("--train") === -1) {
if (args.indexOf("--pretrain") === -1) {
help("You need to specify either --train or --pretrain or both with --new.");
process.exit(0);
}
} else {
if (args.indexOf("--pretrain") === -1) {
if (args.indexOf("--train") === -1) {
help("You need to specify either --train or --pretrain or both with --new.");
process.exit(0);
}
}
}
}
}
if (args.indexOf("--new") === -1 && args.indexOf("--load") === -1) {
help("You need to specify either --new or --load.");
process.exit(0);
}
if (!VERBOSE) {
VERBOSE = false;
}
console.log("Arguments parsed successfully.");
if (args.indexOf("--new") !== -1) {
console.log("Reading and loading config file...");
var configtoparse;
try {
configtoparse = fs.readFileSync(config_location, "utf-8");
} catch (error) {
console.log("Failed to read config file, check if it's corrupted or if you don't have permissions.");
console.log("JavaScript error:");
console.log(String(error));
console.log("Exiting...");
process.exit(1);
}
try {
configtoparse = json.parse(configtoparse);
} catch (error) {
console.log("Failed to load json of config file, check if it's corrupted.");
console.log("JavaScript error:");
console.log(String(error));
console.log("Exiting...");
process.exit(1);
}
var keys = ["pre-training-paths", "training-dataset-path", "contextSize", "embeddingSize", "learningRate", "maxOutputSize", "layersAmount", "heads", "biasesinitrange", "embeddinginitrange"];
for (var k = 0; k < keys.length; k++) {
var key = keys[k];
if (key === "pre-training-paths" || key === "training-dataset-path") {
if (pretraining__ === null && training__ === null) {
console.log("Config file missing parameter " + key + ", add it.");
console.log("Exiting...");
process.exit(1);
}
} else {
if (!(key in configtoparse)) {
console.log("Config file missing parameter " + key + ", add it.");
console.log("Exiting...");
process.exit(1);
} else {
if (pretraining__) {
if (key === "pre-training-paths") {
if (!(Array.isArray(configtoparse[key]))) {
console.log("Config file parameter " + key + " must be an array of strings, not a " + typeof configtoparse[key]);
process.exit(1);
}
for (var j = 0; j < configtoparse[key].length; j++) {
if (typeof configtoparse[key][j] !== "string") {
console.log("Config file parameter " + key + " must be an array of strings, not an array of " + typeof configtoparse[key][j]);
process.exit(1);
}
}
}
}
if (training__) {
if (key === "training-dataset-path") {
if (typeof configtoparse[key] !== "string") {
console.log("Config file parameter " + key + " must be a string, not a " + typeof configtoparse[key]);
process.exit(1);
}
}
}
if (key === "contextSize") {
if (typeof configtoparse[key] !== "number" || Math.floor(configtoparse[key]) !== configtoparse[key]) {
console.log("Config file parameter " + key + " must be an int, not a " + typeof configtoparse[key]);
process.exit(1);
}
}
if (key === "embeddingSize") {
if (typeof configtoparse[key] !== "number" || Math.floor(configtoparse[key]) !== configtoparse[key]) {
console.log("Config file parameter " + key + " must be an int, not a " + typeof configtoparse[key]);
process.exit(1);
}
}
if (key === "learningRate") {
if (typeof configtoparse[key] !== "number") {
console.log("Config file parameter " + key + " must be a float, not a " + typeof configtoparse[key]);
process.exit(1);
}
}
if (key === "maxOutputSize") {
if (typeof configtoparse[key] !== "number" || Math.floor(configtoparse[key]) !== configtoparse[key]) {
console.log("Config file parameter " + key + " must be an int, not a " + typeof configtoparse[key]);
process.exit(1);
}
}
if (key === "layersAmount") {
if (typeof configtoparse[key] !== "number" || Math.floor(configtoparse[key]) !== configtoparse[key]) {
console.log("Config file parameter " + key + " must be an int, not a " + typeof configtoparse[key]);
process.exit(1);
}
}
if (key === "heads") {
if (typeof configtoparse[key] !== "number" || Math.floor(configtoparse[key]) !== configtoparse[key]) {
console.log("Config file parameter " + key + " must be an int, not a " + typeof configtoparse[key]);
process.exit(1);
}
}
if (key === "biasesinitrange") {
if (!Array.isArray(configtoparse[key])) {
console.log("Config file parameter " + key + " must be an array of two floats, not a " + typeof configtoparse[key]);
process.exit(1);
}
if (configtoparse[key].length !== 2) {
console.log("Config file parameter " + key + " must be an array of two floats, not an array of " + configtoparse[key].length + " floats");
process.exit(1);
}
for (var j = 0; j < configtoparse[key].length; j++) {
if (typeof configtoparse[key][j] !== "number") {
console.log("Config file parameter " + key + " must be an array of two floats, not an array of " + typeof configtoparse[key][j]);
process.exit(1);
}
}
}
if (key === "embeddinginitrange") {
if (!Array.isArray(configtoparse[key])) {
console.log("Config file parameter " + key + " must be an array of two floats, not a " + typeof configtoparse[key]);
process.exit(1);
}
if (configtoparse[key].length !== 2) {
console.log("Config file parameter " + key + " must be an array of two floats, not an array of " + configtoparse[key].length + " floats");
process.exit(1);
}
for (var j = 0; j < configtoparse[key].length; j++) {
if (typeof configtoparse[key][j] !== "number") {
console.log("Config file parameter " + key + " must be an array of two floats, not an array of " + typeof configtoparse[key][j]);
process.exit(1);
}
}
}
if (pretraining__) {
if (key === "pre-train-epochs") {
if (typeof configtoparse[key] !== "number" || Math.floor(configtoparse[key]) !== configtoparse[key]) {
console.log("Config file parameter " + key + " must be an int, not a " + typeof configtoparse[key]);
process.exit(1);
}
}
}
if (training__) {
if (key === "train-epochs") {
if (typeof configtoparse[key] !== "number" || Math.floor(configtoparse[key]) !== configtoparse[key]) {
console.log("Config file parameter " + key + " must be an int, not a " + typeof configtoparse[key]);
process.exit(1);
}
}
}
if (pretraining__) {
if (key === "pre-train-optimizer") {
if (typeof configtoparse[key] !== "string") {
console.log("Config file parameter " + key + " must be a string, not a " + typeof configtoparse[key]);
process.exit(1);
}
}
}
if (training__) {
if (key === "train-optimizer") {
if (typeof configtoparse[key] !== "string") {
console.log("Config file parameter " + key + " must be a string, not a " + typeof configtoparse[key]);
process.exit(1);
}
}
}
if (key === "antiOverfittingOptimisations") {
if (typeof configtoparse[key] !== "boolean") {
console.log("Config file parameter " + key + " must be a boolean, not a " + typeof configtoparse[key]);
process.exit(1);
}
}
if (key === "microbatchSize") {
if (typeof configtoparse[key] !== "number" || Math.floor(configtoparse[key]) !== configtoparse[key]) {
console.log("Config file parameter " + key + " must be an int, not a " + typeof configtoparse[key]);
process.exit(1);
}
}
config[key] = configtoparse[key];
}
}
config = configtoparse;
console.log("Config file loaded successfully.");
}
var ndprint = console.log;
var print = function() {
if (VERBOSE) {
ndprint.apply(null, arguments);
}
};
function timer_() {
// Generate random 32 char string as timer id using uuid v4 without dashes
var timer_id = uuid.v4().replace(/-/g, "");
timers.push({
"id": timer_id,
"start": Date.now() / 1000
});
return timer_id;
}
function timer_end(timer_id) {
for (var i = 0; i < timers.length; i++) {
if (timers[i]["id"] === timer_id) {
timers[i]["end"] = Date.now() / 1000;
// Return time in ms
return (timers[i]["end"] - timers[i]["start"]) * 1000;
}
}
return null;
}
function random_range(range) {
// Return random float between range[0] and range[1] (inclusive)
return Math.random() * (range[1] - range[0]) + range[0];
}
var timers = [];
class Transformer {
constructor(newFlag, parameters, path, vocab_path) {
if (vocab_path === undefined) { vocab_path = "vocabulary.json"; }
this.adam_params = {
'beta1': 0.9,
'beta2': 0.98, // From 0.999 to 0.98 to match paper
'epsilon': 1e-9,
't': 0
};
ndprint("Trying to read vocabulary file...");
try {
this.vocab = json.parse(fs.readFileSync("vocabulary.json", "utf-8"));
} catch (e) {
console.log("Failed to read vocabulary file, creating error...");
throw new Error("Failed to read vocabulary file");
}
ndprint("Successfully read vocabulary file");
ndprint("Computing lookup table...");
this.id_to_token = {};
for (var i = 0; i < this.vocab.length; i++) {
var tok = this.vocab[i];
this.id_to_token[tok[1]] = tok[0];
}
ndprint("Computed lookup table");
this.encoder = new Tiktoken(
cl100k_base.bpe_ranks,
{ "<|endoftext|>": 100257 }, // example special token
cl100k_base.pat_str
);
this.temperature = 0.7;
if (newFlag) {
ndprint("Initializing model...");
// Calculate total parameters
var total_params = (
this.vocab.length * parameters["embeddingSize"] +
parameters["contextSize"] * parameters["embeddingSize"] +
parameters["layersAmount"] * (
2 * parameters["embeddingSize"] +
parameters["heads"] * (3 * parameters["embeddingSize"] * parameters["embeddingSize"] / parameters["heads"] + 3 * parameters["embeddingSize"] / parameters["heads"]) +
parameters["embeddingSize"] * parameters["embeddingSize"] + parameters["embeddingSize"] +
2 * parameters["embeddingSize"] +
parameters["embeddingSize"] * (4 * parameters["embeddingSize"]) + 4 * parameters["embeddingSize"] +
(4 * parameters["embeddingSize"]) * parameters["embeddingSize"] + parameters["embeddingSize"]
)
);
var total_ram = total_params * 8; // (64 bit floats take up 8 bytes each)
ndprint("Model is of size " + total_params + " parameters");
total_ram = total_params * 8;
ndprint(" ~" + (total_params / 1e9).toFixed(2) + "b parameters");
ndprint("");
var adam_ram = total_params * 3 * 8; // Assuming 3 times the parameters for Adam
ndprint("Would cost the equivalent of " + (total_params * 3) + " parameters if trained with Adam");
ndprint(" ~" + ((total_params * 3) / 1e9).toFixed(2) + "b parameters if trained with adam");
ndprint("");
var sgd_momentum_ram = total_params * 2 * 8; // Assuming 2 times the parameters for SGD with momentum
ndprint("Would cost the equivalent of " + (total_params * 2) + " parameters if trained with SGD with momentum");
ndprint(" ~" + ((total_params * 2) / 1e9).toFixed(2) + "b parameters if trained with SGD with momentum");
ndprint("");
ndprint("Would not cost more than the original size of the model if trained with vanilla SGD");
var sgtimer = timer_();
ndprint("Initializing parameters...");
var timer = timer_();
this.contextSize = parameters["contextSize"];
this.embeddingSize = parameters["embeddingSize"];
this.learningRate = parameters["learningRate"];
this.maxOutputSize = parameters["maxOutputSize"];
this.layersAmount = parameters["layersAmount"];
if ("use_he_init" in parameters && parameters["use_he_init"]) {
this.weightsinitrange = this.he_init(this.embeddingSize);
console.log("Using He initialization with range: " + this.weightsinitrange);
} else {
this.weightsinitrange = parameters["weightsinitrange"];
}
this.biasesinitrange = parameters["biasesinitrange"];
this.heads = parameters["heads"];
this.embeddinginitrange = parameters["embeddinginitrange"];
this.transformer = {};
this.step_num = 0;
ndprint("Initialized parameters in", timer_end(timer), "ms");
var percentagePrintInterval = 10;
ndprint("Initializing layers...");
var gtimer = timer_();
this.transformer["layers"] = [];
for (var i = 0; i < this.layersAmount; i++) {
var timer_layer = timer_();
console.log("Initializing weights and biases for layer " + i);
this.transformer["layers"].push({
"weights": {
"normalize_1": [],
"attention": {
"heads": (function() {
var arr = [];
for (var h = 0; h < this.heads; h++) {
arr.push({
"query": [],
"key": [],
"value": []
});
}
return arr;
}).call(this),
"output": []
},
"normalize_2": [],
"feed_forward": {
"grow": [],
"shrink": []
}
},
"biases": {
"normalize_1": [],
"attention": {
"heads": (function() {
var arr = [];
for (var h = 0; h < this.heads; h++) {
arr.push({
"query": [],
"key": [],
"value": []
});
}
return arr;
}).call(this),
"output": []
},
"normalize_2": [],
"feed_forward": {
"grow": [],
"shrink": []
}
}
});
var total_params_layer = 2 * this.embeddingSize + 3 * this.heads * (this.embeddingSize * this.embeddingSize + this.embeddingSize) + this.embeddingSize * (this.embeddingSize * this.heads) + this.embeddingSize + 2 * this.embeddingSize + this.embeddingSize * (this.embeddingSize * 4) + (this.embeddingSize * 4) + (this.embeddingSize * 4) * this.embeddingSize + this.embeddingSize;
var params_done = 0;
var last_percent = -percentagePrintInterval;
for (var j = 0; j < this.embeddingSize; j++) {
this.transformer["layers"][i]["weights"]["normalize_1"].push([random_range(this.weightsinitrange), 0, 0]);
this.transformer["layers"][i]["biases"]["normalize_1"].push([random_range(this.biasesinitrange), 0, 0]);
this.transformer["layers"][i]["weights"]["normalize_2"].push([random_range(this.weightsinitrange), 0, 0]);
this.transformer["layers"][i]["biases"]["normalize_2"].push([random_range(this.biasesinitrange), 0, 0]);
params_done += 4;
var percent = Math.floor((params_done * 100) / total_params_layer);
if (percent >= last_percent + percentagePrintInterval) {
last_percent = Math.floor(percent / percentagePrintInterval) * percentagePrintInterval;
ndprint(" Layer " + i + ": " + last_percent + "% complete");
}
}
for (var j = 0; j < this.heads; j++) {
for (var k = 0; k < this.embeddingSize * this.embeddingSize; k++) {
this.transformer["layers"][i]["weights"]["attention"]["heads"][j]["query"].push([random_range(this.weightsinitrange), 0, 0]);
this.transformer["layers"][i]["weights"]["attention"]["heads"][j]["key"].push([random_range(this.weightsinitrange), 0, 0]);
this.transformer["layers"][i]["weights"]["attention"]["heads"][j]["value"].push([random_range(this.weightsinitrange), 0, 0]);
params_done += 3;
percent = Math.floor((params_done * 100) / total_params_layer);
if (percent >= last_percent + percentagePrintInterval) {
last_percent = Math.floor(percent / percentagePrintInterval) * percentagePrintInterval;
ndprint(" Layer " + i + ": " + last_percent + "% complete");
}
}
for (var k = 0; k < this.embeddingSize; k++) {
this.transformer["layers"][i]["biases"]["attention"]["heads"][j]["query"].push([random_range(this.biasesinitrange), 0, 0]);
this.transformer["layers"][i]["biases"]["attention"]["heads"][j]["key"].push([random_range(this.biasesinitrange), 0, 0]);
this.transformer["layers"][i]["biases"]["attention"]["heads"][j]["value"].push([random_range(this.biasesinitrange), 0, 0]);
params_done += 3;
percent = Math.floor((params_done * 100) / total_params_layer);
if (percent >= last_percent + percentagePrintInterval) {
last_percent = Math.floor(percent / percentagePrintInterval) * percentagePrintInterval;
ndprint(" Layer " + i + ": " + last_percent + "% complete");
}
}
}
for (var j = 0; j < this.embeddingSize * (this.embeddingSize * this.heads); j++) {
this.transformer["layers"][i]["weights"]["attention"]["output"].push([random_range(this.weightsinitrange), 0, 0]);
params_done += 1;
percent = Math.floor((params_done * 100) / total_params_layer);
if (percent >= last_percent + percentagePrintInterval) {
last_percent = Math.floor(percent / percentagePrintInterval) * percentagePrintInterval;
ndprint(" Layer " + i + ": " + last_percent + "% complete");
}
}
for (var j = 0; j < this.embeddingSize; j++) {
this.transformer["layers"][i]["biases"]["attention"]["output"].push([random_range(this.biasesinitrange), 0, 0]);
params_done += 1;
percent = Math.floor((params_done * 100) / total_params_layer);
if (percent >= last_percent + percentagePrintInterval) {
last_percent = Math.floor(percent / percentagePrintInterval) * percentagePrintInterval;
ndprint(" Layer " + i + ": " + last_percent + "% complete");
}
}
for (var j = 0; j < this.embeddingSize * (this.embeddingSize * 4); j++) {
this.transformer["layers"][i]["weights"]["feed_forward"]["grow"].push([random_range(this.weightsinitrange), 0, 0]);
params_done += 1;
percent = Math.floor((params_done * 100) / total_params_layer);
if (percent >= last_percent + percentagePrintInterval) {
last_percent = Math.floor(percent / percentagePrintInterval) * percentagePrintInterval;
ndprint(" Layer " + i + ": " + last_percent + "% complete");
}
}
for (var j = 0; j < this.embeddingSize * 4; j++) {
this.transformer["layers"][i]["biases"]["feed_forward"]["grow"].push([random_range(this.biasesinitrange), 0, 0]);
params_done += 1;
percent = Math.floor((params_done * 100) / total_params_layer);
if (percent >= last_percent + percentagePrintInterval) {
last_percent = Math.floor(percent / percentagePrintInterval) * percentagePrintInterval;
ndprint(" Layer " + i + ": " + last_percent + "% complete");
}
}
for (var j = 0; j < (this.embeddingSize * 4) * this.embeddingSize; j++) {
this.transformer["layers"][i]["weights"]["feed_forward"]["shrink"].push([random_range(this.weightsinitrange), 0, 0]);
params_done += 1;
percent = Math.floor((params_done * 100) / total_params_layer);
if (percent >= last_percent + percentagePrintInterval) {
last_percent = Math.floor(percent / percentagePrintInterval) * percentagePrintInterval;
ndprint(" Layer " + i + ": " + last_percent + "% complete");
}
}
for (var j = 0; j < this.embeddingSize; j++) {
this.transformer["layers"][i]["biases"]["feed_forward"]["shrink"].push([random_range(this.biasesinitrange), 0, 0]);
params_done += 1;
percent = Math.floor((params_done * 100) / total_params_layer);
if (percent >= last_percent + percentagePrintInterval) {
last_percent = Math.floor(percent / percentagePrintInterval) * percentagePrintInterval;
ndprint(" Layer " + i + ": " + last_percent + "% complete");
}
}
console.log("Initialized weights and biases for layer " + i + " in " + timer_end(timer_layer) + " ms");
}
ndprint("Initialized layers in", timer_end(gtimer), "ms");
ndprint("Initializing embeddings...");
timer = timer_();
this.transformer["embeddings"] = [];
params_done = 0;
total_params_layer = this.vocab.length * this.embeddingSize;
last_percent = -percentagePrintInterval;
for (var i = 0; i < this.vocab.length; i++) {
var embedding = [];
for (var j = 0; j < this.embeddingSize; j++) {
embedding.push([random_range(this.embeddinginitrange), 0, 0]);
params_done += 1;
percent = Math.floor((params_done * 100) / total_params_layer);
if (percent >= last_percent + percentagePrintInterval) {
last_percent = Math.floor(percent / percentagePrintInterval) * percentagePrintInterval;
ndprint(" Embeddings: " + last_percent + "% complete");
}
}
this.transformer["embeddings"].push(embedding);
}
ndprint("Initialized embeddings in", timer_end(timer), "ms");
ndprint("Initializing vocabulary projection weights and biases...");
timer = timer_();
this.transformer["vocab_projection"] = {
"weights": [],
"biases": []
};
params_done = 0;
total_params_layer = this.vocab.length * this.embeddingSize + this.vocab.length;
last_percent = -percentagePrintInterval;
for (var i = 0; i < this.vocab.length * this.embeddingSize; i++) {
this.transformer["vocab_projection"]["weights"].push([random_range(this.weightsinitrange), 0, 0]);
params_done += 1;
percent = Math.floor((params_done * 100) / total_params_layer);
if (percent >= last_percent + percentagePrintInterval) {
last_percent = Math.floor(percent / percentagePrintInterval) * percentagePrintInterval;
ndprint(" Vocab projection: " + last_percent + "% complete");
}
}
for (var i = 0; i < this.vocab.length; i++) {
this.transformer["vocab_projection"]["biases"].push([random_range(this.biasesinitrange), 0, 0]);
params_done += 1;
percent = Math.floor((params_done * 100) / total_params_layer);
if (percent >= last_percent + percentagePrintInterval) {
last_percent = Math.floor(percent / percentagePrintInterval) * percentagePrintInterval;
ndprint(" Vocab projection: " + last_percent + "% complete");
}
}
ndprint("Initialized vocabulary projection weights and biases in", timer_end(timer), "ms");
ndprint("Successfully initialized model in", timer_end(sgtimer), "ms");
} else {
ndprint("Reading model from file...");
var timer = timer_();
try {
var model = json.parse(fs.readFileSync(path, "utf-8"));
this.transformer = model["transformer"];
this.contextSize = model["contextSize"];
this.embeddingSize = model["embeddingSize"];
this.learningRate = model["learningRate"];
this.maxOutputSize = model["maxOutputSize"];
this.layersAmount = model["layersAmount"];
this.weightsinitrange = model["weightsinitrange"];
this.biasesinitrange = model["biasesinitrange"];
this.heads = model["heads"];
this.embeddinginitrange = model["embeddinginitrange"];
if ("adam_params" in model) {
this.adam_params = model["adam_params"];
} else {
this.adam_params = {
'beta1': 0.9,
'beta2': 0.98,
'epsilon': 1e-9,
't': 0
};
}
if ("step_num" in model) {
this.step_num = model["step_num"];
} else {
this.step_num = 0;
}
} catch (e) {
console.log("Failed to read model file, creating error...");
throw new Error("Failed to read model file");
}
ndprint("Successfully read model from file in", timer_end(timer), "ms");
}
}
he_init(fan_in) {
var scale = Math.sqrt(2.0 / fan_in);
return [-scale, scale];
}
tokenize(text) {
var timer_id = timer_();
console.log("Tokenizing text...");
// Strip out the literal endoftext token string — it breaks tiktoken in Node.js
text = text.replace(/<\|endoftext\|>/g, "");
var token_ids = this.encoder.encode(text, {
allowed_special: [],
encode_special_tokens: false
});
var result = [];
for (var i = 0; i < token_ids.length; i++) {
var id = token_ids[i];
var token_str = (id in this.id_to_token) ? this.id_to_token[id] : "unknown";
result.push([token_str, id]);
}
console.log("Tokenized in " + timer_end(timer_id) + " ms");
return result;
}
calculate_positional_encoding(sequence_length) {
var positional_encodings = [];
for (var pos = 0; pos < sequence_length; pos++) {
var embedding = [];
for (var i = 0; i < this.embeddingSize; i++) {
var denominator = Math.pow(10000, (2 * Math.floor(i / 2)) / this.embeddingSize);
if (i % 2 === 0) {
embedding.push(Math.sin(pos / denominator));
} else {
embedding.push(Math.cos(pos / denominator));
}
}
positional_encodings.push(embedding);
}
return positional_encodings;
}
get_embedding(token_id) {
var vocab_idx = null;
for (var i = 0; i < this.vocab.length; i++) {
if (this.vocab[i][1] === token_id) {
vocab_idx = i;
break;
}
}
if (vocab_idx !== null) {
return this.transformer["embeddings"][vocab_idx];
} else {
var unknown_idx = null;
for (var i = 0; i < this.vocab.length; i++) {
if (this.vocab[i][0] === "unknown" && this.vocab[i][1] === 16476) {
unknown_idx = i;
break;
}
}
if (unknown_idx !== null) {
console.log("Warning: Token ID " + token_id + " not found in vocabulary, using unknown token instead");
return this.transformer["embeddings"][unknown_idx];
} else {
console.log("Warning: Token ID " + token_id + " not found in vocabulary, using first token as fallback");
return this.transformer["embeddings"][0];
}
}
}
normalize_vector(vector) {
var vector_list = [];
for (var i = 0; i < vector.length; i++) {
var x = vector[i];
try {
if (typeof x === "number") {
vector_list.push(Number(x));
} else if (Array.isArray(x) && x.length > 0) {
vector_list.push(Number(x[0]));
} else {
vector_list.push(Number(x));
}
} catch (e) {
vector_list.push(0.0);
}
}
var mean = vector_list.reduce(function(a, b) { return a + b; }, 0) / vector_list.length;
var squared_diffs = [];
for (var i = 0; i < vector_list.length; i++) {
var diff = vector_list[i] - mean;
if (diff > 1e6) { diff = 1e6; }
else if (diff < -1e6) { diff = -1e6; }
squared_diffs.push(diff * diff);
}
var variance = squared_diffs.reduce(function(a, b) { return a + b; }, 0) / vector_list.length;
var std = Math.sqrt(variance + 1e-10);
if (std < 1e-6) {
var zeros = [];
for (var i = 0; i < vector_list.length; i++) { zeros.push(0.0); }
return zeros;
}
var normalized = [];
for (var i = 0; i < vector_list.length; i++) {
var norm_val = (vector_list[i] - mean) / std;
if (norm_val > 10.0) { norm_val = 10.0; }
else if (norm_val < -10.0) { norm_val = -10.0; }
normalized.push(norm_val);
}
return normalized;
}
dot_product(vec1, vec2) {
var sum = 0;
for (var i = 0; i < vec1.length; i++) {
sum += vec1[i] * vec2[i];
}
return sum;
}
add_vectors(vec1, vec2) {
var result = [];
for (var i = 0; i < vec1.length; i++) {
result.push(vec1[i] + vec2[i]);
}
return result;
}
softmax(scores) {
var float_scores = [];
for (var i = 0; i < scores.length; i++) {
try {
float_scores.push(Number(scores[i]));
} catch (e) {
float_scores.push(0.0);
}
}
var max_score = Math.max.apply(null, float_scores);
var exp_scores = [];
for (var i = 0; i < float_scores.length; i++) {
exp_scores.push(Math.exp(float_scores[i] - max_score));
}
var sum_exp = exp_scores.reduce(function(a, b) { return a + b; }, 0);
if (sum_exp === 0) {
var equal = [];
for (var i = 0; i < float_scores.length; i++) { equal.push(1.0 / float_scores.length); }
return equal;
}
var probs = [];
for (var i = 0; i < exp_scores.length; i++) {
probs.push(exp_scores[i] / sum_exp);
}
return probs;
}
save(path_out) {
if (path_out === undefined) { path_out = "model.json"; }
var transformer_obj = {};
transformer_obj["contextSize"] = this.contextSize;
transformer_obj["embeddingSize"] = this.embeddingSize;
transformer_obj["learningRate"] = this.learningRate;
transformer_obj["maxOutputSize"] = this.maxOutputSize;
transformer_obj["layersAmount"] = this.layersAmount;
transformer_obj["heads"] = this.heads;
transformer_obj["weightsinitrange"] = this.weightsinitrange;
transformer_obj["biasesinitrange"] = this.biasesinitrange;
transformer_obj["embeddinginitrange"] = this.embeddinginitrange;
transformer_obj["vocab"] = this.vocab;
transformer_obj["transformer"] = this.transformer;
transformer_obj["adam_params"] = {
'beta1': this.adam_params['beta1'],
'beta2': this.adam_params['beta2'],
'epsilon': this.adam_params['epsilon'],
't': this.adam_params['t']
};
transformer_obj["step_num"] = this.step_num;
fs.writeFileSync(path_out, json.stringify(transformer_obj));
ndprint("Model saved to", path_out);
}
calculate_loss(predicted_scores, target_token_id) {
var predicted_probs = this.softmax(predicted_scores);
var epsilon = 0;
if (config["antiOverfittingOptimisations"]) {
epsilon = 0.1;
}
var vocab_size = this.vocab.length;
var target_distribution = [];
for (var i = 0; i < vocab_size; i++) {
target_distribution.push(epsilon / (vocab_size - 1));
}
var target_idx = null;
for (var i = 0; i < this.vocab.length; i++) {
if (this.vocab[i][1] === target_token_id) {
target_idx = i;
break;
}
}
if (target_idx === null) {
console.log("Warning: Token ID " + target_token_id + " not found in vocabulary");
target_idx = 0;
}
target_distribution[target_idx] = 1.0 - epsilon;
var loss = 0;
for (var i = 0; i < vocab_size; i++) {
if (predicted_probs[i] > 0) {
loss -= target_distribution[i] * Math.log(predicted_probs[i]);
}
}
return loss;
}
initialize_zero_gradients(structure) {
if (Array.isArray(structure)) {
if (structure.length > 0 && Array.isArray(structure[0]) && structure[0].length === 3) {
var newArr = [];
for (var i = 0; i < structure.length; i++) {
newArr.push([0, 0, 0]);
}
return newArr;
}
return [0, 0, 0];
} else if (typeof structure === "object") {
var zero_dict = {};
for (var key in structure) {
if (key === "heads") {
var arr = [];
for (var h = 0; h < this.heads; h++) {
arr.push({
"query": (function() {
var inner = [];
for (var i = 0; i < this.embeddingSize * this.embeddingSize; i++) { inner.push([0, 0, 0]); }
return inner;
}).call(this),
"key": (function() {
var inner = [];
for (var i = 0; i < this.embeddingSize * this.embeddingSize; i++) { inner.push([0, 0, 0]); }
return inner;
}).call(this),
"value": (function() {
var inner = [];
for (var i = 0; i < this.embeddingSize * this.embeddingSize; i++) { inner.push([0, 0, 0]); }
return inner;
}).call(this)
});
}
zero_dict[key] = arr;
} else {
zero_dict[key] = this.initialize_zero_gradients(structure[key]);
}
}
return zero_dict;
}
return [0, 0, 0];
}
add_in_place(target, source) {
const targetType = typeof target;
const sourceType = type