UNPKG

simplegrad

Version:

Simple reverse mode automatic differentiation of scalar values in javascript

366 lines (301 loc) 7.93 kB
import {test} from 'tap'; import {Variable, ReferenceVariable, ValueStorage} from '../index.js' test('it can multiply stuff', (t) => { let vs = new ValueStorage(); let a = new Variable(vs); let b = new Variable(vs); let out = a.add(b).mul(a); out.compile(); a.setValue(2); b.setValue(3); out.forward(); t.equal(out.getValue(), 10); out.setGradient(1) out.backward(); t.equal(out.getGradient(), 1); // out = (a + b) * a; // a.grad = 2 * a + b; t.equal(a.getGradient(), 2 * a.getValue() + b.getValue()); // b.grad = a; t.equal(b.getGradient(), a.getValue()); t.end(); }); test('it can subtract variable', t => { let vs = new ValueStorage(); let a = new Variable(vs); let b = new Variable(vs); let out = a.sub(b); out.compile(); a.setValue(2); b.setValue(3); out.forward(); t.equal(out.getValue(), -1); out.setGradient(2); out.backward(); t.equal(out.getGradient(), 2); t.equal(a.getGradient(), 2); t.equal(b.getGradient(), -2); t.end(); }); test('it can subtract constant', t => { let vs = new ValueStorage(); let a = new Variable(vs); let out = a.sub(3); out.compile(); a.setValue(2); out.forward(); t.equal(out.getValue(), -1); out.setGradient(2); out.backward(); t.equal(out.getGradient(), 2); t.equal(a.getGradient(), 2); t.end(); }); test('it can mul constant', t => { let vs = new ValueStorage(); let a = new Variable(vs); let out = a.mul(3); out.compile(); a.setValue(2); out.forward(); t.equal(out.getValue(), 6); out.setGradient(2); out.backward(); t.equal(out.getGradient(), 2); t.equal(a.getGradient(), 2 * 3); t.end(); }); test('it can add constant', t => { let vs = new ValueStorage(); let a = new Variable(vs); let out = a.add(3); out.compile(); a.setValue(2); out.forward(); t.equal(out.getValue(), 5); out.setGradient(2); out.backward(); t.equal(out.getGradient(), 2); t.equal(a.getGradient(), 2); t.end(); }); test('it can pow variable', t => { let vs = new ValueStorage(); let a = new Variable(vs); let b = new Variable(vs); let out = a.pow(b); out.compile(); a.setValue(2); b.setValue(3); out.forward(); t.equal(out.getValue(), 2 ** 3); let globalGradient = 2; out.setGradient(globalGradient); out.backward(); t.equal(out.getGradient(), 2); // gradient of (a ^ b) by a is (b * a ^ (b - 1)) t.equal(a.getGradient(), 3 * (2 ** (3 - 1)) * globalGradient); // gradient of (a ^ b) by b is (log(a) * a ^ b) t.equal(b.getGradient(), Math.log(2) * (2 ** 3) * globalGradient); t.end(); }) test('it can pow variable', t => { let vs = new ValueStorage(); let a = new Variable(vs); let b = new Variable(vs); let out = a.pow(b); out.compile(); a.setValue(2); b.setValue(3); out.forward(); t.equal(out.getValue(), 2 ** 3); let globalGradient = 2; out.setGradient(globalGradient); out.backward(); t.equal(out.getGradient(), 2); // gradient of (a ^ b) by a is (b * a ^ (b - 1)) t.equal(a.getGradient(), 3 * (2 ** (3 - 1)) * globalGradient); // gradient of (a ^ b) by b is (log(a) * a ^ b) t.equal(b.getGradient(), Math.log(2) * (2 ** 3) * globalGradient); t.end(); }); test('it can pow constant', t => { let vs = new ValueStorage(); let a = new Variable(vs); let out = a.pow(3); out.compile(); a.setValue(2); out.forward(); t.equal(out.getValue(), 2 ** 3); let globalGradient = 2; out.setGradient(globalGradient); out.backward(); t.equal(out.getGradient(), 2); t.equal(a.getGradient(), 3 * (2 ** (3 - 1)) * globalGradient); t.end(); }) test('it can divide by variable', t => { let vs = new ValueStorage(); let a = new Variable(vs); let b = new Variable(vs); let out = a.div(b); out.compile(); a.setValue(2); b.setValue(3); out.forward(); t.equal(out.getValue(), 2 / 3); let globalGradient = 2; out.setGradient(globalGradient); out.backward(); t.equal(out.getGradient(), 2); // gradient of (a / b) by a is (1 / b) t.equal(a.getGradient(), 1 / 3 * globalGradient); // gradient of (a / b) by b is (-a / b ^ 2) t.equal(b.getGradient(), -2 / (3 * 3) * globalGradient); t.end(); }); test('it can divide by constant', t => { let vs = new ValueStorage(); let a = new Variable(vs); let out = a.div(3); out.compile(); a.setValue(2); out.forward(); t.equal(out.getValue(), 2 / 3); let globalGradient = 2; out.setGradient(globalGradient); out.backward(); t.equal(out.getGradient(), 2); t.equal(a.getGradient(), 1 / 3 * globalGradient); t.end(); }); test('it can get cosine', t => { let vs = new ValueStorage(); let a = new Variable(vs); let out = a.cos(); out.compile(); a.setValue(Math.PI / 2); out.forward(); t.equal(out.getValue(), Math.cos(Math.PI / 2)); out.setGradient(2); out.backward(); t.equal(out.getGradient(), 2); t.equal(a.getGradient(), -2 * Math.sin(Math.PI / 2)); t.end(); }) test('it can get sine', t => { let vs = new ValueStorage(); let a = new Variable(vs); let out = a.sin(); out.compile(); a.setValue(Math.PI / 2); out.forward(); t.equal(out.getValue(), Math.sin(Math.PI / 2)); out.setGradient(2); out.backward(); t.equal(out.getGradient(), 2); t.equal(a.getGradient(), 2 * Math.cos(Math.PI / 2)); t.end(); }) test('it can get abs', t => { let vs = new ValueStorage(); let a = new Variable(vs); let out = a.abs(); out.compile(); a.setValue(-2); out.forward(); t.equal(out.getValue(), 2); out.setGradient(2); out.backward(); t.equal(out.getGradient(), 2); t.equal(a.getGradient(), 2 * Math.sign(-2)); t.end(); }); test('it can get exp()', t => { let vs = new ValueStorage(); let a = new Variable(vs); let out = a.exp(); out.compile(); a.setValue(2); out.forward(); t.equal(out.getValue(), Math.exp(2)); out.setGradient(2); out.backward(); t.equal(out.getGradient(), 2); t.equal(a.getGradient(), 2 * Math.exp(2)); t.end(); }); test('it can get ReLU()', t => { let vs = new ValueStorage(); let a = new Variable(vs); let out = a.ReLU(); out.compile(); a.setValue(-2); out.forward(); t.equal(out.getValue(), 0); out.setGradient(2); out.backward(); t.equal(out.getGradient(), 2); t.equal(a.getGradient(), 0); t.end(); }) test('it can get ELU', t => { let vs = new ValueStorage(); let a = new Variable(vs); let out = a.ELU(); out.compile(); a.setValue(-2); out.forward(); t.equal(out.getValue(), Math.exp(-2) - 1); out.setGradient(2); out.backward(); t.equal(out.getGradient(), 2); t.equal(a.getGradient(), Math.exp(-2) * 2); t.end(); }); test('it can get sigmoid', t => { let vs = new ValueStorage(); let a = new Variable(vs); let out = a.sigmoid(); out.compile(); a.setValue(-2); out.forward(); let expected = 1 / (1 + Math.exp(2)); t.equal(out.getValue(), expected); out.setGradient(2); out.backward(); t.equal(out.getGradient(), 2); t.equal(a.getGradient(), 2 * expected * (1 - expected)); t.end(); }); test('it can get tanh', t => { let vs = new ValueStorage(); let a = new Variable(vs); let out = a.tanh(); out.compile(); a.setValue(-2); out.forward(); let expected = Math.tanh(-2); t.equal(out.getValue(), expected); out.setGradient(2); out.backward(); t.equal(out.getGradient(), 2); t.equal(a.getGradient(), 2 * (1 - expected * expected)); t.end(); }); test('it can reference another variable', t => { let vs = new ValueStorage(); let a = new Variable(vs); let b = new Variable(vs); let aRef = new ReferenceVariable(vs); let out = b.mul(aRef); out.compile(); aRef.setReference(a); aRef.setValue(2); b.setValue(3); t.equal(a.getValue(), 2); out.forward(); t.equal(out.getValue(), 6); out.setGradient(2); out.backward(); t.equal(out.getGradient(), 2); t.equal(a.getGradient(), 3*2); t.equal(b.getGradient(), 2*2); t.end(); });