webshap
Version:
Explain any ML models anywhere
87 lines (76 loc) • 2.71 kB
text/typescript
import { describe, test, expect, beforeEach } from 'vitest';
import { lstsq } from '../../src/explainer/lstsq';
import math from '../../src/utils/math-import';
interface LocalTestContext {
x: math.Matrix;
y: math.Matrix;
w: math.Matrix;
}
/**
* Initialize the fixture for all tests
*/
beforeEach<LocalTestContext>(context => {
context.x = math.reshape(
math.matrix([
6.5883575, 9.47744716, 6.38128927, 1.32898554, 1.42451075, 8.07888215,
1.37843823, 5.01048122, 5.91789496, 6.74303087, 9.71877437, 2.68876159,
2.15260096, 0.67584554, 8.76868181, 0.69520003, 3.27889612, 3.53383915,
0.39004321, 8.98698661, 5.94193232, 5.61643321, 5.80469669, 9.02401475,
9.45376559, 4.79489526, 1.03856608, 4.44841592, 5.27924568, 6.81838066,
7.13956214, 6.20247773, 9.34432887, 9.58139555, 9.94923107, 3.43236642,
2.67356961, 1.87084952, 7.85087359, 0.09172617, 3.50793606, 1.1197608,
3.40025746, 8.02584101, 5.8887118, 8.23636548, 1.32038954, 6.14270219,
0.13882923, 7.53383955
]),
[10, 5]
);
context.y = math.reshape(
math.matrix([
61.767112, 14.04187055, 30.88552799, 51.9603377, 29.28888286, 86.85199145,
73.45540838, 46.08049002, 43.68663906, 48.23927507
]),
[10, 1]
);
context.w = math.reshape(
math.matrix([
0.14251289, 0.03092725, 0.07816275, 0.13553988, 0.02129188, 0.13268995,
0.05887923, 0.15102889, 0.11071796, 0.13824932
]),
[10, 1]
);
});
test<LocalTestContext>('lstsq() without weight', ({ x, y }) => {
const w = math.matrix(math.ones([10, 1]));
const result = lstsq(x, y, w);
const resultExp = math.reshape(
math.matrix([-0.38057159, 1.35998903, 6.28966952, 0.75159811, 1.20430188]),
[5, 1]
);
result.forEach((v, i) => {
const curI = i as unknown as [number, number];
expect(v).toBeCloseTo(resultExp.get(curI) as number, 4);
});
});
test<LocalTestContext>('lstsq() diagonal w', ({ x, y }) => {
const w = math.matrix(math.diag(new Array(10).fill(1)));
const result = lstsq(x, y, w);
const resultExp = math.reshape(
math.matrix([-0.38057159, 1.35998903, 6.28966952, 0.75159811, 1.20430188]),
[5, 1]
);
result.forEach((v, i) => {
const curI = i as unknown as [number, number];
expect(v).toBeCloseTo(resultExp.get(curI) as number, 4);
});
});
test<LocalTestContext>('lstsq() weight', ({ x, y, w }) => {
const result = lstsq(x, y, w);
const resultExp = math.reshape(
math.matrix([-0.28239368, 1.83192929, 5.51801203, 2.44676581, 1.94097814]),
[5, 1]
);
result.forEach((v, i) => {
const curI = i as unknown as [number, number];
expect(v).toBeCloseTo(resultExp.get(curI) as number, 4);
});
});