UNPKG

qminer

Version:

A C++ based data analytics platform for processing large-scale real-time streams containing structured and unstructured data

271 lines (251 loc) 10.7 kB
/** * Copyright (c) 2015, Jozef Stefan Institute, Quintelligence d.o.o. and contributors * All rights reserved. * * This source code is licensed under the FreeBSD license found in the * LICENSE file in the root directory of this source tree. */ var assert = require('../../src/nodejs/scripts/assert.js'); var analytics = require('../../index.js').analytics; var la = require('../../index.js').la; describe('Ridge Regression Tests', function () { describe('Constructor Tests', function () { it('should not throw an exception, empty constructor', function () { assert.doesNotThrow(function () { var RR = new analytics.RidgeReg(); }); }) it('should create a new default ridge regression model', function () { var RR = new analytics.RidgeReg(); var param = RR.getParams(); assert.strictEqual(param.gamma, 0); }) it('should not throw an exception, json parameter', function () { assert.doesNotThrow(function () { var RR = new analytics.RidgeReg({ gamma: 5 }); }); }) it('should create a new ridge regression, json parameter', function () { var RR = new analytics.RidgeReg({ gamma: 5 }); var param = RR.getParams(); assert.strictEqual(param.gamma, 5); }) it('should create a default ridge regression if json doesn\'t have the gamma key', function () { var RR = new analytics.RidgeReg({ alpha: 5 }); var param = RR.getParams(); assert.strictEqual(param.gamma, 0); }) it('should throw an exception if the parameter is not valid', function () { assert.throws(function () { var RR = new analytics.RidgeReg('gamma'); }); }) }); describe('GetParams Tests', function () { it('should not throw an exception', function () { var RR = new analytics.RidgeReg(); assert.doesNotThrow(function () { var param = RR.getParams(); }) }) it('should return the parameters of the model', function () { var RR = new analytics.RidgeReg(); var param = RR.getParams(); assert.strictEqual(param.gamma, 0); }) it('should throw an exception if parameters are given', function () { var RR = new analytics.RidgeReg(); assert.throws(function () { var param = RR.getParams(5); }); }) it('should be used for construction of a new Ridge Regression model', function () { var RR = new analytics.RidgeReg({ gamma: 4 }); var param = RR.getParams(); var otherRR = new analytics.RidgeReg(param); var otherParam = otherRR.getParams(); assert.eqtol(otherParam.gamma, 4); }) }); describe('SetParams Tests', function () { it('should not throw an exception, Json parameter', function () { var RR = new analytics.RidgeReg(); assert.doesNotThrow(function () { RR.setParams({ gamma: 5 }); }); }) it('should set the parameter gamma to 5, Json parameter', function () { var RR = new analytics.RidgeReg(); RR.setParams({ gamma: 5 }); var param = RR.getParams(); assert.strictEqual(param.gamma, 5); }) it('should throw an exception if the parameter is not valid', function () { var RR = new analytics.RidgeReg(); assert.throws(function () { RR.setParams("something"); }); }) it('does not changes the model', function () { var RR = new analytics.RidgeReg(); var A = new la.Matrix([[1, 2], [1, -1]]); var b = new la.Vector([3, 3]); RR.fit(A, b); var weights = RR.weights; assert.eqtol(weights.length, 2); assert.eqtol(weights[0], 2); assert.eqtol(weights[1], 1); RR.setParams({ gamma: 5 }); weights = RR.weights; assert.eqtol(weights.length, 2); assert.eqtol(weights[0], 2); assert.eqtol(weights[1], 1); }) }); describe('Fit Tests', function () { it('should not throw an exception', function () { var RR = new analytics.RidgeReg(); var A = new la.Matrix([[1, -2], [2, 0], [3, 0], [0, 2], [1, 1]]); var b = new la.Vector([3, 2]); assert.doesNotThrow(function () { RR.fit(A, b); }); }) it('should fit the model to the test values, default parameters', function () { var RR = new analytics.RidgeReg(); var A = new la.Matrix([[1, 2], [1, -1]]); var b = new la.Vector([3, 3]); RR.fit(A, b); var weights = RR.weights; assert.eqtol(weights.length, 2); assert.eqtol(weights[0], 2); assert.eqtol(weights[1], 1); }) it('should fit the model to the test values, Json parameters', function () { var RR = new analytics.RidgeReg({ gamma: 1 }); var A = new la.Matrix([[1, 2], [1, -1]]); var b = new la.Vector([3, 3]); RR.fit(A, b); var weights = RR.weights; assert.strictEqual(weights.length, 2); assert.eqtol(weights[0], 1.588235294117647); assert.eqtol(weights[1], 0.529411764705883); }) it('should throw an exception if the number of columns and the length of the vector are not equal', function () { var RR = new analytics.RidgeReg({ gamma: 1 }); var A = new la.Matrix([[1, 2, 3], [1, -1, 2]]); var b = new la.Vector([3, 3]); assert.throws(function () { RR.fit(A, b); }); }) it('should forget the previous models', function () { var RR = new analytics.RidgeReg(); var A = new la.Matrix([[1, 2], [1, -1]]); var b = new la.Vector([3, 3]); RR.fit(A, b); var weights = RR.weights; assert.eqtol(weights.length, 2); assert.eqtol(weights[0], 2); assert.eqtol(weights[1], 1); var A2 = new la.Matrix([[1, 1], [2, -1]]); var b2 = new la.Vector([3, 3]); RR.fit(A2, b2); weights = RR.weights; assert.eqtol(weights.length, 2); assert.eqtol(weights[0], 3); assert.eqtol(weights[1], 0); }) }); describe('GetModel Tests', function () { it('should not throw an exception', function () { var RR = new analytics.RidgeReg(); var A = new la.Matrix([[1, 2], [1, -1]]); var b = new la.Vector([3, 3]); RR.fit(A, b); assert.doesNotThrow(function () { var model = RR.getModel(); }); }) it('should return the model, no fit previously made', function () { var RR = new analytics.RidgeReg(); var A = new la.Matrix([[1, 2], [1, -1]]); var b = new la.Vector([3, 3]); var model = RR.getModel(); assert.eqtol(model.weights.length, 0); }) it('should return the model, default parameters', function () { var RR = new analytics.RidgeReg(); var A = new la.Matrix([[1, 2], [1, -1]]); var b = new la.Vector([3, 3]); RR.fit(A, b); var model = RR.getModel(); assert.eqtol(model.weights.length, 2); assert.eqtol(model.weights[0], 2.0); assert.eqtol(model.weights[1], 1.0); }) it('should return the model, Json parameters', function () { var RR = new analytics.RidgeReg({ gamma: 1 }); var A = new la.Matrix([[1, 2], [1, -1]]); var b = new la.Vector([3, 3]); RR.fit(A, b); var model = RR.getModel(); assert.strictEqual(model.weights.length, 2); assert.eqtol(model.weights[0], 1.588235294117647); assert.eqtol(model.weights[1], 0.529411764705883); }) }); describe('Predict Tests', function () { it('should not throw an exception', function () { var RR = new analytics.RidgeReg(); var A = new la.Matrix([[1, 2], [1, -1]]); var b = new la.Vector([3, 3]); RR.fit(A, b); var vec = new la.Vector([3, 4]); assert.doesNotThrow(function () { var prediction = RR.predict(vec); }); }) it('should return the prediction of the vector [3, 4]', function () { var RR = new analytics.RidgeReg(); var A = new la.Matrix([[1, 2], [1, -1]]); var b = new la.Vector([3, 3]); RR.fit(A, b); var vec = new la.Vector([3, 4]); var prediction = RR.predict(vec); assert.eqtol(prediction, 10); }) it('should throw an exception if the vector is longer than the model', function () { var RR = new analytics.RidgeReg(); var A = new la.Matrix([[1, 2], [1, -1]]); var b = new la.Vector([3, 3]); RR.fit(A, b); var vec = new la.Vector([3, 4, 5]); assert.throws(function () { var prediction = RR.predict(vec); }); }) it('should throw an exception if the vector is shorter than the model', function () { var RR = new analytics.RidgeReg(); var A = new la.Matrix([[1, 2], [1, -1]]); var b = new la.Vector([3, 3]); RR.fit(A, b); var vec = new la.Vector([3]); assert.throws(function () { var prediction = RR.predict(vec); }); }) }); describe('Serialization Tests', function () { it('should serialize and deserialize', function () { var RR = new analytics.RidgeReg(); var A = new la.Matrix([[1, 2], [1, -1]]); var b = new la.Vector([3, 3]); RR.fit(A, b); RR.save(require('../../index.js').fs.openWrite('ridgereg_test.bin')).close(); var RR2 = new analytics.RidgeReg(require('../../index.js').fs.openRead('ridgereg_test.bin')); assert.deepEqual(RR.getParams(), RR2.getParams()); assert.eqtol(RR.weights.minus(RR2.weights).norm(), 0, 1e-8); }) }); });