UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

163 lines (162 loc) 6.96 kB
/** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /// <amd-module name="@tensorflow/tfjs-core/dist/io/io_utils" /> import { NamedTensor, NamedTensorMap } from '../tensor_types'; import { TypedArray } from '../types'; import { ModelArtifacts, ModelArtifactsInfo, ModelJSON, WeightData, WeightGroup, WeightsManifestConfig, WeightsManifestEntry } from './types'; /** * Encode a map from names to weight values as an ArrayBuffer, along with an * `Array` of `WeightsManifestEntry` as specification of the encoded weights. * * This function does not perform sharding. * * This function is the reverse of `decodeWeights`. * * @param tensors A map ("dict") from names to tensors. * @param group Group to which the weights belong (optional). * @returns A `Promise` of * - A flat `ArrayBuffer` with all the binary values of the `Tensor`s * concatenated. * - An `Array` of `WeightManifestEntry`s, carrying information including * tensor names, `dtype`s and shapes. * @throws Error: on unsupported tensor `dtype`. */ export declare function encodeWeights(tensors: NamedTensorMap | NamedTensor[], group?: WeightGroup): Promise<{ data: ArrayBuffer; specs: WeightsManifestEntry[]; }>; /** * Decode flat ArrayBuffer as weights. * * This function does not handle sharding. * * This function is the reverse of `encodeWeights`. * * @param weightData A flat ArrayBuffer or an array of ArrayBuffers carrying the * binary values of the tensors concatenated in the order specified in * `specs`. * @param specs Specifications of the names, dtypes and shapes of the tensors * whose value are encoded by `buffer`. * @return A map from tensor name to tensor value, with the names corresponding * to names in `specs`. * @throws Error, if any of the tensors has unsupported dtype. */ export declare function decodeWeights(weightData: WeightData, specs: WeightsManifestEntry[]): NamedTensorMap; export declare function decodeWeightsStream(weightStream: ReadableStream<ArrayBuffer>, specs: WeightsManifestEntry[]): Promise<NamedTensorMap>; /** * Concatenate TypedArrays into an ArrayBuffer. */ export declare function concatenateTypedArrays(xs: TypedArray[]): ArrayBuffer; /** * Calculate the byte length of a JavaScript string. * * Note that a JavaScript string can contain wide characters, therefore the * length of the string is not necessarily equal to the byte length. * * @param str Input string. * @returns Byte length. */ export declare function stringByteLength(str: string): number; /** * Encode an ArrayBuffer as a base64 encoded string. * * @param buffer `ArrayBuffer` to be converted. * @returns A string that base64-encodes `buffer`. */ export declare function arrayBufferToBase64String(buffer: ArrayBuffer): string; /** * Decode a base64 string as an ArrayBuffer. * * @param str Base64 string. * @returns Decoded `ArrayBuffer`. */ export declare function base64StringToArrayBuffer(str: string): ArrayBuffer; /** * Concatenate a number of ArrayBuffers into one. * * @param buffers An array of ArrayBuffers to concatenate, or a single * ArrayBuffer. * @returns Result of concatenating `buffers` in order. * * @deprecated Use tf.io.CompositeArrayBuffer.join() instead. */ export declare function concatenateArrayBuffers(buffers: ArrayBuffer[] | ArrayBuffer): ArrayBuffer; /** * Get the basename of a path. * * Behaves in a way analogous to Linux's basename command. * * @param path */ export declare function basename(path: string): string; /** * Create `ModelJSON` from `ModelArtifacts`. * * @param artifacts Model artifacts, describing the model and its weights. * @param manifest Weight manifest, describing where the weights of the * `ModelArtifacts` are stored, and some metadata about them. * @returns Object representing the `model.json` file describing the model * artifacts and weights */ export declare function getModelJSONForModelArtifacts(artifacts: ModelArtifacts, manifest: WeightsManifestConfig): ModelJSON; /** * Create `ModelArtifacts` from a JSON file and weights. * * @param modelJSON Object containing the parsed JSON of `model.json` * @param weightSpecs The list of WeightsManifestEntry for the model. Must be * passed if the modelJSON has a weightsManifest. * @param weightData An ArrayBuffer or array of ArrayBuffers of weight data for * the model corresponding to the weights in weightSpecs. Must be passed if * the modelJSON has a weightsManifest. * @returns A Promise of the `ModelArtifacts`, as described by the JSON file. */ export declare function getModelArtifactsForJSONSync(modelJSON: ModelJSON, weightSpecs?: WeightsManifestEntry[], weightData?: WeightData): ModelArtifacts; /** * Create `ModelArtifacts` from a JSON file. * * @param modelJSON Object containing the parsed JSON of `model.json` * @param loadWeights Function that takes the JSON file's weights manifest, * reads weights from the listed path(s), and returns a Promise of the * weight manifest entries along with the weights data. * @returns A Promise of the `ModelArtifacts`, as described by the JSON file. */ export declare function getModelArtifactsForJSON(modelJSON: ModelJSON, loadWeights: (weightsManifest: WeightsManifestConfig) => Promise<[ WeightsManifestEntry[], WeightData ]>): Promise<ModelArtifacts>; /** * Populate ModelArtifactsInfo fields for a model with JSON topology. * @param modelArtifacts * @returns A ModelArtifactsInfo object. */ export declare function getModelArtifactsInfoForJSON(modelArtifacts: ModelArtifacts): ModelArtifactsInfo; /** * Concatenate the weights stored in a WeightsManifestConfig into a list of * WeightsManifestEntry * * @param weightsManifest The WeightsManifestConfig to extract weights from. * @returns A list of WeightsManifestEntry of the weights in the weightsManifest */ export declare function getWeightSpecs(weightsManifest: WeightsManifestConfig): WeightsManifestEntry[]; /** * Retrieve a Float16 decoder which will decode a ByteArray of Float16 values * to a Float32Array. * * @returns Function (buffer: Uint16Array) => Float32Array which decodes * the Uint16Array of Float16 bytes to a Float32Array. */ export declare function getFloat16Decoder(): (buffer: Uint16Array) => Float32Array;