qminer
Version:
A C++ based data analytics platform for processing large-scale real-time streams containing structured and unstructured data
185 lines (175 loc) • 8.15 kB
JavaScript
var assert = require('../../src/nodejs/scripts/assert.js');
var qm = require('../../index.js');
var analytics = qm.analytics;
var la = qm.la;
describe('Logistic Regression Tests', function () {
describe('Constructor Tests', function () {
it('should not throw an exception, default constructor', function () {
assert.doesNotThrow(function () {
var logreg = new analytics.LogReg();
});
})
it('should create a default constructor', function () {
var logreg = new analytics.LogReg();
var param = logreg.getParams();
assert.strictEqual(param.lambda, 1.0);
assert.strictEqual(param.intercept, false);
})
it('should not throw an exception, given parameters', function () {
assert.doesNotThrow(function () {
var logreg = new analytics.LogReg({ lambda: 10, intercept: true });
});
})
it('should create a constructor out of the given parameters', function () {
var logreg = new analytics.LogReg({ lambda: 10, intercept: true });
var param = logreg.getParams();
assert.strictEqual(param.lambda, 10.0);
assert.strictEqual(param.intercept, true);
})
it('should throw an exception if the parameter is not a json object', function () {
assert.throws(function () {
var logreg = new analytics.LogReg(10, false);
});
})
});
describe('GetParams Tests', function () {
it('should not throw an exception, default constructor', function () {
var logreg = new analytics.LogReg();
assert.doesNotThrow(function () {
var param = logreg.getParams();
});
})
it('should give the parameters of the default constructor', function () {
var logreg = new analytics.LogReg();
var param = logreg.getParams();
assert.strictEqual(param.lambda, 1.0);
assert.strictEqual(param.intercept, false);
})
it('should not throw an exception, given parameters', function () {
var logreg = new analytics.LogReg({ lambda: 10, intercept: true });
assert.doesNotThrow(function () {
var param = logreg.getParams();
});
})
it('should returns the parameters of the constructor created with the given parameters', function () {
var logreg = new analytics.LogReg({ lambda: 10, intercept: true });
var param = logreg.getParams();
assert.strictEqual(param.lambda, 10);
assert.strictEqual(param.intercept, true);
})
it('should return parameters that can be used for the constructor', function () {
var logreg = new analytics.LogReg({ lambda: 10, intercept: true });
var param = logreg.getParams();
var logreg2 = new analytics.LogReg(param);
var param2 = logreg2.getParams();
assert.strictEqual(param2.lambda, param.lambda);
assert.strictEqual(param2.intercept, param.intercept);
})
});
describe('SetParams Tests', function () {
it('should not throw an exception', function () {
var logreg = new analytics.LogReg();
assert.doesNotThrow(function () {
logreg.setParams({ lambda: 10, intercept: true });
});
})
it('should set the parameters of the model', function () {
var logreg = new analytics.LogReg();
logreg.setParams({ lambda: 10, intercept: true });
var param = logreg.getParams();
assert.strictEqual(param.lambda, 10);
assert.strictEqual(param.intercept, true);
})
});
// need openblas configurations for these tests
if (qm.flags.blas) {
describe('Fit Tests', function () {
it('should not throw an exception', function () {
var logreg = new analytics.LogReg();
var mat = new la.Matrix([[1, 2], [1, -1]]);
var vec = new la.Vector([3, 3]);
assert.doesNotThrow(function () {
logreg.fit(mat, vec);
});
})
it('should fit the model with the values', function () {
var logreg = new analytics.LogReg();
var mat = new la.Matrix([[1, 1], [1, -1]]);
var vec = new la.Vector([3, 3]);
logreg.fit(mat, vec);
var weights = logreg.weights;
assert(weights[0] > 0);
assert(weights[1] == 0);
})
it('should throw an exception if matrix has more columns than length of vector', function () {
var logreg = new analytics.LogReg();
var mat = new la.Matrix([[1, 1, 1], [1, -1, -2]]);
var vec = new la.Vector([3, 3]);
assert.throws(function () {
logreg.fit(mat, vec);
});
})
it('should throw an exception if matrix has fewer columns than length of vector', function () {
var logreg = new analytics.LogReg();
var mat = new la.Matrix([[1], [1]]);
var vec = new la.Vector([3, 3]);
assert.throws(function () {
logreg.fit(mat, vec);
});
})
});
describe('Predict Tests', function () {
it('should not throw an exception', function () {
var logreg = new analytics.LogReg();
var mat = new la.Matrix([[1, 1], [1, -1]]);
var vec = new la.Vector([3, 3]);
logreg.fit(mat, vec);
var test = new la.Vector([1, 2]);
assert.doesNotThrow(function () {
var prediction = logreg.predict(test);
});
})
it('should return the prediction', function () {
var logreg = new analytics.LogReg();
var mat = new la.Matrix([[1, 1], [1, -1]]);
var vec = new la.Vector([3, 3]);
logreg.fit(mat, vec);
var test = new la.Vector([1, 3]);
var prediction = logreg.predict(test);
assert.eqtol(prediction, 1);
})
it('should throw an exception if the given vector is longer than model', function () {
var logreg = new analytics.LogReg();
var mat = new la.Matrix([[1, 1], [1, -1]]);
var vec = new la.Vector([3, 3]);
logreg.fit(mat, vec);
var test = new la.Vector([1, 3, 2]);
assert.throws(function () {
var prediction = logreg.predict(test);
});
})
it('should throw an excpetion if the given vector is shorter than model', function () {
var logreg = new analytics.LogReg();
var mat = new la.Matrix([[1, 1], [1, -1]]);
var vec = new la.Vector([3, 3]);
logreg.fit(mat, vec);
var test = new la.Vector([1]);
assert.throws(function () {
var prediction = logreg.predict(test);
});
})
})
describe('Serialization Tests', function () {
it('should serialize and deserialize', function () {
var logreg = new analytics.LogReg();
var mat = new la.Matrix([[1, 1], [1, -1]]);
var vec = new la.Vector([3, 3]);
logreg.fit(mat, vec);
logreg.save(require('../../index.js').fs.openWrite('logreg_test.bin')).close();
var logreg2 = new analytics.LogReg(require('../../index.js').fs.openRead('logreg_test.bin'));
assert.deepEqual(logreg.getParams(), logreg2.getParams());
assert.eqtol(logreg.weights.minus(logreg2.weights).norm(), 0, 1e-8);
})
});
}
});