@bdelab/jscat
Version:
A library to support IRT-based computer adaptive testing in JavaScript
320 lines (279 loc) • 10.5 kB
text/typescript
/* eslint-disable @typescript-eslint/no-non-null-assertion */
import { minimize_Powell } from 'optimization-js';
import { Stimulus, Zeta } from './type';
import { itemResponseFunction, fisherInformation, normal, findClosest } from './utils';
import { validateZetaParams, fillZetaDefaults } from './corpus';
import seedrandom from 'seedrandom';
import _clamp from 'lodash/clamp';
import _cloneDeep from 'lodash/cloneDeep';
const abilityPrior = normal();
export interface CatInput {
method?: string;
itemSelect?: string;
nStartItems?: number;
startSelect?: string;
theta?: number;
minTheta?: number;
maxTheta?: number;
prior?: number[][];
randomSeed?: string | null;
}
export class Cat {
public method: string;
public itemSelect: string;
public minTheta: number;
public maxTheta: number;
public prior: number[][];
private readonly _zetas: Zeta[];
private readonly _resps: (0 | 1)[];
private _theta: number;
private _seMeasurement: number;
public nStartItems: number;
public startSelect: string;
private readonly _rng: ReturnType<seedrandom>;
/**
* Create a Cat object. This expects an single object parameter with the following keys
* @param {{method: string, itemSelect: string, nStartItems: number, startSelect:string, theta: number, minTheta: number, maxTheta: number, prior: number[][]}=} destructuredParam
* method: ability estimator, e.g. MLE or EAP, default = 'MLE'
* itemSelect: the method of item selection, e.g. "MFI", "random", "closest", default method = 'MFI'
* nStartItems: first n trials to keep non-adaptive selection
* startSelect: rule to select first n trials
* theta: initial theta estimate
* minTheta: lower bound of theta
* maxTheta: higher bound of theta
* prior: the prior distribution
* randomSeed: set a random seed to trace the simulation
*/
constructor({
method = 'MLE',
itemSelect = 'MFI',
nStartItems = 0,
startSelect = 'middle',
theta = 0,
minTheta = -6,
maxTheta = 6,
prior = abilityPrior,
randomSeed = null,
}: CatInput = {}) {
this.method = Cat.validateMethod(method);
this.itemSelect = Cat.validateItemSelect(itemSelect);
this.startSelect = Cat.validateStartSelect(startSelect);
this.minTheta = minTheta;
this.maxTheta = maxTheta;
this.prior = prior;
this._zetas = [];
this._resps = [];
this._theta = theta;
this._seMeasurement = Number.MAX_VALUE;
this.nStartItems = nStartItems;
this._rng = randomSeed === null ? seedrandom() : seedrandom(randomSeed);
}
public get theta() {
return this._theta;
}
public get seMeasurement() {
return this._seMeasurement;
}
/**
* Return the number of items that have been observed so far.
*/
public get nItems() {
return this._resps.length;
}
public get resps() {
return this._resps;
}
public get zetas() {
return this._zetas;
}
private static validateMethod(method: string) {
const lowerMethod = method.toLowerCase();
const validMethods: Array<string> = ['mle', 'eap']; // TO DO: add staircase
if (!validMethods.includes(lowerMethod)) {
throw new Error('The abilityEstimator you provided is not in the list of valid methods');
}
return lowerMethod;
}
private static validateItemSelect(itemSelect: string) {
const lowerItemSelect = itemSelect.toLowerCase();
const validItemSelect: Array<string> = ['mfi', 'random', 'closest', 'fixed'];
if (!validItemSelect.includes(lowerItemSelect)) {
throw new Error('The itemSelector you provided is not in the list of valid methods');
}
return lowerItemSelect;
}
private static validateStartSelect(startSelect: string) {
const lowerStartSelect = startSelect.toLowerCase();
const validStartSelect: Array<string> = ['random', 'middle', 'fixed']; // TO DO: add staircase
if (!validStartSelect.includes(lowerStartSelect)) {
throw new Error('The startSelect you provided is not in the list of valid methods');
}
return lowerStartSelect;
}
/**
* use previous response patterns and item params to calculate the estimate ability based on a defined method
* @param zeta - last item param
* @param answer - last response pattern
* @param method
*/
public updateAbilityEstimate(zeta: Zeta | Zeta[], answer: (0 | 1) | (0 | 1)[], method: string = this.method) {
method = Cat.validateMethod(method);
zeta = Array.isArray(zeta) ? zeta : [zeta];
answer = Array.isArray(answer) ? answer : [answer];
zeta.forEach((z) => validateZetaParams(z, true));
if (zeta.length !== answer.length) {
throw new Error('Unmatched length between answers and item params');
}
this._zetas.push(...zeta);
this._resps.push(...answer);
if (method === 'eap') {
this._theta = this.estimateAbilityEAP();
} else if (method === 'mle') {
this._theta = this.estimateAbilityMLE();
}
this.calculateSE();
}
private estimateAbilityEAP() {
let num = 0;
let nf = 0;
this.prior.forEach(([theta, probability]) => {
const like = this.likelihood(theta);
num += theta * like * probability;
nf += like * probability;
});
return num / nf;
}
private estimateAbilityMLE() {
const theta0 = [0];
const solution = minimize_Powell(this.negLikelihood.bind(this), theta0);
const theta = solution.argument[0];
return _clamp(theta, this.minTheta, this.maxTheta);
}
private negLikelihood(thetaArray: Array<number>) {
return -this.likelihood(thetaArray[0]);
}
private likelihood(theta: number) {
return this._zetas.reduce((acc, zeta, i) => {
const irf = itemResponseFunction(theta, zeta);
return this._resps[i] === 1 ? acc + Math.log(irf) : acc + Math.log(1 - irf);
}, 1);
}
/**
* calculate the standard error of ability estimation
*/
private calculateSE() {
const sum = this._zetas.reduce((previousValue, zeta) => previousValue + fisherInformation(this._theta, zeta), 0);
this._seMeasurement = 1 / Math.sqrt(sum);
}
/**
* find the next available item from an input array of stimuli based on a selection method
*
* remainingStimuli is sorted by fisher information to reduce the computation complexity for future item selection
* @param stimuli - an array of stimulus
* @param itemSelect - the item selection method
* @param deepCopy - default deepCopy = true
* @returns {nextStimulus: Stimulus, remainingStimuli: Array<Stimulus>}
*/
public findNextItem(stimuli: Stimulus[], itemSelect: string = this.itemSelect, deepCopy = true) {
let arr: Array<Stimulus>;
let selector = Cat.validateItemSelect(itemSelect);
if (deepCopy) {
arr = _cloneDeep(stimuli);
} else {
arr = stimuli;
}
arr = arr.map((stim) => fillZetaDefaults(stim, 'semantic'));
if (this.nItems < this.nStartItems) {
selector = this.startSelect;
}
if (selector !== 'mfi' && selector !== 'fixed') {
// for mfi, we sort the arr by fisher information in the private function to select the best item,
// and then sort by difficulty to return the remainingStimuli
// for fixed, we want to keep the corpus order as input
arr.sort((a: Stimulus, b: Stimulus) => a.difficulty! - b.difficulty!);
}
if (selector === 'middle') {
// middle will only be used in startSelect
return this.selectorMiddle(arr);
} else if (selector === 'closest') {
return this.selectorClosest(arr);
} else if (selector === 'random') {
return this.selectorRandom(arr);
} else if (selector === 'fixed') {
return this.selectorFixed(arr);
} else {
return this.selectorMFI(arr);
}
}
private selectorMFI(inputStimuli: Stimulus[]) {
const stimuli = inputStimuli.map((stim) => fillZetaDefaults(stim, 'semantic'));
const stimuliAddFisher = stimuli.map((element: Stimulus) => ({
fisherInformation: fisherInformation(this._theta, fillZetaDefaults(element, 'symbolic')),
...element,
}));
stimuliAddFisher.sort((a, b) => b.fisherInformation - a.fisherInformation);
stimuliAddFisher.forEach((stimulus: Stimulus) => {
delete stimulus['fisherInformation'];
});
return {
nextStimulus: stimuliAddFisher[0],
remainingStimuli: stimuliAddFisher.slice(1).sort((a: Stimulus, b: Stimulus) => a.difficulty! - b.difficulty!),
};
}
private selectorMiddle(arr: Stimulus[]) {
let index: number;
index = Math.floor(arr.length / 2);
if (arr.length >= this.nStartItems) {
index += this.randomInteger(-Math.floor(this.nStartItems / 2), Math.floor(this.nStartItems / 2));
}
const nextItem = arr[index];
arr.splice(index, 1);
return {
nextStimulus: nextItem,
remainingStimuli: arr,
};
}
private selectorClosest(arr: Stimulus[]) {
//findClosest requires arr is sorted by difficulty
const index = findClosest(arr, this._theta + 0.481);
const nextItem = arr[index];
arr.splice(index, 1);
return {
nextStimulus: nextItem,
remainingStimuli: arr,
};
}
private selectorRandom(arr: Stimulus[]) {
const index = this.randomInteger(0, arr.length - 1);
const nextItem = arr.splice(index, 1)[0];
return {
nextStimulus: nextItem,
remainingStimuli: arr,
};
}
/**
* Picks the next item in line from the given list of stimuli.
* It grabs the first item from the list, removes it, and then returns it along with the rest of the list.
*
* @param arr - The list of stimuli to choose from.
* @returns {Object} - An object with the next item and the updated list.
* @returns {Stimulus} return.nextStimulus - The item that was picked from the list.
* @returns {Stimulus[]} return.remainingStimuli - The list of what's left after picking the item.
*/
private selectorFixed(arr: Stimulus[]) {
const nextItem = arr.shift();
return {
nextStimulus: nextItem,
remainingStimuli: arr,
};
}
/**
* return a random integer between min and max
* @param min - The minimum of the random number range (include)
* @param max - The maximum of the random number range (include)
* @returns {number} - random integer within the range
*/
private randomInteger(min: number, max: number) {
return Math.floor(this._rng() * (max - min + 1)) + min;
}
}