UNPKG

arc-agents

Version:

A library for creating and deploying gaming agents at scale

84 lines (73 loc) 2.38 kB
const { FrequencyTable, NeuralNetworkMultihead, // HierarchicalNeuralNetwork, DataCollector } = require('arc-ml') const ProbabilisticAgentWrapper = require('../agent-wrappers/probabilistic-agent-wrapper') const modelsMapping = { "simple": FrequencyTable, "neural-network": NeuralNetworkMultihead, // "hierarchical": HierarchicalNeuralNetwork, } class AgentCore { constructor(agentConfig = { useAgentWrapper: false }) { this.agentConfig = agentConfig this.dataCollector = new DataCollector() this.initializedBool = false } createAgent(modelData) { this.modelConfig = modelData.config this.modelType = modelData.config.modelType if (modelsMapping[this.modelType] === undefined) { throw Error("Invalid model type") } let createdNewModel = false if (modelData.parameters === undefined) { this.reset() createdNewModel = true } else { this.model = new modelsMapping[this.modelType](modelData) this.agent = new ProbabilisticAgentWrapper(this.model, this.agentConfig) } this.initializedBool = true return createdNewModel } collect(dataInstance) { return this.dataCollector.collect(dataInstance) } getTrainingData() { return this.dataCollector.trainingData } getProbabilities(inputs, postTrainingBool=false) { if (postTrainingBool && this.trainedModel) { return this.trainedModel.getProbabilities(inputs) } return this.model.getProbabilities(inputs) } selectAction(inputs, postTrainingBool=false) { if (postTrainingBool && this.trainedModel) { if (this.agentConfig?.useAgentWrapper && this.trainedAgent) { return this.trainedAgent.selectAction(inputs) } else { return this.trainedModel.selectAction(inputs) } } if (this.agentConfig?.useAgentWrapper) { return this.agent.selectAction(inputs) } else { return this.model.selectAction(inputs) } } async reset() { if (this.modelConfig === undefined) { throw Error("Model config is not defined") } this.model = new modelsMapping[this.modelType]({ config: this.modelConfig }) this.agent = new ProbabilisticAgentWrapper(this.model, this.agentConfig) } } module.exports = { AgentCore, modelsMapping }