UNPKG

aiom_pack

Version:

Framework for interdependent (mcmc-like) behavioral experiments

380 lines (362 loc) 17.1 kB
// controllers/dataController.js const { pool } = require('../../core/database'); // const { ExperimentConfig } = require('../../core/config'); const sampling = require('../../models/sampling'); const transformer = require('../../models/transformer'); const gk = require('../../models/gatekeeper'); // const converge = require('../../models/convergence'); const fs = require('fs'); const path = require('path'); class Controller { constructor(experimentPath, config) { this.config = config; this.experimentPath = experimentPath; // Initialize experiment settings from config this.n_chain = this.config.getNumber('n_chain'); this.lower_bound = this.config.getNumber('lower_bound'); this.upper_bound = this.config.getNumber('upper_bound'); this.classes = this.config.getArray('classes'); this.class_questions = this.config.getArray('class_questions'); this.n_class = this.classes.length; this.max_trial = this.config.getNumber('trial_per_participant_per_class'); this.dim = this.config.getNumber('dim'); this.mode = this.config.get('mode'); this.n_rest = this.config.getNumber('n_rest'); this.attention_check = this.config.getBoolean('attention_check'); this.attention_check_dir = this.config.get('attention_check_dir'); this.attention_check_rate = Number(this.config.get('attention_check_rate')); this.temperature = 2.0; this.stuck_count = {}; this.stuck_patience = this.config.getNumber('stuck_patience'); this.proposal_cov_v = this.config.getNumber('proposal_cov'); this.proposal_cov = Array(this.dim).fill().map((_, i) => Array(this.dim).fill().map((_, j) => i === j ? this.proposal_cov_v : 0) ); // align with process.env.dim if (this.mode==='test') { this.stimuli_processing = transformer.raw; this.stimuli_processing_batch = transformer.raw; } else if (this.mode==='image') { this.stimuli_processing = transformer.to_image; this.stimuli_processing_batch = transformer.to_image_gsp; } // Setup other config-dependent properties this.setupGatekeeper(); } setupGatekeeper() { if (this.config.get('gatekeeper') === 'false') { this.gatekeeper = false; } else if (this.config.get('gatekeeper') === 'true') { const modelsDir = this.config.get('gatekeeper_dir'); this.gatekeeper = {}; for (const cate of this.classes) { const modelFilename = `${cate}.json`; const modelFilePath = path.join(this.experimentPath, modelsDir, modelFilename); const modelParamsJson = fs.readFileSync(modelFilePath, 'utf8'); const gatekeeper_parameters = JSON.parse(modelParamsJson); this.gatekeeper[cate] = new gk.GaussianKDE(gatekeeper_parameters); console.log(`Gatekeeper ${cate} initialized successfully with custom models in ${modelsDir}`); } } else { this.gatekeeper = false; } } async initialize() { try { // Create table if it doesn't exist await pool.query(`CREATE TABLE IF NOT EXISTS participants ( id SERIAL PRIMARY KEY, participant TEXT NOT NULL, attention_check_fail INTEGER DEFAULT 0, face_authorization BOOLEAN DEFAULT false, completion TEXT, bonus_issued BOOLEAN DEFAULT false );`); for (const colname of this.classes) { await pool.query(`ALTER TABLE participants ADD COLUMN IF NOT EXISTS "${colname}_ss" INTEGER DEFAULT 0;`); } } catch (error) { console.error('Error setting up initialization database:', error); } } async set_table(req, res, next) { const name = req.body.names; this.stuck_count[name] = 0; var table_name; try { await pool.query( `INSERT INTO participants (participant) VALUES ($1)`, [name] ); // create tables and insert the starting point for each chain for (let i=1; i<=this.n_chain; i++) { table_name = `${name}_blockwise_no${i}`; await pool.query(`CREATE TABLE IF NOT EXISTS ${table_name} ( id SERIAL PRIMARY KEY, stimulus JSON NOT NULL, category TEXT NOT NULL, for_prior BOOLEAN, gatekeeper BOOLEAN );`); const current_class = this.classes[Math.floor(Math.random() * this.n_class)]; const current_state = this.gatekeeper ? transformer.limit_array_in_range(this.gatekeeper[current_class].sampling(), this.lower_bound, this.upper_bound) : sampling.uniform_array(this.dim, this.lower_bound, this.upper_bound); await pool.query( `INSERT INTO ${table_name} (stimulus, category, for_prior) VALUES ($1, $2, $3)`, [JSON.stringify(current_state), current_class, true] ); } res.status(200).json({ "classes": this.classes, "class_questions": this.class_questions, "n_rest": this.n_rest, "mode": this.mode, }); } catch (error) { next(error); } } async generate_stimulus(table_name) { var current_state, current_class, proposal, trial_type; const check_table = await pool.query(` SELECT stimulus, category, for_prior FROM ${table_name} ORDER BY id DESC LIMIT 1 `); current_state = check_table.rows[0].stimulus; current_class = check_table.rows[0].category; // console.log(check_table); if (check_table.rows[0].for_prior) { let proposal_center; if (this.gatekeeper) { proposal_center = new Array(this.dim).fill(0); for (let i=0; i<current_state.length; i++) { proposal_center[i] = this.gatekeeper[current_class].mean[i] * 2 - current_state[i]; } } else { proposal_center = current_state; } proposal = transformer.limit_array_in_range(sampling.gaussian_array(proposal_center, this.proposal_cov), this.lower_bound, this.upper_bound); trial_type = 'likelihood'; return { current_state: current_state, current_class: current_class, proposal: proposal, trial_type: trial_type }; } else { proposal = this.classes[Math.floor(Math.random() * this.n_class)]; while (proposal === current_class) { proposal = this.classes[Math.floor(Math.random() * this.n_class)]; } trial_type = 'prior'; const pcx = await this.stimuli_processing(current_state); return { current_state: current_state, current_class: current_class, stimulus: pcx.image, proposal: proposal, trial_type: trial_type }; } } // transformer.limit_array_in_range(this.gatekeeper[current_class].sampling(), this.lower_bound, this.upper_bound) async generate_stimulus_independence_gatekeeper(table_name) { const name = table_name.split('_blockwise_')[0]; var current_state, current_class, proposal, trial_type; const check_table = await pool.query(` SELECT stimulus, category, for_prior FROM ${table_name} ORDER BY id DESC LIMIT 1 `); current_state = check_table.rows[0].stimulus; current_class = check_table.rows[0].category; // console.log(check_table); if (check_table.rows[0].for_prior) { if (this.stuck_count[name] > this.stuck_patience) { // forced switch to another class this.stuck_count[name] = 0; current_class = this.classes[Math.floor(Math.random() * this.n_class)]; while (current_class === check_table.rows[0].category) { current_class = this.classes[Math.floor(Math.random() * this.n_class)]; } console.log(`Participant ${name} is stuck in ${check_table.rows[0].category}, switching to another class: ${current_class}`); } proposal = transformer.limit_array_in_range(this.gatekeeper[current_class].sampling(), this.lower_bound, this.upper_bound) trial_type = 'likelihood'; let sumOfSquaredDifferences = 0; for (let i = 0; i < current_state.length; i++) { const difference = current_state[i] - proposal[i]; sumOfSquaredDifferences += difference * difference; } const distance_between_current_and_proposal = Math.sqrt(sumOfSquaredDifferences); if (distance_between_current_and_proposal <= this.upper_bound * 0.1) { // if the proposal is too close to the current state, we need to sample again await pool.query( `INSERT INTO ${table_name} (stimulus, category, for_prior, gatekeeper) VALUES ($1, $2, $3, $4)`, [JSON.stringify(proposal), current_class, false, true] ); return this.generate_stimulus_independence_gatekeeper(table_name); } return { current_state: current_state, current_class: current_class, proposal: proposal, trial_type: trial_type }; } else { const pcx = await this.stimuli_processing(current_state); const conditional_image = pcx.image; // const proposal_index = sampling.sampleFromDistribution(pcx.posterior); proposal = pcx.posterior; trial_type = 'prior'; if (proposal === current_class) { this.stuck_count[name]++; // if the proposal is the same as the current class, we need to sample again await pool.query( `INSERT INTO ${table_name} (stimulus, category, for_prior, gatekeeper) VALUES ($1, $2, $3, $4)`, [JSON.stringify(current_state), current_class, true, true] ); return this.generate_stimulus_independence_gatekeeper(table_name); } this.stuck_count[name] = 0; return { current_state: current_state, current_class: current_class, stimulus: conditional_image, proposal: proposal, trial_type: trial_type }; } } async bw_gatekeeper(new_stimuli, table_name) { if (new_stimuli.trial_type === 'likelihood') { if (Math.random() > this.gatekeeper[new_stimuli.current_class].acceptance(new_stimuli.current_state, new_stimuli.proposal, this.temperature)) { await pool.query( `INSERT INTO ${table_name} (stimulus, category, for_prior, gatekeeper) VALUES ($1, $2, $3, $4)`, [JSON.stringify(new_stimuli.current_state), new_stimuli.current_class, false, true] ); } else { return new_stimuli; } } else if (new_stimuli.trial_type === 'prior') { const density_current = this.gatekeeper[new_stimuli.current_class].density(new_stimuli.current_state); const density_proposal = this.gatekeeper[new_stimuli.proposal].density(new_stimuli.current_state); const accceptance_prob = Math.exp(density_proposal/this.temperature) / (Math.exp(density_current/this.temperature) + Math.exp(density_proposal/this.temperature)); if (Math.random() > accceptance_prob) { // reject the proposal await pool.query( `INSERT INTO ${table_name} (stimulus, category, for_prior, gatekeeper) VALUES ($1, $2, $3, $4)`, [JSON.stringify(new_stimuli.current_state), new_stimuli.current_class, true, true] ); } else { return new_stimuli; } } return 0; } async get_choices(req, res, next) { const name = req.header('ID'); const current_chain = Math.floor(Math.random() * this.n_chain) + 1; const table_name = `${name}_blockwise_no${current_chain}`; const attention_check_trial = Math.random() < this.attention_check_rate; try { if (attention_check_trial && this.attention_check) { const check_table = await pool.query(` SELECT stimulus, category, for_prior FROM ${table_name} ORDER BY id DESC LIMIT 1 `); const current_class = check_table.rows[0].category; const attentionDir = path.join(this.experimentPath, this.attention_check_dir); const attention_stimuli = transformer.get_attention_stimuli_path(attentionDir, current_class); res.status(200).json({ 'trial_type': 'attention_check', 'current_class': current_class, "current": transformer.grab_image(attention_stimuli[0]), "proposal": transformer.grab_image(attention_stimuli[1]), "attention_check": [attention_stimuli[2][0], attention_stimuli[2][1]]}); } else { const new_stimuli = this.gatekeeper ? await this.generate_stimulus_independence_gatekeeper(table_name) : await this.generate_stimulus(table_name); // if (this.gatekeeper) { // let gatekeeper_result = await this.bw_gatekeeper(new_stimuli, table_name); // while (gatekeeper_result === 0) { // new_stimuli = await this.generate_stimulus(table_name); // gatekeeper_result = await this.bw_gatekeeper(new_stimuli, table_name); // } // new_stimuli = gatekeeper_result; // } if (new_stimuli.trial_type === 'likelihood') { const stimuli_list = [new_stimuli.current_state, new_stimuli.proposal]; const stimuli_list_processed = await this.stimuli_processing_batch(stimuli_list); res.status(200).json({ 'trial_type': new_stimuli.trial_type, 'current_class': new_stimuli.current_class, 'current_chain': current_chain, 'current_position': new_stimuli.current_state, 'proposal_position': new_stimuli.proposal, "current": stimuli_list_processed[0], "proposal": stimuli_list_processed[1]}); } else if (new_stimuli.trial_type === 'prior') { res.status(200).json({ 'trial_type': new_stimuli.trial_type, 'stimulus_position': new_stimuli.current_state, 'current_stimulus': new_stimuli.stimulus, 'current_chain': current_chain, "current": new_stimuli.current_class, "proposal": new_stimuli.proposal}); } } } catch (error) { next(error); } } async register_choices(req, res, next) { const name = req.header('ID'); const pid = req.header('name'); const n_trial = req.header('n_trial'); const selected = req.body.choice; console.log(`Participant ${pid} made a choice: ${selected}`); if (req.header('trial_type') === 'likelihood') { try { const current_class = req.header('current_class'); await pool.query( `INSERT INTO ${name} (stimulus, category, for_prior) VALUES ($1, $2, $3)`, [JSON.stringify(selected), current_class, false] ); await pool.query( `UPDATE participants SET "${current_class}_ss" = "${current_class}_ss" + 1 WHERE participant = $1`, [pid] ); if (n_trial < this.max_trial) { res.status(200).json({"finish": 0, "progress": n_trial/this.max_trial}); } else { res.status(200).json({"finish": 1, "progress": 1}); } } catch (error) { next(error); } } else if (req.header('trial_type') === 'prior') { try { const current_stimulus = req.body.current_position; // console.log(n_trial); await pool.query( `INSERT INTO ${name} (stimulus, category, for_prior) VALUES ($1, $2, $3)`, [JSON.stringify(current_stimulus), selected, true] ); if (n_trial < this.max_trial) { res.status(200).json({"finish": 0, "progress": n_trial/this.max_trial}); } else { res.status(200).json({"finish": 1, "progress": 1}); } } catch (error) { next(error); } } } } module.exports = { Controller };