ppljs-ppl-core
Version:
ppljs network inference framework core module
89 lines (82 loc) • 3.65 kB
text/typescript
import {ModelConfig,Model} from './interface/interface'
/*
*@author xusiyu@sensetime.com
*@remark asyn parse will be implement later;
*@breif model includes model structure json files and model data(weight,bias...) binary file,
* so we shoud parse the json model and read the model data
*@process
*@1.determine the path of model file and data file.we do this work in constructor()
*@2.parse the model json file according to the path.we do this work in parseJsonModel().
*@3.read the model data according to the path.we do this word in readModelData()
*/
export default class ModelLoader {
modelConfig_: ModelConfig;
isLocalPath: boolean;
//模型路径可以为网络地址or本地,本地文件仅供测试使用
constructor(modelConfig:ModelConfig){
this.modelConfig_ = modelConfig;
//if ModelConfig don't provide model name,we use default name
if(modelConfig.modelName == undefined)
this.modelConfig_.modelName = "model.json"
if(modelConfig.binaryDataName == undefined)
this.modelConfig_.binaryDataName = "model.dat"
if (modelConfig.modelPath.charAt(modelConfig.modelPath.length - 1) != '/') {
this.modelConfig_.modelPath = `${modelConfig.modelPath}/`;
}
this.isLocalPath = modelConfig.modelPath.indexOf('http') != 0;
}
async LoadModel():Promise<Model> {
var p1 = this.ParseJsonModel();
var p2 = this.ReadModelData();
return new Promise((resolve)=>{
Promise.all([p1,p2]).then(res =>{
this.AssignData(res[0],res[1]);
resolve(res[0]);
});
});
}
readData(modelPath:string){
return new Promise((resolve, reject) => {
fetch(modelPath, {
method: 'get',
}).then(response =>{ resolve(response.arrayBuffer());})//可以获取blob或者arrayBuffer
.then(err => reject(err));
});
}
ParseJsonModel():Promise<Model> {
const modelPath = this.modelConfig_.modelPath+this.modelConfig_.modelName;
return new Promise((resolve, reject) => {
fetch(modelPath, {
method: 'get',
}).then(response =>{resolve(response.json());})
.then(err => reject(err));
});
}
ReadModelData():Promise<Float32Array> {
const binaryPath = this.modelConfig_.modelPath+this.modelConfig_.binaryDataName;
return new Promise((resolve, reject) => {
this.readData(binaryPath).then(response=> {
var array = new Float32Array(<ArrayBuffer>response);
//console.log("Float32Array length is ",array.length);
//console.log("Float32Array data is ",array);
resolve(array);
}).catch(err => reject(err));
});
}
AssignData(opModel_:Model,dataArray: Float32Array) {
//we put the binary data in dataArray
opModel_.ops.forEach((op)=>{
if(op.data!=undefined){
var offset:number = op.hostDataOffset;
//通过offset去读取响应的数据
for (var item in op.data) {
var size_: number = op.data[item].size;
//we will read size_ length data from offset
//slice是pos 而非byte
op.data[item].hostData = dataArray.slice(offset,offset+size_);
offset +=(size_);
}
}
})
}
};