UNPKG

webshap

Version:

Explain any ML models anywhere

309 lines (255 loc) 9.92 kB
import { describe, test, expect, beforeEach } from 'vitest'; import { KernelSHAP, IrisLinearBinary } from '../../src/index'; import math from '../../src/utils/math-import'; const SEED = 0.20071022; interface LocalTestContext { model: (x: number[][]) => Promise<number[][]>; data: number[][]; } /** * Initialize the fixture for all tests */ beforeEach<LocalTestContext>(context => { const coef = [-0.1991, 0.3426, 0.0478, 1.03745]; const intercept = -1.6689; const model = new IrisLinearBinary(coef, intercept); context.model = (x: number[][]) => { // Wrap the model in a promise const promise = new Promise<number[][]>((resolve, reject) => { const prob = model.predictProba(x); resolve(prob); }); return promise; }; context.data = [ [5.8, 2.8, 5.1, 2.4], [5.8, 2.7, 5.1, 1.9], [7.2, 3.6, 6.1, 2.5], [6.2, 2.8, 4.8, 1.8], [4.9, 3.1, 1.5, 0.1] ]; }); test<LocalTestContext>('constructor()', async ({ model, data }) => { const yPredProbaExp = [ 0.7045917, 0.57841617, 0.73422101, 0.53812833, 0.19671004 ]; const explainer = new KernelSHAP(model, data, SEED); await explainer.initializeModel(); for (const [i, pred] of explainer.predictions.entries()) { expect(pred[0]).toBeCloseTo(yPredProbaExp[i], 6); } }); test<LocalTestContext>('prepareSampling()', async ({ model, data }) => { const explainer = new KernelSHAP(model, data, SEED); await explainer.initializeModel(); const nSamples = 14; const x1 = [4.8, 3.8, 2.1, 5.4]; // Initialize the sample data explainer.varyingIndexes = explainer.getVaryingIndexes(x1); explainer.nVaryFeatures = explainer.varyingIndexes.length; explainer.prepareSampling(nSamples); // The sample data should be initialized to repeat x_test const sampledData = explainer.sampledData!; expect(sampledData.size()[0]).toBe(nSamples * data.length); expect(sampledData.subset(math.index(0, 0))).toBe(data[0][0]); expect(sampledData.subset(math.index(data.length, 1))).toBe(data[0][1]); expect(sampledData.subset(math.index(sampledData.size()[0] - 1, 2))).toBe( data[data.length - 1][2] ); }); test<LocalTestContext>('sampleFeatureCoalitions()', async ({ model, data }) => { const explainer = new KernelSHAP(model, data, SEED); await explainer.initializeModel(); const x1 = [4.8, 3.8, 2.1, 5.4]; explainer.sampleFeatureCoalitions(x1, 32); // The number of samples should be overridden const sampledData = explainer.sampledData!; expect(sampledData.size()[0]).toBe(14 * data.length); // Size = 1 should be fully sampled const maskMat = explainer.maskMat!; const maskStrings = new Set<string>(); for (let i = 0; i < maskMat.size()[0]; i++) { const row = math.row(maskMat, i).toArray()[0] as number[]; maskStrings.add(KernelSHAP.getMaskStr(row)); } const size1Comb = ['1000', '0100', '0010', '0001']; const size1CombComp = ['0111', '1011', '1101', '1110']; for (const comb of size1Comb) { expect(maskStrings.has(comb)).toBeTruthy(); } for (const comb of size1CombComp) { expect(maskStrings.has(comb)).toBeTruthy(); } // Kernel weights should sum up to 1 const weightSum = math.sum(explainer.kernelWeight!) as number; expect(weightSum).toBeCloseTo(1, 6); // The weights for the enumeration of sample=1 combination should be // 0.09090909 for (let i = 0; i < 8; i++) { expect(explainer.kernelWeight!.get([i, 0])).toBeCloseTo(0.09090909, 8); } // Verify tracker variables expect(explainer.nSamplesAdded).toBe(14); }); test<LocalTestContext>('addSample() basic', async ({ model, data }) => { const explainer = new KernelSHAP(model, data, SEED); await explainer.initializeModel(); const nSamples = 14; const x1 = [4.8, 3.8, 2.1, 5.4]; // Initialize the sample data explainer.varyingIndexes = explainer.getVaryingIndexes(x1); explainer.nVaryFeatures = explainer.varyingIndexes.length; explainer.prepareSampling(nSamples); // Test adding a sample const mask1 = [1.0, 0.0, 1.0, 0.0]; const weight1 = 0.52; explainer.addSample(x1, mask1, weight1); const sampledData = explainer.sampledData!; // Only the first and their elements are changed from the background for (let i = 0; i < data.length; i++) { const row = math.row(sampledData, i).toArray()[0]; const rowExp = [x1[0], data[i][1], x1[2], data[i][3]]; expect(row).toEqual(rowExp); } // Test if all other repetitions of the background data remain the same for (let i = 1; i < nSamples; i++) { for (let j = 0; j < data.length; j++) { const row = math.row(sampledData, i * data.length + j).toArray()[0]; expect(row).toEqual(data[j]); } } // Test tracking variables expect(explainer.kernelWeight!.get([0, 0])).toBe(weight1); expect(explainer.nSamplesAdded).toBe(1); }); test<LocalTestContext>('addSample() more complex', async ({ model, data }) => { const explainer = new KernelSHAP(model, data, SEED); await explainer.initializeModel(); const nSamples = 14; const x1 = [4.8, 3.8, 2.1, 5.4]; // Initialize the sample data explainer.varyingIndexes = explainer.getVaryingIndexes(x1); explainer.nVaryFeatures = explainer.varyingIndexes.length; explainer.prepareSampling(nSamples); // Test adding a sample const mask1 = [1.0, 0.0, 1.0, 0.0]; const weight1 = 0.52; explainer.addSample(x1, mask1, weight1); const x2 = [11.2, 11.2, 11.2, 11.2]; const mask2 = [1.0, 1.0, 0.0, 1.0]; const weight2 = 0.99; explainer.addSample(x2, mask2, weight2); const sampledData = explainer.sampledData!; // The first repetition should match x_1 and mask_1 for (let i = 0; i < data.length; i++) { const row = math.row(sampledData, i).toArray()[0]; const rowExp = [x1[0], data[i][1], x1[2], data[i][3]]; expect(row).toEqual(rowExp); } // The second repetition should match x_2 and mask_2 for (let i = 0; i < data.length; i++) { const r = data.length + i; const row = math.row(sampledData, r).toArray()[0]; const rowExp = [x2[0], x2[1], data[i][2], x2[3]]; expect(row).toEqual(rowExp); } // Test if all other repetitions of the background data remain the same for (let i = 2; i < nSamples; i++) { for (let j = 0; j < data.length; j++) { const row = math.row(sampledData, i * data.length + j).toArray()[0]; expect(row).toEqual(data[j]); } } // Test tracking variables expect(explainer.kernelWeight!.get([1, 0])).toBe(weight2); expect(explainer.nSamplesAdded).toBe(2); }); test<LocalTestContext>('inferenceFeatureCoalitions()', async ({ model, data }) => { const explainer = new KernelSHAP(model, data, SEED); await explainer.initializeModel(); const nSamples = 32; const x1 = [4.8, 3.8, 2.1, 5.4]; // Inference on the sampled feature coalitions explainer.sampleFeatureCoalitions(x1, nSamples); await explainer.inferenceFeatureCoalitions(); // The first 8 masks (40 samples) are deterministic, so we compare the // results of them with SHAP const expectedYMat8 = [ 0.74428491, 0.62606565, 0.81667565, 0.60624373, 0.19987513, 0.98494403, 0.98494403, 0.98019991, 0.98371625, 0.98738284, 0.77062789, 0.66666396, 0.74737578, 0.62138006, 0.23736781, 0.98266107, 0.98206758, 0.98676269, 0.98266107, 0.9843281, 0.67389612, 0.54311144, 0.69528502, 0.50593722, 0.20128136, 0.98926348, 0.98926348, 0.98975948, 0.9891101, 0.98727311, 0.98168608, 0.98105986, 0.98244577, 0.9799177, 0.98356063, 0.78032477, 0.6789248, 0.79759091, 0.65590315, 0.2462757 ]; for (let i = 0; i < expectedYMat8.length; i++) { expect(explainer.yMat!.get([i, 0]) as number).toBeCloseTo( expectedYMat8[i], 8 ); } // Compare the expected y value const expectedEyMat8 = [ 0.59862901, 0.98423741, 0.6086831, 0.9836961, 0.52390223, 0.98893393, 0.98173401, 0.63180387 ]; for (let i = 0; i < expectedEyMat8.length; i++) { expect(explainer.yExpMat!.get([i, 0]) as number).toBeCloseTo( expectedEyMat8[i], 8 ); } }); test<LocalTestContext>('explainOneInstance()', async ({ model, data }) => { const explainer = new KernelSHAP(model, data, SEED); const nSamples = 32; const x1 = [4.8, 3.8, 2.1, 5.4]; const results = await explainer.explainOneInstance(x1, nSamples); const values = results[0]; const valuesExp = [0.02968265, 0.03134839, -0.0162967, 0.39248069]; for (const [i, value] of values.entries()) { expect(value).toBeCloseTo(valuesExp[i], 6); } }); test<LocalTestContext>('explainOneInstance() with only one background data should not fail', async ({ model, data }) => { const singleData = [data[0]]; const explainer = new KernelSHAP(model, singleData, SEED); const nSamples = 32; const x1 = [4.8, 3.8, 2.1, 5.4]; const results = await explainer.explainOneInstance(x1, nSamples); const values = results[0]; const valuesExp = [0.02457882, 0.03561506, -0.01863908, 0.24148199]; for (const [i, value] of values.entries()) { expect(value).toBeCloseTo(valuesExp[i], 6); } }); test<LocalTestContext>('getVaryingIndexes()', async ({ model }) => { const data = [[4.8, 2.8, 2.1, 3.3]]; const explainer = new KernelSHAP(model, data, SEED); const nSamples = 32; const x1 = [4.8, 3.8, 2.1, 5.4]; await explainer.explainOneInstance(x1, nSamples); expect(explainer.varyingIndexes).toEqual([1, 3]); expect(explainer.nVaryFeatures).toBe(2); }); test<LocalTestContext>('explainOneInstance() with same columns', async ({ model }) => { const data = [[4.8, 2.8, 2.1, 3.3]]; const explainer = new KernelSHAP(model, data, SEED); const nSamples = 32; const x1 = [4.8, 3.8, 2.1, 5.4]; const results = await explainer.explainOneInstance(x1, nSamples); const values = results[0]; const valuesExp = [0, 0.0200946, 0, 0.10239262]; for (const [i, value] of values.entries()) { expect(value).toBeCloseTo(valuesExp[i], 6); } });