arc-agents
Version:
A library for creating and deploying gaming agents at scale
84 lines (73 loc) • 2.38 kB
JavaScript
const {
FrequencyTable,
NeuralNetworkMultihead,
// HierarchicalNeuralNetwork,
DataCollector
} = require('arc-ml')
const ProbabilisticAgentWrapper = require('../agent-wrappers/probabilistic-agent-wrapper')
const modelsMapping = {
"simple": FrequencyTable,
"neural-network": NeuralNetworkMultihead,
// "hierarchical": HierarchicalNeuralNetwork,
}
class AgentCore {
constructor(agentConfig = { useAgentWrapper: false }) {
this.agentConfig = agentConfig
this.dataCollector = new DataCollector()
this.initializedBool = false
}
createAgent(modelData) {
this.modelConfig = modelData.config
this.modelType = modelData.config.modelType
if (modelsMapping[this.modelType] === undefined) {
throw Error("Invalid model type")
}
let createdNewModel = false
if (modelData.parameters === undefined) {
this.reset()
createdNewModel = true
}
else {
this.model = new modelsMapping[this.modelType](modelData)
this.agent = new ProbabilisticAgentWrapper(this.model, this.agentConfig)
}
this.initializedBool = true
return createdNewModel
}
collect(dataInstance) {
return this.dataCollector.collect(dataInstance)
}
getTrainingData() {
return this.dataCollector.trainingData
}
getProbabilities(inputs, postTrainingBool=false) {
if (postTrainingBool && this.trainedModel) {
return this.trainedModel.getProbabilities(inputs)
}
return this.model.getProbabilities(inputs)
}
selectAction(inputs, postTrainingBool=false) {
if (postTrainingBool && this.trainedModel) {
if (this.agentConfig?.useAgentWrapper && this.trainedAgent) {
return this.trainedAgent.selectAction(inputs)
}
else {
return this.trainedModel.selectAction(inputs)
}
}
if (this.agentConfig?.useAgentWrapper) {
return this.agent.selectAction(inputs)
}
else {
return this.model.selectAction(inputs)
}
}
async reset() {
if (this.modelConfig === undefined) {
throw Error("Model config is not defined")
}
this.model = new modelsMapping[this.modelType]({ config: this.modelConfig })
this.agent = new ProbabilisticAgentWrapper(this.model, this.agentConfig)
}
}
module.exports = { AgentCore, modelsMapping }