@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
163 lines (162 loc) • 6.96 kB
TypeScript
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/// <amd-module name="@tensorflow/tfjs-core/dist/io/io_utils" />
import { NamedTensor, NamedTensorMap } from '../tensor_types';
import { TypedArray } from '../types';
import { ModelArtifacts, ModelArtifactsInfo, ModelJSON, WeightData, WeightGroup, WeightsManifestConfig, WeightsManifestEntry } from './types';
/**
* Encode a map from names to weight values as an ArrayBuffer, along with an
* `Array` of `WeightsManifestEntry` as specification of the encoded weights.
*
* This function does not perform sharding.
*
* This function is the reverse of `decodeWeights`.
*
* @param tensors A map ("dict") from names to tensors.
* @param group Group to which the weights belong (optional).
* @returns A `Promise` of
* - A flat `ArrayBuffer` with all the binary values of the `Tensor`s
* concatenated.
* - An `Array` of `WeightManifestEntry`s, carrying information including
* tensor names, `dtype`s and shapes.
* @throws Error: on unsupported tensor `dtype`.
*/
export declare function encodeWeights(tensors: NamedTensorMap | NamedTensor[], group?: WeightGroup): Promise<{
data: ArrayBuffer;
specs: WeightsManifestEntry[];
}>;
/**
* Decode flat ArrayBuffer as weights.
*
* This function does not handle sharding.
*
* This function is the reverse of `encodeWeights`.
*
* @param weightData A flat ArrayBuffer or an array of ArrayBuffers carrying the
* binary values of the tensors concatenated in the order specified in
* `specs`.
* @param specs Specifications of the names, dtypes and shapes of the tensors
* whose value are encoded by `buffer`.
* @return A map from tensor name to tensor value, with the names corresponding
* to names in `specs`.
* @throws Error, if any of the tensors has unsupported dtype.
*/
export declare function decodeWeights(weightData: WeightData, specs: WeightsManifestEntry[]): NamedTensorMap;
export declare function decodeWeightsStream(weightStream: ReadableStream<ArrayBuffer>, specs: WeightsManifestEntry[]): Promise<NamedTensorMap>;
/**
* Concatenate TypedArrays into an ArrayBuffer.
*/
export declare function concatenateTypedArrays(xs: TypedArray[]): ArrayBuffer;
/**
* Calculate the byte length of a JavaScript string.
*
* Note that a JavaScript string can contain wide characters, therefore the
* length of the string is not necessarily equal to the byte length.
*
* @param str Input string.
* @returns Byte length.
*/
export declare function stringByteLength(str: string): number;
/**
* Encode an ArrayBuffer as a base64 encoded string.
*
* @param buffer `ArrayBuffer` to be converted.
* @returns A string that base64-encodes `buffer`.
*/
export declare function arrayBufferToBase64String(buffer: ArrayBuffer): string;
/**
* Decode a base64 string as an ArrayBuffer.
*
* @param str Base64 string.
* @returns Decoded `ArrayBuffer`.
*/
export declare function base64StringToArrayBuffer(str: string): ArrayBuffer;
/**
* Concatenate a number of ArrayBuffers into one.
*
* @param buffers An array of ArrayBuffers to concatenate, or a single
* ArrayBuffer.
* @returns Result of concatenating `buffers` in order.
*
* @deprecated Use tf.io.CompositeArrayBuffer.join() instead.
*/
export declare function concatenateArrayBuffers(buffers: ArrayBuffer[] | ArrayBuffer): ArrayBuffer;
/**
* Get the basename of a path.
*
* Behaves in a way analogous to Linux's basename command.
*
* @param path
*/
export declare function basename(path: string): string;
/**
* Create `ModelJSON` from `ModelArtifacts`.
*
* @param artifacts Model artifacts, describing the model and its weights.
* @param manifest Weight manifest, describing where the weights of the
* `ModelArtifacts` are stored, and some metadata about them.
* @returns Object representing the `model.json` file describing the model
* artifacts and weights
*/
export declare function getModelJSONForModelArtifacts(artifacts: ModelArtifacts, manifest: WeightsManifestConfig): ModelJSON;
/**
* Create `ModelArtifacts` from a JSON file and weights.
*
* @param modelJSON Object containing the parsed JSON of `model.json`
* @param weightSpecs The list of WeightsManifestEntry for the model. Must be
* passed if the modelJSON has a weightsManifest.
* @param weightData An ArrayBuffer or array of ArrayBuffers of weight data for
* the model corresponding to the weights in weightSpecs. Must be passed if
* the modelJSON has a weightsManifest.
* @returns A Promise of the `ModelArtifacts`, as described by the JSON file.
*/
export declare function getModelArtifactsForJSONSync(modelJSON: ModelJSON, weightSpecs?: WeightsManifestEntry[], weightData?: WeightData): ModelArtifacts;
/**
* Create `ModelArtifacts` from a JSON file.
*
* @param modelJSON Object containing the parsed JSON of `model.json`
* @param loadWeights Function that takes the JSON file's weights manifest,
* reads weights from the listed path(s), and returns a Promise of the
* weight manifest entries along with the weights data.
* @returns A Promise of the `ModelArtifacts`, as described by the JSON file.
*/
export declare function getModelArtifactsForJSON(modelJSON: ModelJSON, loadWeights: (weightsManifest: WeightsManifestConfig) => Promise<[
WeightsManifestEntry[],
WeightData
]>): Promise<ModelArtifacts>;
/**
* Populate ModelArtifactsInfo fields for a model with JSON topology.
* @param modelArtifacts
* @returns A ModelArtifactsInfo object.
*/
export declare function getModelArtifactsInfoForJSON(modelArtifacts: ModelArtifacts): ModelArtifactsInfo;
/**
* Concatenate the weights stored in a WeightsManifestConfig into a list of
* WeightsManifestEntry
*
* @param weightsManifest The WeightsManifestConfig to extract weights from.
* @returns A list of WeightsManifestEntry of the weights in the weightsManifest
*/
export declare function getWeightSpecs(weightsManifest: WeightsManifestConfig): WeightsManifestEntry[];
/**
* Retrieve a Float16 decoder which will decode a ByteArray of Float16 values
* to a Float32Array.
*
* @returns Function (buffer: Uint16Array) => Float32Array which decodes
* the Uint16Array of Float16 bytes to a Float32Array.
*/
export declare function getFloat16Decoder(): (buffer: Uint16Array) => Float32Array;