@tensorflow/tfjs-layers
Version:
TensorFlow layers API in JavaScript
105 lines • 15.2 kB
JavaScript
/**
* @license
* Copyright 2022 CodeSmith LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
import { serialization, unstack, stack, tensor, tidy, range, image } from '@tensorflow/tfjs-core';
import { getExactlyOneShape, getExactlyOneTensor } from '../../utils/types_utils';
import { Layer } from '../../engine/topology';
import * as K from '../../backend/tfjs_backend';
const { resizeBilinear, cropAndResize } = image;
class CenterCrop extends Layer {
constructor(args) {
super(args);
this.height = args.height;
this.width = args.width;
}
centerCrop(inputs, hBuffer, wBuffer, height, width, inputHeight, inputWidth, dtype) {
return tidy(() => {
let input;
let isRank3 = false;
const top = hBuffer / inputHeight;
const left = wBuffer / inputWidth;
const bottom = ((height) + hBuffer) / inputHeight;
const right = ((width) + wBuffer) / inputWidth;
const bound = [top, left, bottom, right];
const boxesArr = [];
if (inputs.rank === 3) {
isRank3 = true;
input = stack([inputs]);
}
else {
input = inputs;
}
for (let i = 0; i < input.shape[0]; i++) {
boxesArr.push(bound);
}
const boxes = tensor(boxesArr, [boxesArr.length, 4]);
const boxInd = range(0, boxesArr.length, 1, 'int32');
const cropSize = [height, width];
const cropped = cropAndResize(input, boxes, boxInd, cropSize, 'nearest');
if (isRank3) {
return K.cast(getExactlyOneTensor(unstack(cropped)), dtype);
}
return K.cast(cropped, dtype);
});
}
upsize(inputs, height, width, dtype) {
return tidy(() => {
const outputs = resizeBilinear(inputs, [height, width]);
return K.cast(outputs, dtype);
});
}
call(inputs, kwargs) {
return tidy(() => {
const rankedInputs = getExactlyOneTensor(inputs);
const dtype = rankedInputs.dtype;
const inputShape = rankedInputs.shape;
const inputHeight = inputShape[inputShape.length - 3];
const inputWidth = inputShape[inputShape.length - 2];
let hBuffer = 0;
if (inputHeight !== this.height) {
hBuffer = Math.floor((inputHeight - this.height) / 2);
}
let wBuffer = 0;
if (inputWidth !== this.width) {
wBuffer = Math.floor((inputWidth - this.width) / 2);
if (wBuffer === 0) {
wBuffer = 1;
}
}
if (hBuffer >= 0 && wBuffer >= 0) {
return this.centerCrop(rankedInputs, hBuffer, wBuffer, this.height, this.width, inputHeight, inputWidth, dtype);
}
else {
return this.upsize(inputs, this.height, this.width, dtype);
}
});
}
getConfig() {
const config = {
'height': this.height,
'width': this.width
};
const baseConfig = super.getConfig();
Object.assign(config, baseConfig);
return config;
}
computeOutputShape(inputShape) {
inputShape = getExactlyOneShape(inputShape);
const hAxis = inputShape.length - 3;
const wAxis = inputShape.length - 2;
inputShape[hAxis] = this.height;
inputShape[wAxis] = this.width;
return inputShape;
}
}
/** @nocollapse */
CenterCrop.className = 'CenterCrop';
export { CenterCrop };
serialization.registerClass(CenterCrop);
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"center_crop.js","sourceRoot":"","sources":["../../../../../../../tfjs-layers/src/layers/preprocessing/center_crop.ts"],"names":[],"mappings":"AAAA;;;;;;;;GAQG;AAEH,OAAO,EAAC,aAAa,EAAU,OAAO,EAAC,KAAK,EAAC,MAAM,EAA+C,IAAI,EAAE,KAAK,EAAE,KAAK,EAAC,MAAM,uBAAuB,CAAC;AACnJ,OAAO,EAAC,kBAAkB,EAAE,mBAAmB,EAAC,MAAM,yBAAyB,CAAC;AAChF,OAAO,EAAY,KAAK,EAAC,MAAM,uBAAuB,CAAC;AAGvD,OAAO,KAAK,CAAC,MAAM,4BAA4B,CAAC;AAEhD,MAAM,EAAC,cAAc,EAAE,aAAa,EAAC,GAAG,KAAK,CAAC;AAO9C,MAAa,UAAW,SAAQ,KAAK;IAKnC,YAAY,IAAoB;QAC9B,KAAK,CAAC,IAAI,CAAC,CAAC;QACZ,IAAI,CAAC,MAAM,GAAG,IAAI,CAAC,MAAM,CAAC;QAC1B,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,KAAK,CAAC;IAC1B,CAAC;IAED,UAAU,CAAC,MAA2B,EAAE,OAAe,EAAE,OAAe,EAC9D,MAAc,EAAE,KAAa,EAAE,WAAmB,EAClD,UAAkB,EAAE,KAAe;QAE3C,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,IAAI,KAAe,CAAC;YACpB,IAAI,OAAO,GAAQ,KAAK,CAAC;YACzB,MAAM,GAAG,GAAQ,OAAO,GAAG,WAAW,CAAC;YACvC,MAAM,IAAI,GAAO,OAAO,GAAG,UAAU,CAAC;YACtC,MAAM,MAAM,GAAK,CAAC,CAAC,MAAM,CAAC,GAAG,OAAO,CAAC,GAAG,WAAW,CAAC;YACpD,MAAM,KAAK,GAAM,CAAC,CAAC,KAAK,CAAC,GAAG,OAAO,CAAC,GAAG,UAAU,CAAC;YAClD,MAAM,KAAK,GAAM,CAAC,GAAG,EAAE,IAAI,EAAE,MAAM,EAAE,KAAK,CAAC,CAAC;YAC5C,MAAM,QAAQ,GAAG,EAAE,CAAC;YAEpB,IAAG,MAAM,CAAC,IAAI,KAAK,CAAC,EAAE;gBACpB,OAAO,GAAI,IAAI,CAAC;gBAChB,KAAK,GAAI,KAAK,CAAC,CAAC,MAAM,CAAC,CAAa,CAAC;aACtC;iBAAM;gBACL,KAAK,GAAG,MAAkB,CAAC;aAC5B;YAED,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE;gBACvC,QAAQ,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;aACtB;YAED,MAAM,KAAK,GAAc,MAAM,CAAC,QAAQ,EAAE,CAAC,QAAQ,CAAC,MAAM,EAAE,CAAC,CAAC,CAAC,CAAC;YAChE,MAAM,MAAM,GAAa,KAAK,CAAC,CAAC,EAAE,QAAQ,CAAC,MAAM,EAAE,CAAC,EAAE,OAAO,CAAC,CAAC;YAE/D,MAAM,QAAQ,GAAqB,CAAC,MAAM,EAAE,KAAK,CAAC,CAAC;YACnD,MAAM,OAAO,GAAG,aAAa,CAAC,KAAK,EAAE,KAAK,EAAE,MAAM,EAAE,QAAQ,EAAE,SAAS,CAAC,CAAC;YAEzE,IAAG,OAAO,EAAE;gBACV,OAAO,CAAC,CAAC,IAAI,CAAC,mBAAmB,CAAC,OAAO,CAAC,OAAO,CAAC,CAAC,EAAE,KAAK,CAAC,CAAC;aAC7D;YACD,OAAO,CAAC,CAAC,IAAI,CAAC,OAAO,EAAE,KAAK,CAAC,CAAC;QACjC,CAAC,CAAC,CAAC;IAEJ,CAAC;IAED,MAAM,CAAC,MAA4B,EAAE,MAAc,EAC5C,KAAa,EAAE,KAAe;QAEnC,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,MAAM,OAAO,GAAG,cAAc,CAAC,MAAM,EAAE,CAAC,MAAM,EAAE,KAAK,CAAC,CAAC,CAAC;YACxD,OAAO,CAAC,CAAC,IAAI,CAAC,OAAO,EAAE,KAAK,CAAC,CAAC;QAClC,CAAC,CAAC,CAAC;IAEL,CAAC;IAEU,IAAI,CAAC,MAA2B,EAAG,MAAc;QAExD,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,MAAM,YAAY,GAAG,mBAAmB,CAAC,MAAM,CAAwB,CAAC;YACxE,MAAM,KAAK,GAAS,YAAY,CAAC,KAAK,CAAC;YACvC,MAAM,UAAU,GAAI,YAAY,CAAC,KAAK,CAAC;YACvC,MAAM,WAAW,GAAG,UAAU,CAAC,UAAU,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;YACtD,MAAM,UAAU,GAAK,UAAU,CAAC,UAAU,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;YAEvD,IAAI,OAAO,GAAG,CAAC,CAAC;YAChB,IAAI,WAAW,KAAK,IAAI,CAAC,MAAM,EAAE;gBAC/B,OAAO,GAAI,IAAI,CAAC,KAAK,CAAC,CAAC,WAAW,GAAG,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,CAAC,CAAC;aACxD;YAED,IAAI,OAAO,GAAG,CAAC,CAAC;YAChB,IAAI,UAAU,KAAK,IAAI,CAAC,KAAK,EAAE;gBAC7B,OAAO,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,UAAU,GAAG,IAAI,CAAC,KAAK,CAAC,GAAG,CAAC,CAAC,CAAC;gBAEpD,IAAI,OAAO,KAAK,CAAC,EAAE;oBACjB,OAAO,GAAG,CAAC,CAAC;iBACb;aACF;YAED,IAAG,OAAO,IAAI,CAAC,IAAI,OAAO,IAAI,CAAC,EAAE;gBAC/B,OAAO,IAAI,CAAC,UAAU,CAAC,YAAY,EAAE,OAAO,EAAE,OAAO,EAC/B,IAAI,CAAC,MAAM,EAAE,IAAI,CAAC,KAAK,EAAE,WAAW,EACpC,UAAU,EAAE,KAAK,CAAC,CAAC;aAC1C;iBAAM;gBACL,OAAO,IAAI,CAAC,MAAM,CAAC,MAAM,EAAE,IAAI,CAAC,MAAM,EAAE,IAAI,CAAC,KAAK,EAAE,KAAK,CAAC,CAAC;aAC5D;QACJ,CAAC,CAAC,CAAC;IAEJ,CAAC;IAEQ,SAAS;QAEhB,MAAM,MAAM,GAA6B;YACvC,QAAQ,EAAG,IAAI,CAAC,MAAM;YACtB,OAAO,EAAG,IAAI,CAAC,KAAK;SACrB,CAAC;QAEF,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;IAEQ,kBAAkB,CAAC,UAA2B;QACrD,UAAU,GAAG,kBAAkB,CAAC,UAAU,CAAC,CAAC;QAC5C,MAAM,KAAK,GAAG,UAAU,CAAC,MAAM,GAAG,CAAC,CAAC;QACpC,MAAM,KAAK,GAAG,UAAU,CAAC,MAAM,GAAG,CAAC,CAAC;QACpC,UAAU,CAAC,KAAK,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC;QAChC,UAAU,CAAC,KAAK,CAAC,GAAG,IAAI,CAAC,KAAK,CAAC;QAC/B,OAAO,UAAU,CAAC;IACpB,CAAC;;AAhHD,kBAAkB;AACX,oBAAS,GAAG,YAAY,CAAC;SAFrB,UAAU;AAoHvB,aAAa,CAAC,aAAa,CAAC,UAAU,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2022 CodeSmith LLC\n *\n * Use of this source code is governed by an MIT-style\n * license that can be found in the LICENSE file or at\n * https://opensource.org/licenses/MIT.\n * =============================================================================\n */\n\nimport {serialization,DataType,unstack,stack,tensor,Tensor,Tensor1D,Tensor2D, Tensor3D, Tensor4D, tidy, range, image} from '@tensorflow/tfjs-core';\nimport {getExactlyOneShape, getExactlyOneTensor} from '../../utils/types_utils';\nimport {LayerArgs, Layer} from '../../engine/topology';\nimport {Kwargs} from '../../types';\nimport {Shape} from '../../keras_format/common';\nimport * as K from '../../backend/tfjs_backend';\n\nconst {resizeBilinear, cropAndResize} = image;\n\nexport declare interface CenterCropArgs extends LayerArgs{\n  height: number;\n  width: number;\n}\n\nexport class CenterCrop extends Layer {\n  /** @nocollapse */\n  static className = 'CenterCrop';\n  private readonly height: number;\n  private readonly width: number;\n  constructor(args: CenterCropArgs) {\n    super(args);\n    this.height = args.height;\n    this.width = args.width;\n  }\n\n  centerCrop(inputs: Tensor3D | Tensor4D, hBuffer: number, wBuffer: number,\n            height: number, width: number, inputHeight: number,\n            inputWidth: number, dtype: DataType): Tensor | Tensor[] {\n\n    return tidy(() => {\n      let input: Tensor4D;\n      let isRank3      = false;\n      const top      = hBuffer / inputHeight;\n      const left     = wBuffer / inputWidth;\n      const bottom   = ((height) + hBuffer) / inputHeight;\n      const right    = ((width) + wBuffer) / inputWidth;\n      const bound    = [top, left, bottom, right];\n      const boxesArr = [];\n\n      if(inputs.rank === 3) {\n        isRank3  = true;\n        input  = stack([inputs]) as Tensor4D;\n      } else {\n        input = inputs as Tensor4D;\n      }\n\n      for (let i = 0; i < input.shape[0]; i++) {\n        boxesArr.push(bound);\n      }\n\n      const boxes: Tensor2D  = tensor(boxesArr, [boxesArr.length, 4]);\n      const boxInd: Tensor1D = range(0, boxesArr.length, 1, 'int32');\n\n      const cropSize: [number, number] = [height, width];\n      const cropped = cropAndResize(input, boxes, boxInd, cropSize, 'nearest');\n\n      if(isRank3) {\n        return K.cast(getExactlyOneTensor(unstack(cropped)), dtype);\n      }\n      return K.cast(cropped, dtype);\n   });\n\n  }\n\n  upsize(inputs : Tensor3D | Tensor4D, height: number,\n         width: number, dtype: DataType): Tensor | Tensor[] {\n\n    return tidy(() => {\n      const outputs = resizeBilinear(inputs, [height, width]);\n      return K.cast(outputs, dtype);\n  });\n\n}\n\n  override call(inputs: Tensor3D | Tensor4D , kwargs: Kwargs):\n      Tensor[] | Tensor {\n    return tidy(() => {\n      const rankedInputs = getExactlyOneTensor(inputs) as Tensor3D | Tensor4D;\n      const dtype       = rankedInputs.dtype;\n      const inputShape  = rankedInputs.shape;\n      const inputHeight = inputShape[inputShape.length - 3];\n      const inputWidth  =  inputShape[inputShape.length - 2];\n\n      let hBuffer = 0;\n      if (inputHeight !== this.height) {\n        hBuffer =  Math.floor((inputHeight - this.height) / 2);\n      }\n\n      let wBuffer = 0;\n      if (inputWidth !== this.width) {\n        wBuffer = Math.floor((inputWidth - this.width) / 2);\n\n        if (wBuffer === 0) {\n          wBuffer = 1;\n        }\n      }\n\n      if(hBuffer >= 0 && wBuffer >= 0) {\n        return this.centerCrop(rankedInputs, hBuffer, wBuffer,\n                              this.height, this.width, inputHeight,\n                              inputWidth, dtype);\n      } else {\n        return this.upsize(inputs, this.height, this.width, dtype);\n      }\n   });\n\n  }\n\n  override getConfig(): serialization.ConfigDict{\n\n    const config: serialization.ConfigDict = {\n      'height' : this.height,\n      'width' : this.width\n    };\n\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n\n  override computeOutputShape(inputShape: Shape | Shape[]): Shape | Shape[] {\n    inputShape = getExactlyOneShape(inputShape);\n    const hAxis = inputShape.length - 3;\n    const wAxis = inputShape.length - 2;\n    inputShape[hAxis] = this.height;\n    inputShape[wAxis] = this.width;\n    return inputShape;\n  }\n}\n\nserialization.registerClass(CenterCrop);\n"]}