UNPKG

ppljs-ppl-core

Version:

ppljs network inference framework core module

89 lines (82 loc) 3.65 kB
import {ModelConfig,Model} from './interface/interface' /* *@author xusiyu@sensetime.com *@remark asyn parse will be implement later; *@breif model includes model structure json files and model data(weight,bias...) binary file, * so we shoud parse the json model and read the model data *@process *@1.determine the path of model file and data file.we do this work in constructor() *@2.parse the model json file according to the path.we do this work in parseJsonModel(). *@3.read the model data according to the path.we do this word in readModelData() */ export default class ModelLoader { modelConfig_: ModelConfig; isLocalPath: boolean; //模型路径可以为网络地址or本地,本地文件仅供测试使用 constructor(modelConfig:ModelConfig){ this.modelConfig_ = modelConfig; //if ModelConfig don't provide model name,we use default name if(modelConfig.modelName == undefined) this.modelConfig_.modelName = "model.json" if(modelConfig.binaryDataName == undefined) this.modelConfig_.binaryDataName = "model.dat" if (modelConfig.modelPath.charAt(modelConfig.modelPath.length - 1) != '/') { this.modelConfig_.modelPath = `${modelConfig.modelPath}/`; } this.isLocalPath = modelConfig.modelPath.indexOf('http') != 0; } async LoadModel():Promise<Model> { var p1 = this.ParseJsonModel(); var p2 = this.ReadModelData(); return new Promise((resolve)=>{ Promise.all([p1,p2]).then(res =>{ this.AssignData(res[0],res[1]); resolve(res[0]); }); }); } readData(modelPath:string){ return new Promise((resolve, reject) => { fetch(modelPath, { method: 'get', }).then(response =>{ resolve(response.arrayBuffer());})//可以获取blob或者arrayBuffer .then(err => reject(err)); }); } ParseJsonModel():Promise<Model> { const modelPath = this.modelConfig_.modelPath+this.modelConfig_.modelName; return new Promise((resolve, reject) => { fetch(modelPath, { method: 'get', }).then(response =>{resolve(response.json());}) .then(err => reject(err)); }); } ReadModelData():Promise<Float32Array> { const binaryPath = this.modelConfig_.modelPath+this.modelConfig_.binaryDataName; return new Promise((resolve, reject) => { this.readData(binaryPath).then(response=> { var array = new Float32Array(<ArrayBuffer>response); //console.log("Float32Array length is ",array.length); //console.log("Float32Array data is ",array); resolve(array); }).catch(err => reject(err)); }); } AssignData(opModel_:Model,dataArray: Float32Array) { //we put the binary data in dataArray opModel_.ops.forEach((op)=>{ if(op.data!=undefined){ var offset:number = op.hostDataOffset; //通过offset去读取响应的数据 for (var item in op.data) { var size_: number = op.data[item].size; //we will read size_ length data from offset //slice是pos 而非byte op.data[item].hostData = dataArray.slice(offset,offset+size_); offset +=(size_); } } }) } };