UNPKG

@tensorflow/tfjs-layers

Version:

TensorFlow layers API in JavaScript

122 lines 15.5 kB
/** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { tensorScatterUpdate, tidy } from '@tensorflow/tfjs-core'; import { LayersModel } from '../../engine/training'; import { NotImplementedError } from '../../errors'; export function tensorToArr(input) { return Array.from(input.dataSync()); } export function tensorArrTo2DArr(inputs) { return inputs.map(input => tensorToArr(input)); } /** * Returns a new Tensor with `updates` inserted into `inputs` starting at the * index `startIndices`. * * @param inputs Tensor to "modify" * @param startIndices the starting index to insert the slice. * Length must be equal to `inputs.rank`; * @param updates the update tensor. Shape must fit within `inputs` shape. * @returns a new tensor with the modification. */ export function sliceUpdate(inputs, startIndices, updates) { return tidy(() => { const indices = []; /** * Computes the update indices by iterating through all indices from * `startIndices` to `startIndices + updates.shape`. */ function createIndices(idx, curr) { if (curr.length === startIndices.length) { indices.push(curr.slice()); return; } const start = startIndices[idx]; const end = start + updates.shape[idx]; for (let i = start; i < end; i++) { curr.push(i); createIndices(idx + 1, curr); curr.pop(); } } createIndices(0, []); // Flatten the updates to match length of its update indices. updates = updates.reshape([updates.size]); return tensorScatterUpdate(inputs, indices, updates); }); } function packXYSampleWeight(x, y, sampleWeight) { throw new NotImplementedError(); } function unPackXYSampleWeight(data) { throw new NotImplementedError(); } // TODO(pforderique): Figure out a workaround for `tf.data.Dataset`. function convertInputsToDataset(x, y, sampleWeight, batchSize) { throw new NotImplementedError(); } function trainValidationSplit(arrays, validationSplit) { throw new NotImplementedError(); } class PipelineModel extends LayersModel { constructor(args) { var _a; super(args); this.includePreprocessing = (_a = args.includePreprocessing) !== null && _a !== void 0 ? _a : true; } /** * An overridable function which preprocesses features. */ preprocessFeatures(x) { return x; } /** * An overridable function which preprocesses labels. */ preprocessLabels(y) { return y; } /** * An overridable function which preprocesses entire samples. */ preprocessSamples(x, y, sampleWeight) { throw new NotImplementedError(); } // --------------------------------------------------------------------------- // Below are overrides to LayersModel methods to apply the functions above. // --------------------------------------------------------------------------- fit(x, y, args = {}) { throw new NotImplementedError(`Uses ${convertInputsToDataset}, ${trainValidationSplit} ` + `${packXYSampleWeight}, and ${unPackXYSampleWeight}`); } evaluate(x, y, args) { throw new NotImplementedError(); } predict(x, args) { throw new NotImplementedError(); } trainOnBatch(x, y, sampleWeight) { throw new NotImplementedError(); } predictOnBatch(x) { throw new NotImplementedError(); } } /** @nocollapse */ PipelineModel.className = 'PipelineModel'; export { PipelineModel }; //# sourceMappingURL=data:application/json;base64,