webshap
Version:
Explain any ML models anywhere
222 lines (195 loc) • 7.64 kB
text/typescript
import { describe, test, expect, beforeEach } from 'vitest';
import { KernelSHAP, IrisLinearMultiClass } from '../../src/index';
import math from '../../src/utils/math-import';
const SEED = 0.20071022;
interface LocalTestContext {
model: (x: number[][]) => Promise<number[][]>;
data: number[][];
}
/**
* Initialize the fixture for all tests
*/
beforeEach<LocalTestContext>(context => {
const coefs = [
[-0.39228899, 0.85674351, -2.23337115, -0.98440396],
[],
[-0.16645916, -0.60427374, 2.34695022, 1.88953615]
];
const intercepts = [8.85065312, 1.21877625, -10.06942938];
const model = new IrisLinearMultiClass(coefs, intercepts);
// Wrap the model in a promise
context.model = (x: number[][]) => {
const promise = new Promise<number[][]>((resolve, reject) => {
const prob = model.predictProba(x);
resolve(prob);
});
return promise;
};
context.data = [
[],
[],
[],
[],
[]
];
});
test<LocalTestContext>('constructor()', async ({ model, data }) => {
const yPredProbaExp = [
[],
[],
[],
[],
[]
];
const explainer = new KernelSHAP(model, data, SEED);
await explainer.initializeModel();
for (let i = 0; i < explainer.predictions.length; i++) {
const curRow = explainer.predictions[i];
for (let j = 0; j < curRow.length; j++) {
expect(curRow[j]).toBeCloseTo(yPredProbaExp[i][j], 6);
}
}
});
test<LocalTestContext>('inferenceFeatureCoalitions()', async ({
model,
data
}) => {
const explainer = new KernelSHAP(model, data, SEED);
await explainer.initializeModel();
const nSamples = 32;
const x1 = [4.8, 3.8, 2.1, 5.4];
// Inference on the sampled feature coalitions
explainer.sampleFeatureCoalitions(x1, nSamples);
await explainer.inferenceFeatureCoalitions();
// The first 8 masks (40 samples) are deterministic, so we compare the
// results of them with SHAP
const yMat = explainer.yMat!;
const expectedYMat8 = [
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[]
];
for (let i = 0; i < expectedYMat8.length; i++) {
for (let j = 0; j < expectedYMat8[i].length; j++) {
expect(yMat.get([i, j]) as number).toBeCloseTo(expectedYMat8[i][j], 8);
}
}
// Compare the expected y value for all rows (we get the ground truth by
// forcing Python SHAP to use the same kernel weight and mask as TS)
const expectedEyMat = [
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[]
];
for (let i = 0; i < expectedEyMat.length; i++) {
for (let j = 0; j < expectedEyMat[i].length; j++) {
expect(explainer.yExpMat!.get([i, j]) as number).toBeCloseTo(
expectedEyMat[i][j],
8
);
}
}
});
test<LocalTestContext>('explainOneInstance()', async ({ model, data }) => {
const explainer = new KernelSHAP(model, data, SEED);
const nSamples = 32;
const x1 = [4.8, 3.8, 2.1, 5.4];
const results = await explainer.explainOneInstance(x1, nSamples);
const resultsExp = [
[],
[-0.06244084, -0.01309725, 0.05957105, -0.12841866],
[-0.01206957, -0.00816908, -0.32038983, 0.53181117]
];
for (let i = 0; i < resultsExp.length; i++) {
for (let j = 0; j < resultsExp[i].length; j++) {
expect(results[i][j]).toBeCloseTo(resultsExp[i][j], 6);
}
}
});
test<LocalTestContext>('explainOneInstance() with only one background data', async ({
model,
data
}) => {
const singleData = [data[0]];
const explainer = new KernelSHAP(model, singleData, SEED);
const nSamples = 32;
const x1 = [4.8, 3.8, 2.1, 5.4];
const results = await explainer.explainOneInstance(x1, nSamples);
const resultsExp = [
[],
[-0.03435395, -0.02590338, 0.10440986, -0.08952941],
[-0.03754184, 0.00898926, -0.51820532, 0.44622462]
];
for (let i = 0; i < resultsExp.length; i++) {
for (let j = 0; j < resultsExp[i].length; j++) {
expect(results[i][j]).toBeCloseTo(resultsExp[i][j], 6);
}
}
});
test<LocalTestContext>('explainOneInstance() with same columns', async ({
model
}) => {
const data = [[4.8, 2.8, 2.1, 3.3]];
const explainer = new KernelSHAP(model, data, SEED);
const nSamples = 32;
const x1 = [4.8, 3.8, 2.1, 5.4];
const results = await explainer.explainOneInstance(x1, nSamples);
const resultsExp = [
[],
[],
[]
];
for (let i = 0; i < resultsExp.length; i++) {
for (let j = 0; j < resultsExp[i].length; j++) {
expect(results[i][j]).toBeCloseTo(resultsExp[i][j], 6);
}
}
});