aiom_pack
Version:
Framework for interdependent (mcmc-like) behavioral experiments
380 lines (362 loc) • 17.1 kB
JavaScript
// controllers/dataController.js
const { pool } = require('../../core/database');
// const { ExperimentConfig } = require('../../core/config');
const sampling = require('../../models/sampling');
const transformer = require('../../models/transformer');
const gk = require('../../models/gatekeeper');
// const converge = require('../../models/convergence');
const fs = require('fs');
const path = require('path');
class Controller {
constructor(experimentPath, config) {
this.config = config;
this.experimentPath = experimentPath;
// Initialize experiment settings from config
this.n_chain = this.config.getNumber('n_chain');
this.lower_bound = this.config.getNumber('lower_bound');
this.upper_bound = this.config.getNumber('upper_bound');
this.classes = this.config.getArray('classes');
this.class_questions = this.config.getArray('class_questions');
this.n_class = this.classes.length;
this.max_trial = this.config.getNumber('trial_per_participant_per_class');
this.dim = this.config.getNumber('dim');
this.mode = this.config.get('mode');
this.n_rest = this.config.getNumber('n_rest');
this.attention_check = this.config.getBoolean('attention_check');
this.attention_check_dir = this.config.get('attention_check_dir');
this.attention_check_rate = Number(this.config.get('attention_check_rate'));
this.temperature = 2.0;
this.stuck_count = {};
this.stuck_patience = this.config.getNumber('stuck_patience');
this.proposal_cov_v = this.config.getNumber('proposal_cov');
this.proposal_cov = Array(this.dim).fill().map((_, i) =>
Array(this.dim).fill().map((_, j) => i === j ? this.proposal_cov_v : 0)
); // align with process.env.dim
if (this.mode==='test') {
this.stimuli_processing = transformer.raw;
this.stimuli_processing_batch = transformer.raw;
} else if (this.mode==='image') {
this.stimuli_processing = transformer.to_image;
this.stimuli_processing_batch = transformer.to_image_gsp;
}
// Setup other config-dependent properties
this.setupGatekeeper();
}
setupGatekeeper() {
if (this.config.get('gatekeeper') === 'false') {
this.gatekeeper = false;
} else if (this.config.get('gatekeeper') === 'true') {
const modelsDir = this.config.get('gatekeeper_dir');
this.gatekeeper = {};
for (const cate of this.classes) {
const modelFilename = `${cate}.json`;
const modelFilePath = path.join(this.experimentPath, modelsDir, modelFilename);
const modelParamsJson = fs.readFileSync(modelFilePath, 'utf8');
const gatekeeper_parameters = JSON.parse(modelParamsJson);
this.gatekeeper[cate] = new gk.GaussianKDE(gatekeeper_parameters);
console.log(`Gatekeeper ${cate} initialized successfully with custom models in ${modelsDir}`);
}
} else {
this.gatekeeper = false;
}
}
async initialize() {
try {
// Create table if it doesn't exist
await pool.query(`CREATE TABLE IF NOT EXISTS participants (
id SERIAL PRIMARY KEY,
participant TEXT NOT NULL,
attention_check_fail INTEGER DEFAULT 0,
face_authorization BOOLEAN DEFAULT false,
completion TEXT,
bonus_issued BOOLEAN DEFAULT false
);`);
for (const colname of this.classes) {
await pool.query(`ALTER TABLE participants ADD COLUMN IF NOT EXISTS "${colname}_ss" INTEGER DEFAULT 0;`);
}
} catch (error) {
console.error('Error setting up initialization database:', error);
}
}
async set_table(req, res, next) {
const name = req.body.names;
this.stuck_count[name] = 0;
var table_name;
try {
await pool.query(
`INSERT INTO participants (participant)
VALUES ($1)`,
[name]
);
// create tables and insert the starting point for each chain
for (let i=1; i<=this.n_chain; i++) {
table_name = `${name}_blockwise_no${i}`;
await pool.query(`CREATE TABLE IF NOT EXISTS ${table_name} (
id SERIAL PRIMARY KEY,
stimulus JSON NOT NULL,
category TEXT NOT NULL,
for_prior BOOLEAN,
gatekeeper BOOLEAN
);`);
const current_class = this.classes[Math.floor(Math.random() * this.n_class)];
const current_state = this.gatekeeper
? transformer.limit_array_in_range(this.gatekeeper[current_class].sampling(), this.lower_bound, this.upper_bound)
: sampling.uniform_array(this.dim, this.lower_bound, this.upper_bound);
await pool.query(
`INSERT INTO ${table_name} (stimulus, category, for_prior)
VALUES ($1, $2, $3)`,
[JSON.stringify(current_state), current_class, true]
);
}
res.status(200).json({
"classes": this.classes,
"class_questions": this.class_questions,
"n_rest": this.n_rest,
"mode": this.mode,
});
} catch (error) {
next(error);
}
}
async generate_stimulus(table_name) {
var current_state, current_class, proposal, trial_type;
const check_table = await pool.query(`
SELECT stimulus, category, for_prior FROM ${table_name} ORDER BY id DESC LIMIT 1
`);
current_state = check_table.rows[0].stimulus;
current_class = check_table.rows[0].category;
// console.log(check_table);
if (check_table.rows[0].for_prior) {
let proposal_center;
if (this.gatekeeper) {
proposal_center = new Array(this.dim).fill(0);
for (let i=0; i<current_state.length; i++) {
proposal_center[i] = this.gatekeeper[current_class].mean[i] * 2 - current_state[i];
}
} else {
proposal_center = current_state;
}
proposal = transformer.limit_array_in_range(sampling.gaussian_array(proposal_center, this.proposal_cov), this.lower_bound, this.upper_bound);
trial_type = 'likelihood';
return {
current_state: current_state,
current_class: current_class,
proposal: proposal,
trial_type: trial_type
};
} else {
proposal = this.classes[Math.floor(Math.random() * this.n_class)];
while (proposal === current_class) {
proposal = this.classes[Math.floor(Math.random() * this.n_class)];
}
trial_type = 'prior';
const pcx = await this.stimuli_processing(current_state);
return {
current_state: current_state,
current_class: current_class,
stimulus: pcx.image,
proposal: proposal,
trial_type: trial_type
};
}
}
// transformer.limit_array_in_range(this.gatekeeper[current_class].sampling(), this.lower_bound, this.upper_bound)
async generate_stimulus_independence_gatekeeper(table_name) {
const name = table_name.split('_blockwise_')[0];
var current_state, current_class, proposal, trial_type;
const check_table = await pool.query(`
SELECT stimulus, category, for_prior FROM ${table_name} ORDER BY id DESC LIMIT 1
`);
current_state = check_table.rows[0].stimulus;
current_class = check_table.rows[0].category;
// console.log(check_table);
if (check_table.rows[0].for_prior) {
if (this.stuck_count[name] > this.stuck_patience) {
// forced switch to another class
this.stuck_count[name] = 0;
current_class = this.classes[Math.floor(Math.random() * this.n_class)];
while (current_class === check_table.rows[0].category) {
current_class = this.classes[Math.floor(Math.random() * this.n_class)];
}
console.log(`Participant ${name} is stuck in ${check_table.rows[0].category}, switching to another class: ${current_class}`);
}
proposal = transformer.limit_array_in_range(this.gatekeeper[current_class].sampling(), this.lower_bound, this.upper_bound)
trial_type = 'likelihood';
let sumOfSquaredDifferences = 0;
for (let i = 0; i < current_state.length; i++) {
const difference = current_state[i] - proposal[i];
sumOfSquaredDifferences += difference * difference;
}
const distance_between_current_and_proposal = Math.sqrt(sumOfSquaredDifferences);
if (distance_between_current_and_proposal <= this.upper_bound * 0.1) {
// if the proposal is too close to the current state, we need to sample again
await pool.query(
`INSERT INTO ${table_name} (stimulus, category, for_prior, gatekeeper)
VALUES ($1, $2, $3, $4)`,
[JSON.stringify(proposal), current_class, false, true]
);
return this.generate_stimulus_independence_gatekeeper(table_name);
}
return {
current_state: current_state,
current_class: current_class,
proposal: proposal,
trial_type: trial_type
};
} else {
const pcx = await this.stimuli_processing(current_state);
const conditional_image = pcx.image;
// const proposal_index = sampling.sampleFromDistribution(pcx.posterior);
proposal = pcx.posterior;
trial_type = 'prior';
if (proposal === current_class) {
this.stuck_count[name]++;
// if the proposal is the same as the current class, we need to sample again
await pool.query(
`INSERT INTO ${table_name} (stimulus, category, for_prior, gatekeeper)
VALUES ($1, $2, $3, $4)`,
[JSON.stringify(current_state), current_class, true, true]
);
return this.generate_stimulus_independence_gatekeeper(table_name);
}
this.stuck_count[name] = 0;
return {
current_state: current_state,
current_class: current_class,
stimulus: conditional_image,
proposal: proposal,
trial_type: trial_type
};
}
}
async bw_gatekeeper(new_stimuli, table_name) {
if (new_stimuli.trial_type === 'likelihood') {
if (Math.random() > this.gatekeeper[new_stimuli.current_class].acceptance(new_stimuli.current_state, new_stimuli.proposal, this.temperature)) {
await pool.query(
`INSERT INTO ${table_name} (stimulus, category, for_prior, gatekeeper)
VALUES ($1, $2, $3, $4)`,
[JSON.stringify(new_stimuli.current_state), new_stimuli.current_class, false, true]
);
} else {
return new_stimuli;
}
} else if (new_stimuli.trial_type === 'prior') {
const density_current = this.gatekeeper[new_stimuli.current_class].density(new_stimuli.current_state);
const density_proposal = this.gatekeeper[new_stimuli.proposal].density(new_stimuli.current_state);
const accceptance_prob = Math.exp(density_proposal/this.temperature) / (Math.exp(density_current/this.temperature) + Math.exp(density_proposal/this.temperature));
if (Math.random() > accceptance_prob) {
// reject the proposal
await pool.query(
`INSERT INTO ${table_name} (stimulus, category, for_prior, gatekeeper)
VALUES ($1, $2, $3, $4)`,
[JSON.stringify(new_stimuli.current_state), new_stimuli.current_class, true, true]
);
} else {
return new_stimuli;
}
}
return 0;
}
async get_choices(req, res, next) {
const name = req.header('ID');
const current_chain = Math.floor(Math.random() * this.n_chain) + 1;
const table_name = `${name}_blockwise_no${current_chain}`;
const attention_check_trial = Math.random() < this.attention_check_rate;
try {
if (attention_check_trial && this.attention_check) {
const check_table = await pool.query(`
SELECT stimulus, category, for_prior FROM ${table_name} ORDER BY id DESC LIMIT 1
`);
const current_class = check_table.rows[0].category;
const attentionDir = path.join(this.experimentPath, this.attention_check_dir);
const attention_stimuli = transformer.get_attention_stimuli_path(attentionDir, current_class);
res.status(200).json({
'trial_type': 'attention_check',
'current_class': current_class,
"current": transformer.grab_image(attention_stimuli[0]),
"proposal": transformer.grab_image(attention_stimuli[1]),
"attention_check": [attention_stimuli[2][0], attention_stimuli[2][1]]});
} else {
const new_stimuli = this.gatekeeper
? await this.generate_stimulus_independence_gatekeeper(table_name)
: await this.generate_stimulus(table_name);
// if (this.gatekeeper) {
// let gatekeeper_result = await this.bw_gatekeeper(new_stimuli, table_name);
// while (gatekeeper_result === 0) {
// new_stimuli = await this.generate_stimulus(table_name);
// gatekeeper_result = await this.bw_gatekeeper(new_stimuli, table_name);
// }
// new_stimuli = gatekeeper_result;
// }
if (new_stimuli.trial_type === 'likelihood') {
const stimuli_list = [new_stimuli.current_state, new_stimuli.proposal];
const stimuli_list_processed = await this.stimuli_processing_batch(stimuli_list);
res.status(200).json({
'trial_type': new_stimuli.trial_type,
'current_class': new_stimuli.current_class,
'current_chain': current_chain,
'current_position': new_stimuli.current_state,
'proposal_position': new_stimuli.proposal,
"current": stimuli_list_processed[0],
"proposal": stimuli_list_processed[1]});
} else if (new_stimuli.trial_type === 'prior') {
res.status(200).json({
'trial_type': new_stimuli.trial_type,
'stimulus_position': new_stimuli.current_state,
'current_stimulus': new_stimuli.stimulus,
'current_chain': current_chain,
"current": new_stimuli.current_class,
"proposal": new_stimuli.proposal});
}
}
} catch (error) {
next(error);
}
}
async register_choices(req, res, next) {
const name = req.header('ID');
const pid = req.header('name');
const n_trial = req.header('n_trial');
const selected = req.body.choice;
console.log(`Participant ${pid} made a choice: ${selected}`);
if (req.header('trial_type') === 'likelihood') {
try {
const current_class = req.header('current_class');
await pool.query(
`INSERT INTO ${name} (stimulus, category, for_prior)
VALUES ($1, $2, $3)`,
[JSON.stringify(selected), current_class, false]
);
await pool.query(
`UPDATE participants SET "${current_class}_ss" = "${current_class}_ss" + 1 WHERE participant = $1`,
[pid]
);
if (n_trial < this.max_trial) {
res.status(200).json({"finish": 0, "progress": n_trial/this.max_trial});
} else {
res.status(200).json({"finish": 1, "progress": 1});
}
} catch (error) {
next(error);
}
} else if (req.header('trial_type') === 'prior') {
try {
const current_stimulus = req.body.current_position;
// console.log(n_trial);
await pool.query(
`INSERT INTO ${name} (stimulus, category, for_prior)
VALUES ($1, $2, $3)`,
[JSON.stringify(current_stimulus), selected, true]
);
if (n_trial < this.max_trial) {
res.status(200).json({"finish": 0, "progress": n_trial/this.max_trial});
} else {
res.status(200).json({"finish": 1, "progress": 1});
}
} catch (error) {
next(error);
}
}
}
}
module.exports = { Controller };