UNPKG

@tensorflow/tfjs-layers

Version:

TensorFlow layers API in JavaScript

122 lines 15.5 kB
/** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { tensorScatterUpdate, tidy } from '@tensorflow/tfjs-core'; import { LayersModel } from '../../engine/training'; import { NotImplementedError } from '../../errors'; export function tensorToArr(input) { return Array.from(input.dataSync()); } export function tensorArrTo2DArr(inputs) { return inputs.map(input => tensorToArr(input)); } /** * Returns a new Tensor with `updates` inserted into `inputs` starting at the * index `startIndices`. * * @param inputs Tensor to "modify" * @param startIndices the starting index to insert the slice. * Length must be equal to `inputs.rank`; * @param updates the update tensor. Shape must fit within `inputs` shape. * @returns a new tensor with the modification. */ export function sliceUpdate(inputs, startIndices, updates) { return tidy(() => { const indices = []; /** * Computes the update indices by iterating through all indices from * `startIndices` to `startIndices + updates.shape`. */ function createIndices(idx, curr) { if (curr.length === startIndices.length) { indices.push(curr.slice()); return; } const start = startIndices[idx]; const end = start + updates.shape[idx]; for (let i = start; i < end; i++) { curr.push(i); createIndices(idx + 1, curr); curr.pop(); } } createIndices(0, []); // Flatten the updates to match length of its update indices. updates = updates.reshape([updates.size]); return tensorScatterUpdate(inputs, indices, updates); }); } function packXYSampleWeight(x, y, sampleWeight) { throw new NotImplementedError(); } function unPackXYSampleWeight(data) { throw new NotImplementedError(); } // TODO(pforderique): Figure out a workaround for `tf.data.Dataset`. function convertInputsToDataset(x, y, sampleWeight, batchSize) { throw new NotImplementedError(); } function trainValidationSplit(arrays, validationSplit) { throw new NotImplementedError(); } class PipelineModel extends LayersModel { constructor(args) { var _a; super(args); this.includePreprocessing = (_a = args.includePreprocessing) !== null && _a !== void 0 ? _a : true; } /** * An overridable function which preprocesses features. */ preprocessFeatures(x) { return x; } /** * An overridable function which preprocesses labels. */ preprocessLabels(y) { return y; } /** * An overridable function which preprocesses entire samples. */ preprocessSamples(x, y, sampleWeight) { throw new NotImplementedError(); } // --------------------------------------------------------------------------- // Below are overrides to LayersModel methods to apply the functions above. // --------------------------------------------------------------------------- fit(x, y, args = {}) { throw new NotImplementedError(`Uses ${convertInputsToDataset}, ${trainValidationSplit} ` + `${packXYSampleWeight}, and ${unPackXYSampleWeight}`); } evaluate(x, y, args) { throw new NotImplementedError(); } predict(x, args) { throw new NotImplementedError(); } trainOnBatch(x, y, sampleWeight) { throw new NotImplementedError(); } predictOnBatch(x) { throw new NotImplementedError(); } } /** @nocollapse */ PipelineModel.className = 'PipelineModel'; export { PipelineModel }; //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"utils.js","sourceRoot":"","sources":["../../../../../../../tfjs-layers/src/layers/nlp/utils.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAsC,mBAAmB,EAAE,IAAI,EAAE,MAAM,uBAAuB,CAAC;AAItG,OAAO,EAAE,WAAW,EAAqB,MAAM,uBAAuB,CAAC;AAEvE,OAAO,EAAE,mBAAmB,EAAE,MAAM,cAAc,CAAC;AAEnD,MAAM,UAAU,WAAW,CAAC,KAAa;IACvC,OAAO,KAAK,CAAC,IAAI,CAAC,KAAK,CAAC,QAAQ,EAAE,CAAyB,CAAC;AAC9D,CAAC;AAED,MAAM,UAAU,gBAAgB,CAAC,MAAgB;IAC/C,OAAO,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,EAAE,CAAC,WAAW,CAAC,KAAK,CAAC,CAAC,CAAC;AACjD,CAAC;AAED;;;;;;;;;GASG;AACH,MAAM,UAAU,WAAW,CACvB,MAAc,EAAE,YAAsB,EAAE,OAAe;IACzD,OAAO,IAAI,CAAC,GAAG,EAAE;QACf,MAAM,OAAO,GAAe,EAAE,CAAC;QAC/B;;;WAGG;QACH,SAAS,aAAa,CAAC,GAAW,EAAE,IAAc;YAChD,IAAI,IAAI,CAAC,MAAM,KAAK,YAAY,CAAC,MAAM,EAAE;gBACvC,OAAO,CAAC,IAAI,CAAC,IAAI,CAAC,KAAK,EAAE,CAAC,CAAC;gBAC3B,OAAO;aACR;YACD,MAAM,KAAK,GAAG,YAAY,CAAC,GAAG,CAAC,CAAC;YAChC,MAAM,GAAG,GAAG,KAAK,GAAG,OAAO,CAAC,KAAK,CAAC,GAAG,CAAC,CAAC;YACvC,KAAK,IAAI,CAAC,GAAG,KAAK,EAAE,CAAC,GAAG,GAAG,EAAE,CAAC,EAAE,EAAE;gBAChC,IAAI,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;gBACb,aAAa,CAAC,GAAG,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;gBAC7B,IAAI,CAAC,GAAG,EAAE,CAAC;aACZ;QACH,CAAC;QACD,aAAa,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC;QACrB,6DAA6D;QAC7D,OAAO,GAAG,OAAO,CAAC,OAAO,CAAC,CAAC,OAAO,CAAC,IAAI,CAAC,CAAC,CAAC;QAC1C,OAAO,mBAAmB,CAAC,MAAM,EAAE,OAAO,EAAE,OAAO,CAAC,CAAC;IACvD,CAAC,CAAC,CAAC;AACL,CAAC;AAED,SAAS,kBAAkB,CAAC,CAAS,EAAE,CAAU,EAAE,YAAqB;IAItE,MAAM,IAAI,mBAAmB,EAAE,CAAC;AAClC,CAAC;AAED,SAAS,oBAAoB,CAC3B,IAAwD;IAExD,MAAM,IAAI,mBAAmB,EAAE,CAAC;AAClC,CAAC;AAED,oEAAoE;AACpE,SAAS,sBAAsB,CAC7B,CAAU,EAAE,CAAU,EAAE,YAAqB,EAAE,SAAkB;IAEjE,MAAM,IAAI,mBAAmB,EAAE,CAAC;AAClC,CAAC;AAED,SAAS,oBAAoB,CAAC,MAAgB,EAAE,eAAuB;IACrE,MAAM,IAAI,mBAAmB,EAAE,CAAC;AAClC,CAAC;AAYD,MAAa,aAAc,SAAQ,WAAW;IAM5C,YAAY,IAAuB;;QACjC,KAAK,CAAC,IAAI,CAAC,CAAC;QACZ,IAAI,CAAC,oBAAoB,GAAG,MAAA,IAAI,CAAC,oBAAoB,mCAAI,IAAI,CAAC;IAChE,CAAC;IAED;;OAEG;IACH,kBAAkB,CAAC,CAAS;QAC1B,OAAO,CAAC,CAAC;IACX,CAAC;IAED;;OAEG;IACH,gBAAgB,CAAC,CAAS;QACxB,OAAO,CAAC,CAAC;IACX,CAAC;IAED;;OAEG;IACH,iBAAiB,CAAC,CAAS,EAAE,CAAU,EAAE,YAAqB;QAI5D,MAAM,IAAI,mBAAmB,EAAE,CAAC;IAClC,CAAC;IAED,8EAA8E;IAC9E,2EAA2E;IAC3E,8EAA8E;IACrE,GAAG,CACV,CAAgD,EAChD,CAAgD,EAChD,OAAqB,EAAE;QAEvB,MAAM,IAAI,mBAAmB,CAC3B,QAAQ,sBAAsB,KAAK,oBAAoB,GAAG;YAC1D,GAAG,kBAAkB,SAAS,oBAAoB,EAAE,CAAC,CAAC;IAC1D,CAAC;IAEQ,QAAQ,CACf,CAAkB,EAClB,CAAkB,EAClB,IAAwB;QAExB,MAAM,IAAI,mBAAmB,EAAE,CAAC;IAClC,CAAC;IAEQ,OAAO,CACd,CAAoB,EACpB,IAAyB;QAEzB,MAAM,IAAI,mBAAmB,EAAE,CAAC;IAClC,CAAC;IAEQ,YAAY,CACnB,CAAgD,EAChD,CAAgD,EAChD,YAAqB;QAErB,MAAM,IAAI,mBAAmB,EAAE,CAAC;IAClC,CAAC;IAEQ,cAAc,CAAC,CAAkB;QACxC,MAAM,IAAI,mBAAmB,EAAE,CAAC;IAClC,CAAC;;AAxED,kBAAkB;AACF,uBAAS,GAAG,eAAe,CAAC;SAFjC,aAAa","sourcesContent":["/**\n * @license\n * Copyright 2023 Google LLC.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport { ModelPredictConfig, Scalar, Tensor, tensorScatterUpdate, tidy } from '@tensorflow/tfjs-core';\n\nimport { History } from '../../base_callbacks';\nimport { ContainerArgs } from '../../engine/container';\nimport { LayersModel, ModelEvaluateArgs } from '../../engine/training';\nimport { ModelFitArgs } from '../../engine/training_tensors';\nimport { NotImplementedError } from '../../errors';\n\nexport function tensorToArr(input: Tensor): unknown[] {\n  return Array.from(input.dataSync()) as unknown as unknown[];\n}\n\nexport function tensorArrTo2DArr(inputs: Tensor[]): unknown[][] {\n  return inputs.map(input => tensorToArr(input));\n}\n\n/**\n * Returns a new Tensor with `updates` inserted into `inputs` starting at the\n * index `startIndices`.\n *\n * @param inputs Tensor to \"modify\"\n * @param startIndices the starting index to insert the slice.\n *  Length must be equal to `inputs.rank`;\n * @param updates the update tensor. Shape must fit within `inputs` shape.\n * @returns a new tensor with the modification.\n */\nexport function sliceUpdate(\n    inputs: Tensor, startIndices: number[], updates: Tensor): Tensor {\n  return tidy(() => {\n    const indices: number[][] = [];\n    /**\n     * Computes the update indices by iterating through all indices from\n     * `startIndices` to `startIndices + updates.shape`.\n     */\n    function createIndices(idx: number, curr: number[]): void {\n      if (curr.length === startIndices.length) {\n        indices.push(curr.slice());\n        return;\n      }\n      const start = startIndices[idx];\n      const end = start + updates.shape[idx];\n      for (let i = start; i < end; i++) {\n        curr.push(i);\n        createIndices(idx + 1, curr);\n        curr.pop();\n      }\n    }\n    createIndices(0, []);\n    // Flatten the updates to match length of its update indices.\n    updates = updates.reshape([updates.size]);\n    return tensorScatterUpdate(inputs, indices, updates);\n  });\n}\n\nfunction packXYSampleWeight(x: Tensor, y?: Tensor, sampleWeight?: Tensor):\n  Tensor\n  | [Tensor, Tensor]\n  | [Tensor, Tensor, Tensor] {\n  throw new NotImplementedError();\n}\n\nfunction unPackXYSampleWeight(\n  data: [Tensor]|[Tensor, Tensor]|[Tensor, Tensor, Tensor]\n) {\n  throw new NotImplementedError();\n}\n\n// TODO(pforderique): Figure out a workaround for `tf.data.Dataset`.\nfunction convertInputsToDataset(\n  x?: Tensor, y?: Tensor, sampleWeight?: Tensor, batchSize?: number\n) {\n  throw new NotImplementedError();\n}\n\nfunction trainValidationSplit(arrays: Tensor[], validationSplit: number) {\n  throw new NotImplementedError();\n}\n\n/**\n * A model which allows automatically applying preprocessing.\n */\nexport interface PipelineModelArgs extends ContainerArgs {\n  /**\n   * Defaults to true.\n   */\n  includePreprocessing?: boolean;\n}\n\nexport class PipelineModel extends LayersModel {\n  /** @nocollapse */\n  static override className = 'PipelineModel';\n\n  protected includePreprocessing: boolean;\n\n  constructor(args: PipelineModelArgs) {\n    super(args);\n    this.includePreprocessing = args.includePreprocessing ?? true;\n  }\n\n  /**\n   * An overridable function which preprocesses features.\n   */\n  preprocessFeatures(x: Tensor) {\n    return x;\n  }\n\n  /**\n   * An overridable function which preprocesses labels.\n   */\n  preprocessLabels(y: Tensor) {\n    return y;\n  }\n\n  /**\n   * An overridable function which preprocesses entire samples.\n   */\n  preprocessSamples(x: Tensor, y?: Tensor, sampleWeight?: Tensor):\n    Tensor\n    | [Tensor, Tensor]\n    | [Tensor, Tensor, Tensor] {\n    throw new NotImplementedError();\n  }\n\n  // ---------------------------------------------------------------------------\n  // Below are overrides to LayersModel methods to apply the functions above.\n  // ---------------------------------------------------------------------------\n  override fit(\n    x: Tensor|Tensor[]|{[inputName: string]: Tensor},\n    y: Tensor|Tensor[]|{[inputName: string]: Tensor},\n    args: ModelFitArgs = {}\n  ): Promise<History> {\n    throw new NotImplementedError(\n      `Uses ${convertInputsToDataset}, ${trainValidationSplit} ` +\n      `${packXYSampleWeight}, and ${unPackXYSampleWeight}`);\n  }\n\n  override evaluate(\n    x: Tensor|Tensor[],\n    y: Tensor|Tensor[],\n    args?: ModelEvaluateArgs\n  ): Scalar | Scalar[] {\n    throw new NotImplementedError();\n  }\n\n  override predict(\n    x: Tensor | Tensor[],\n    args?: ModelPredictConfig\n  ): Tensor | Tensor[] {\n    throw new NotImplementedError();\n  }\n\n  override trainOnBatch(\n    x: Tensor|Tensor[]|{[inputName: string]: Tensor},\n    y: Tensor|Tensor[]|{[inputName: string]: Tensor},\n    sampleWeight?: Tensor\n  ): Promise<number|number[]> {\n    throw new NotImplementedError();\n  }\n\n  override predictOnBatch(x: Tensor|Tensor[]): Tensor|Tensor[] {\n    throw new NotImplementedError();\n  }\n}\n"]}