UNPKG

@ai-on-browser/data-analysis-models

Version:

Data analysis model package without any dependencies

141 lines (124 loc) 3.08 kB
import { RLEnvironmentBase } from '../rl/base.js' import { QTableBase } from './q_learning.js' class SoftmaxPolicyGradient { // https://book.mynavi.jp/manatee/detail/id=88297 // https://qiita.com/shionhonda/items/ec05aade07b5bea78081 constructor(env, resolution = 20) { this._params = new QTableBase(env, resolution) this._epoch = 0 } get _state_sizes() { return this._params._state_sizes } get _action_sizes() { return this._params._action_sizes } _state_index(state) { return this._params._state_index(state) } _action_index(action) { return this._params._action_index(action) } probability(state) { state = this._params._state_index(state) const [p] = this._params._q(state) const expp = p.map(Math.exp) const s = expp.reduce((a, v) => a + v, 0) const pi = expp.map(v => v / s) return pi } toArray() { return this._params.toArray() } get_action(state) { const pi = this.probability(state) const r = Math.random() let cumu = 0 let k = -1 for (let i = 0; i < pi.length; i++) { cumu += pi[i] if (r < cumu) { k = i break } } return this._params._action_value(this._params._to_index(this._action_sizes, k)) } update(actions, learning_rate) { const n = actions.length const stateCount = [] const actionCount = {} for (const action of actions) { let [act, state] = action state = this._state_index(state) act = this._action_index(act) const si = this._params._to_position(this._state_sizes, state)[0] stateCount[si] = (stateCount[si] || 0) + 1 const [, i] = this._params._q(state, act) if (!actionCount[i]) { const prob = this.probability(state) const aidx = this._params._to_position(this._action_sizes, act)[0] actionCount[i] = { n: 0, s: si, p: prob[aidx], } } actionCount[i].n++ } for (const i of Object.keys(actionCount)) { const a = actionCount[i] this._params._table[i] += (learning_rate * (a.n + a.p * stateCount[a.s])) / n } this._epoch++ } } /** * Policy gradient agent */ export default class PGAgent { /** * @param {RLEnvironmentBase} env Environment * @param {number} [resolution] Resolution */ constructor(env, resolution = 20) { this._table = new SoftmaxPolicyGradient(env, resolution) this._history = [] } /** * Reset agent. */ reset() { this._history = [] } /** * Returns a score. * @returns {Array<Array<Array<number>>>} Score values */ get_score() { return this._table.toArray() } /** * Returns a action. * @param {*[]} state Current states * @returns {*[]} Action */ get_action(state) { return this._table.get_action(state) } /** * Update model. * @param {*[]} action Action * @param {*[]} state Next states * @param {number} reward Reward * @param {boolean} done Done epoch or not * @param {number} learning_rate Learning rate */ update(action, state, reward, done, learning_rate) { this._history.push([action, state, reward]) if (done) { this._table.update(this._history, learning_rate) this._history = [] } } }