@tensorflow/tfjs-converter
Version:
Tensorflow model converter for javascript
61 lines • 10.4 kB
JavaScript
export function createNumberAttr(value) {
return { value, type: 'number' };
}
export function createNumberAttrFromIndex(inputIndex) {
return { inputIndexStart: inputIndex, type: 'number' };
}
export function createStrAttr(str) {
return { value: str, type: 'string' };
}
export function createStrArrayAttr(strs) {
return { value: strs, type: 'string[]' };
}
export function createBoolAttr(value) {
return { value, type: 'bool' };
}
export function createTensorShapeAttr(value) {
return { value, type: 'shape' };
}
export function createShapeAttrFromIndex(inputIndex) {
return { inputIndexStart: inputIndex, type: 'shape' };
}
export function createNumericArrayAttr(value) {
return { value, type: 'number[]' };
}
export function createNumericArrayAttrFromIndex(inputIndex) {
return { inputIndexStart: inputIndex, type: 'number[]' };
}
export function createBooleanArrayAttrFromIndex(inputIndex) {
return { inputIndexStart: inputIndex, type: 'bool[]' };
}
export function createTensorAttr(index) {
return { inputIndexStart: index, type: 'tensor' };
}
export function createTensorsAttr(index, paramLength) {
return { inputIndexStart: index, inputIndexEnd: paramLength, type: 'tensors' };
}
export function createDtypeAttr(dtype) {
return { value: dtype, type: 'dtype' };
}
export function validateParam(node, opMappers, tfOpName) {
const opMapper = tfOpName != null ?
opMappers.find(mapper => mapper.tfOpName === tfOpName) :
opMappers.find(mapper => mapper.tfOpName === node.op);
const matched = Object.keys(node.inputParams).every(key => {
const value = node.inputParams[key];
const def = opMapper.inputs.find(param => param.name === key);
return def && def.type === value.type &&
def.start === value.inputIndexStart && def.end === value.inputIndexEnd;
}) &&
Object.keys(node.attrParams).every(key => {
const value = node.attrParams[key];
const def = opMapper.attrs.find(param => param.name === key);
return def && def.type === value.type;
});
if (!matched) {
console.log('node = ', node);
console.log('opMapper = ', opMapper);
}
return matched;
}
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"test_helper.js","sourceRoot":"","sources":["../../../../../../../tfjs-converter/src/operations/executors/test_helper.ts"],"names":[],"mappings":"AAmBA,MAAM,UAAU,gBAAgB,CAAC,KAAa;IAC5C,OAAO,EAAC,KAAK,EAAE,IAAI,EAAE,QAAQ,EAAC,CAAC;AACjC,CAAC;AAED,MAAM,UAAU,yBAAyB,CAAC,UAAkB;IAC1D,OAAO,EAAC,eAAe,EAAE,UAAU,EAAE,IAAI,EAAE,QAAQ,EAAC,CAAC;AACvD,CAAC;AAED,MAAM,UAAU,aAAa,CAAC,GAAW;IACvC,OAAO,EAAC,KAAK,EAAE,GAAG,EAAE,IAAI,EAAE,QAAQ,EAAC,CAAC;AACtC,CAAC;AAED,MAAM,UAAU,kBAAkB,CAAC,IAAc;IAC/C,OAAO,EAAC,KAAK,EAAE,IAAI,EAAE,IAAI,EAAE,UAAU,EAAC,CAAC;AACzC,CAAC;AAED,MAAM,UAAU,cAAc,CAAC,KAAc;IAC3C,OAAO,EAAC,KAAK,EAAE,IAAI,EAAE,MAAM,EAAC,CAAC;AAC/B,CAAC;AAED,MAAM,UAAU,qBAAqB,CAAC,KAAe;IACnD,OAAO,EAAC,KAAK,EAAE,IAAI,EAAE,OAAO,EAAC,CAAC;AAChC,CAAC;AAED,MAAM,UAAU,wBAAwB,CAAC,UAAkB;IACzD,OAAO,EAAC,eAAe,EAAE,UAAU,EAAE,IAAI,EAAE,OAAO,EAAC,CAAC;AACtD,CAAC;AAED,MAAM,UAAU,sBAAsB,CAAC,KAAe;IACpD,OAAO,EAAC,KAAK,EAAE,IAAI,EAAE,UAAU,EAAC,CAAC;AACnC,CAAC;AAED,MAAM,UAAU,+BAA+B,CAAC,UAAkB;IAEhE,OAAO,EAAC,eAAe,EAAE,UAAU,EAAE,IAAI,EAAE,UAAU,EAAC,CAAC;AACzD,CAAC;AAED,MAAM,UAAU,+BAA+B,CAAC,UAAkB;IAEhE,OAAO,EAAC,eAAe,EAAE,UAAU,EAAE,IAAI,EAAE,QAAQ,EAAC,CAAC;AACvD,CAAC;AAED,MAAM,UAAU,gBAAgB,CAAC,KAAa;IAC5C,OAAO,EAAC,eAAe,EAAE,KAAK,EAAE,IAAI,EAAE,QAAQ,EAAC,CAAC;AAClD,CAAC;AAED,MAAM,UAAU,iBAAiB,CAC7B,KAAa,EAAE,WAAmB;IACpC,OAAO,EAAC,eAAe,EAAE,KAAK,EAAE,aAAa,EAAE,WAAW,EAAE,IAAI,EAAE,SAAS,EAAC,CAAC;AAC/E,CAAC;AAED,MAAM,UAAU,eAAe,CAAC,KAAa;IAC3C,OAAO,EAAC,KAAK,EAAE,KAAK,EAAE,IAAI,EAAE,OAAO,EAAC,CAAC;AACvC,CAAC;AAED,MAAM,UAAU,aAAa,CACzB,IAAU,EAAE,SAAqB,EAAE,QAAiB;IACtD,MAAM,QAAQ,GAAG,QAAQ,IAAI,IAAI,CAAC,CAAC;QAC/B,SAAS,CAAC,IAAI,CAAC,MAAM,CAAC,EAAE,CAAC,MAAM,CAAC,QAAQ,KAAK,QAAQ,CAAC,CAAC,CAAC;QACxD,SAAS,CAAC,IAAI,CAAC,MAAM,CAAC,EAAE,CAAC,MAAM,CAAC,QAAQ,KAAK,IAAI,CAAC,EAAE,CAAC,CAAC;IAC1D,MAAM,OAAO,GAAG,MAAM,CAAC,IAAI,CAAC,IAAI,CAAC,WAAW,CAAC,CAAC,KAAK,CAAC,GAAG,CAAC,EAAE;QACxD,MAAM,KAAK,GAAG,IAAI,CAAC,WAAW,CAAC,GAAG,CAAC,CAAC;QACpC,MAAM,GAAG,GAAG,QAAQ,CAAC,MAAM,CAAC,IAAI,CAAC,KAAK,CAAC,EAAE,CAAC,KAAK,CAAC,IAAI,KAAK,GAAG,CAAC,CAAC;QAC9D,OAAO,GAAG,IAAI,GAAG,CAAC,IAAI,KAAK,KAAK,CAAC,IAAI;YACjC,GAAG,CAAC,KAAK,KAAK,KAAK,CAAC,eAAe,IAAI,GAAG,CAAC,GAAG,KAAK,KAAK,CAAC,aAAa,CAAC;IAC7E,CAAC,CAAC;QACE,MAAM,CAAC,IAAI,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC,KAAK,CAAC,GAAG,CAAC,EAAE;YACvC,MAAM,KAAK,GAAG,IAAI,CAAC,UAAU,CAAC,GAAG,CAAC,CAAC;YACnC,MAAM,GAAG,GAAG,QAAQ,CAAC,KAAK,CAAC,IAAI,CAAC,KAAK,CAAC,EAAE,CAAC,KAAK,CAAC,IAAI,KAAK,GAAG,CAAC,CAAC;YAC7D,OAAO,GAAG,IAAI,GAAG,CAAC,IAAI,KAAK,KAAK,CAAC,IAAI,CAAC;QACxC,CAAC,CAAC,CAAC;IACP,IAAI,CAAC,OAAO,EAAE;QACZ,OAAO,CAAC,GAAG,CAAC,SAAS,EAAE,IAAI,CAAC,CAAC;QAC7B,OAAO,CAAC,GAAG,CAAC,aAAa,EAAE,QAAQ,CAAC,CAAC;KACtC;IACD,OAAO,OAAO,CAAC;AACjB,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {InputParamValue, OpMapper, ParamValue} from '../types';\nimport {Node} from '../types';\n\nexport function createNumberAttr(value: number): ParamValue {\n  return {value, type: 'number'};\n}\n\nexport function createNumberAttrFromIndex(inputIndex: number): InputParamValue {\n  return {inputIndexStart: inputIndex, type: 'number'};\n}\n\nexport function createStrAttr(str: string): ParamValue {\n  return {value: str, type: 'string'};\n}\n\nexport function createStrArrayAttr(strs: string[]): ParamValue {\n  return {value: strs, type: 'string[]'};\n}\n\nexport function createBoolAttr(value: boolean): ParamValue {\n  return {value, type: 'bool'};\n}\n\nexport function createTensorShapeAttr(value: number[]): ParamValue {\n  return {value, type: 'shape'};\n}\n\nexport function createShapeAttrFromIndex(inputIndex: number): InputParamValue {\n  return {inputIndexStart: inputIndex, type: 'shape'};\n}\n\nexport function createNumericArrayAttr(value: number[]): ParamValue {\n  return {value, type: 'number[]'};\n}\n\nexport function createNumericArrayAttrFromIndex(inputIndex: number):\n    InputParamValue {\n  return {inputIndexStart: inputIndex, type: 'number[]'};\n}\n\nexport function createBooleanArrayAttrFromIndex(inputIndex: number):\n    InputParamValue {\n  return {inputIndexStart: inputIndex, type: 'bool[]'};\n}\n\nexport function createTensorAttr(index: number): InputParamValue {\n  return {inputIndexStart: index, type: 'tensor'};\n}\n\nexport function createTensorsAttr(\n    index: number, paramLength: number): InputParamValue {\n  return {inputIndexStart: index, inputIndexEnd: paramLength, type: 'tensors'};\n}\n\nexport function createDtypeAttr(dtype: string): ParamValue {\n  return {value: dtype, type: 'dtype'};\n}\n\nexport function validateParam(\n    node: Node, opMappers: OpMapper[], tfOpName?: string) {\n  const opMapper = tfOpName != null ?\n      opMappers.find(mapper => mapper.tfOpName === tfOpName) :\n      opMappers.find(mapper => mapper.tfOpName === node.op);\n  const matched = Object.keys(node.inputParams).every(key => {\n    const value = node.inputParams[key];\n    const def = opMapper.inputs.find(param => param.name === key);\n    return def && def.type === value.type &&\n        def.start === value.inputIndexStart && def.end === value.inputIndexEnd;\n  }) &&\n      Object.keys(node.attrParams).every(key => {\n        const value = node.attrParams[key];\n        const def = opMapper.attrs.find(param => param.name === key);\n        return def && def.type === value.type;\n      });\n  if (!matched) {\n    console.log('node = ', node);\n    console.log('opMapper = ', opMapper);\n  }\n  return matched;\n}\n"]}