@ai-on-browser/data-analysis-models
Version:
Data analysis model package without any dependencies
270 lines (242 loc) • 7.03 kB
JavaScript
import { RLEnvironmentBase, RLRealRange } from '../rl/base.js'
import NeuralNetwork from './neuralnetwork.js'
/**
* @ignore
* @typedef {import("./nns/graph").LayerObject} LayerObject
*/
const argmax = function (arr) {
if (arr.length === 0) {
return -1
}
return arr.indexOf(Math.max(...arr))
}
class DQN {
// https://qiita.com/sugulu/items/bc7c70e6658f204f85f9
constructor(env, resolution = 20, layers = [], optimizer = 'sgd') {
this._resolution = resolution
this._states = env.states
this._actions = env.actions
this._action_sizes = env.actions.map(a => {
if (Array.isArray(a)) {
return a.length
} else {
return resolution
}
})
this._gamma = 0.99
this._epoch = 0
this._method = 'DQN'
this._memory = []
this._max_memory_size = 100000
this._batch_size = 1000
this._do_update_step = 10
this._fix_param_update_step = 1000
this._layers = [{ type: 'input' }]
this._layers.push(...layers)
this._layers.push(
{ type: 'full', out_size: this._action_sizes.reduce((s, v) => s * v, 1) },
{ type: 'output', name: 'output' },
{ type: 'huber' }
)
this._target = null
this._net = NeuralNetwork.fromObject(this._layers, null, optimizer)
}
set method(value) {
this._method = value
if (value === 'DQN' && this._target) {
this._target = null
}
}
get_best_action(state) {
state = this._state_to_input(state)
const data = this._net.calc([state])
return this._pos_action(argmax(data.toArray()[0]))
}
_state_to_input(s) {
const state = []
for (let i = 0; i < s.length; i++) {
if (Array.isArray(this._states[i])) {
for (let k = 0; k < this._states[i].length; k++) {
state.push(this._states[i][k] === s[i] ? 1 : 0)
}
} else {
state.push(s[i])
}
}
return state
}
get_score() {
if (!this._states_data) {
const state_sizes = this._states.map(s => s.toArray(this._resolution).length)
this._states_data = []
const next_idx = n => {
for (let i = 0; i < n.length; i++) {
n[i]++
if (n[i] < state_sizes[i]) return true
n[i] = 0
}
return false
}
const state = Array(this._states.length).fill(0)
do {
this._states_data.push([].concat(state))
} while (next_idx(state))
}
const a = this._net.calc(this._states_data).toArray()
const d = []
const m = this._states.length
for (let i = 0; i < this._states_data.length; i++) {
let di = d
for (let k = 0; k < m - 1; k++) {
if (!di[this._states_data[i][k]]) di[this._states_data[i][k]] = []
di = di[this._states_data[i][k]]
}
di[this._states_data[i][m - 1]] = a[i]
}
return d
}
_action_pos(action) {
let i = 0
for (let k = 0; k < action.length; k++) {
i = i * this._action_sizes[k]
if (Array.isArray(this._actions[k])) {
i += this._actions[k].indexOf(action[k])
} else if (this._actions[k] instanceof RLRealRange) {
i += this._actions[k].indexOf(action[k], this._resolution)
} else {
throw 'Not implemented'
}
}
return i
}
_pos_action(i) {
const action = []
for (let k = this._action_sizes.length - 1; k >= 0; k--) {
const p = i % this._action_sizes[k]
i = Math.floor(i / this._action_sizes[k])
if (Array.isArray(this._actions[k])) {
action.unshift(this._actions[k][p])
} else if (this._actions[k] instanceof RLRealRange) {
action.unshift(this._actions[k].toArray(this._resolution)[p])
} else {
throw 'Not implemented'
}
}
return action
}
update(action, state, next_state, reward, done, learning_rate, batch) {
const rstate = this._state_to_input(state)
const rnstate = this._state_to_input(next_state)
this._memory.push([action, rstate, rnstate, reward])
if (this._memory.length < this._batch_size) {
return
} else if (this._memory.length > this._max_memory_size) {
this._memory.shift()
}
if (++this._epoch % this._do_update_step > 0) {
return
}
const idx = []
for (let i = 0; i < this._batch_size; i++) {
let r = Math.floor(Math.random() * (this._memory.length - i))
let j = 0
for (; j < idx.length && idx[j] <= r; j++, r++);
idx.splice(j, 0, r)
}
const select_data = idx.map(i => this._memory[i])
if (this._method === 'DDQN') {
return this._update_ddqn(select_data, learning_rate, batch)
} else {
return this._update_dqn(select_data, learning_rate, batch)
}
}
_update_dqn(data, learning_rate, batch) {
const x = data.map(d => d[1])
const next_x = data.map(d => d[2])
const xx = [].concat(x, next_x)
const est_q = this._net.calc(xx).toArray()
const q = est_q.slice(0, x.length)
const next_q = est_q.slice(x.length)
for (let i = 0; i < q.length; i++) {
const a_idx = this._action_pos(data[i][0])
q[i][a_idx] = data[i][3] + this._gamma * Math.max(...next_q[i])
}
const loss = this._net.fit(x, q, 1, learning_rate, batch)
return loss[0]
}
_update_ddqn(data, learning_rate, batch) {
const x = data.map(d => d[1])
const next_x = data.map(d => d[2])
const xx = [].concat(x, next_x)
const est_q = this._net.calc(xx).toArray()
const q = est_q.slice(0, x.length)
const next_q = est_q.slice(x.length)
const next_t_q = (this._target || this._net).calc(next_x).toArray()
for (let i = 0; i < q.length; i++) {
const a_idx = this._action_pos(data[i][0])
q[i][a_idx] = data[i][3] + this._gamma * next_t_q[i][argmax(next_q[i])]
}
const loss = this._net.fit(x, q, 1, learning_rate, batch)
if (this._epoch % this._fix_param_update_step) {
this._target = this._net.copy()
}
return loss[0]
}
}
/**
* Deep Q-Network agent
*/
export default class DQNAgent {
/**
* @param {RLEnvironmentBase} env Environment
* @param {number} resolution Resolution of actions
* @param {LayerObject[]} layers Network layers
* @param {string} optimizer Optimizer of the network
*/
constructor(env, resolution, layers, optimizer) {
this._env = env
this._net = new DQN(env, resolution, layers, optimizer)
}
/**
* DQN Method
* @param {'DQN' | 'DDQN'} value New method name
*/
set method(value) {
this._net.method = value
}
terminate() {}
/**
* Returns a score.
* @returns {Array<Array<Array<number>>>} Score values
*/
get_score() {
return this._net.get_score()
}
/**
* Returns a action.
* @param {*[]} state Current states
* @param {number} greedy_rate Greedy rate
* @returns {*[]} Action
*/
get_action(state, greedy_rate = 0.002) {
if (Math.random() > greedy_rate) {
return this._net.get_best_action(state)
} else {
return this._env.sample_action(this)
}
}
/**
* Update model.
* @param {*[]} action Action
* @param {*[]} state Current states
* @param {*[]} next_state Next states
* @param {number} reward Reward
* @param {boolean} done Done epoch or not
* @param {number} learning_rate Learning rate
* @param {number} batch Batch size
* @returns {number=} Loss value
*/
update(action, state, next_state, reward, done, learning_rate, batch) {
return this._net.update(action, state, next_state, reward, done, learning_rate, batch)
}
}