UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

175 lines 23 kB
import * as util from '../util'; /** * Wraps a list of ArrayBuffers into a `slice()`-able object without allocating * a large ArrayBuffer. * * Allocating large ArrayBuffers (~2GB) can be unstable on Chrome. TFJS loads * its weights as a list of (usually) 4MB ArrayBuffers and then slices the * weight tensors out of them. For small models, it's safe to concatenate all * the weight buffers into a single ArrayBuffer and then slice the weight * tensors out of it, but for large models, a different approach is needed. */ export class CompositeArrayBuffer { /** * Concatenate a number of ArrayBuffers into one. * * @param buffers An array of ArrayBuffers to concatenate, or a single * ArrayBuffer. * @returns Result of concatenating `buffers` in order. */ static join(buffers) { return new CompositeArrayBuffer(buffers).slice(); } constructor(buffers) { this.shards = []; this.previousShardIndex = 0; if (buffers == null) { return; } // Normalize the `buffers` input to be `ArrayBuffer[]`. if (!(buffers instanceof Array)) { buffers = [buffers]; } buffers = buffers.map((bufferOrTypedArray) => { if (util.isTypedArray(bufferOrTypedArray)) { return bufferOrTypedArray.buffer; } return bufferOrTypedArray; }); // Skip setting up shards if there are no buffers. if (buffers.length === 0) { return; } this.bufferUniformSize = buffers[0].byteLength; let start = 0; for (let i = 0; i < buffers.length; i++) { const buffer = buffers[i]; // Check that all buffers except the last one have the same length. if (i !== buffers.length - 1 && buffer.byteLength !== this.bufferUniformSize) { // Unset the buffer uniform size, since the buffer sizes are not // uniform. this.bufferUniformSize = undefined; } // Create the shards, including their start and end points. const end = start + buffer.byteLength; this.shards.push({ buffer, start, end }); start = end; } // Set the byteLength if (this.shards.length === 0) { this.byteLength = 0; } this.byteLength = this.shards[this.shards.length - 1].end; } slice(start = 0, end = this.byteLength) { // If there are no shards, then the CompositeArrayBuffer was initialized // with no data. if (this.shards.length === 0) { return new ArrayBuffer(0); } // NaN is treated as zero for slicing. This matches ArrayBuffer's behavior. start = isNaN(Number(start)) ? 0 : start; end = isNaN(Number(end)) ? 0 : end; // Fix the bounds to within the array. start = Math.max(0, start); end = Math.min(this.byteLength, end); if (end <= start) { return new ArrayBuffer(0); } const startShardIndex = this.findShardForByte(start); if (startShardIndex === -1) { // This should not happen since the start and end indices are always // within 0 and the composite array's length. throw new Error(`Could not find start shard for byte ${start}`); } const size = end - start; const outputBuffer = new ArrayBuffer(size); const outputArray = new Uint8Array(outputBuffer); let sliced = 0; for (let i = startShardIndex; i < this.shards.length; i++) { const shard = this.shards[i]; const globalStart = start + sliced; const localStart = globalStart - shard.start; const outputStart = sliced; const globalEnd = Math.min(end, shard.end); const localEnd = globalEnd - shard.start; const outputSlice = new Uint8Array(shard.buffer, localStart, localEnd - localStart); outputArray.set(outputSlice, outputStart); sliced += outputSlice.length; if (end < shard.end) { break; } } return outputBuffer; } /** * Get the index of the shard that contains the byte at `byteIndex`. */ findShardForByte(byteIndex) { if (this.shards.length === 0 || byteIndex < 0 || byteIndex >= this.byteLength) { return -1; } // If the buffers have a uniform size, compute the shard directly. if (this.bufferUniformSize != null) { this.previousShardIndex = Math.floor(byteIndex / this.bufferUniformSize); return this.previousShardIndex; } // If the buffers don't have a uniform size, we need to search for the // shard. That means we need a function to check where the byteIndex lies // relative to a given shard. function check(shard) { if (byteIndex < shard.start) { return -1; } if (byteIndex >= shard.end) { return 1; } return 0; } // For efficiency, try the previous shard first. if (check(this.shards[this.previousShardIndex]) === 0) { return this.previousShardIndex; } // Otherwise, use a generic search function. // This should almost never end up being used in practice since the weight // entries should always be in order. const index = search(this.shards, check); if (index === -1) { return -1; } this.previousShardIndex = index; return this.previousShardIndex; } } /** * Search for an element of a sorted array. * * @param sortedArray The sorted array to search * @param compare A function to compare the current value against the searched * value. Return 0 on a match, negative if the searched value is less than * the value passed to the function, and positive if the searched value is * greater than the value passed to the function. * @returns The index of the element, or -1 if it's not in the array. */ export function search(sortedArray, compare) { // Binary search let min = 0; let max = sortedArray.length; while (min <= max) { const middle = Math.floor((max - min) / 2) + min; const side = compare(sortedArray[middle]); if (side === 0) { return middle; } else if (side < 0) { max = middle; } else { min = middle + 1; } } return -1; } //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"composite_array_buffer.js","sourceRoot":"","sources":["../../../../../../tfjs-core/src/io/composite_array_buffer.ts"],"names":[],"mappings":"AAiBA,OAAO,KAAK,IAAI,MAAM,SAAS,CAAC;AAQhC;;;;;;;;;GASG;AAEH,MAAM,OAAO,oBAAoB;IAM/B;;;;;;OAMG;IACH,MAAM,CAAC,IAAI,CAAC,OAAqC;QAC/C,OAAO,IAAI,oBAAoB,CAAC,OAAO,CAAC,CAAC,KAAK,EAAE,CAAC;IACnD,CAAC;IAED,YAAY,OACE;QAjBN,WAAM,GAAkB,EAAE,CAAC;QAC3B,uBAAkB,GAAG,CAAC,CAAC;QAiB7B,IAAI,OAAO,IAAI,IAAI,EAAE;YACnB,OAAO;SACR;QACD,uDAAuD;QACvD,IAAI,CAAC,CAAC,OAAO,YAAY,KAAK,CAAC,EAAE;YAC/B,OAAO,GAAG,CAAC,OAAO,CAAC,CAAC;SACrB;QACD,OAAO,GAAG,OAAO,CAAC,GAAG,CAAC,CAAC,kBAAkB,EAAE,EAAE;YAC3C,IAAI,IAAI,CAAC,YAAY,CAAC,kBAAkB,CAAC,EAAE;gBACzC,OAAO,kBAAkB,CAAC,MAAM,CAAC;aAClC;YACD,OAAO,kBAAkB,CAAC;QAC5B,CAAC,CAAC,CAAC;QAEH,kDAAkD;QAClD,IAAI,OAAO,CAAC,MAAM,KAAK,CAAC,EAAE;YACxB,OAAO;SACR;QAED,IAAI,CAAC,iBAAiB,GAAG,OAAO,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC;QAC/C,IAAI,KAAK,GAAG,CAAC,CAAC;QAEd,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,OAAO,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE;YACvC,MAAM,MAAM,GAAG,OAAO,CAAC,CAAC,CAAC,CAAC;YAC1B,mEAAmE;YACnE,IAAI,CAAC,KAAK,OAAO,CAAC,MAAM,GAAG,CAAC;gBAC1B,MAAM,CAAC,UAAU,KAAK,IAAI,CAAC,iBAAiB,EAAE;gBAC9C,gEAAgE;gBAChE,WAAW;gBACX,IAAI,CAAC,iBAAiB,GAAG,SAAS,CAAC;aACpC;YAED,2DAA2D;YAC3D,MAAM,GAAG,GAAG,KAAK,GAAG,MAAM,CAAC,UAAU,CAAC;YACtC,IAAI,CAAC,MAAM,CAAC,IAAI,CAAC,EAAE,MAAM,EAAE,KAAK,EAAE,GAAG,EAAE,CAAC,CAAC;YACzC,KAAK,GAAG,GAAG,CAAC;SACb;QAED,qBAAqB;QACrB,IAAI,IAAI,CAAC,MAAM,CAAC,MAAM,KAAK,CAAC,EAAE;YAC5B,IAAI,CAAC,UAAU,GAAG,CAAC,CAAC;SACrB;QACD,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC,MAAM,CAAC,IAAI,CAAC,MAAM,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC,GAAG,CAAC;IAC5D,CAAC;IAED,KAAK,CAAC,KAAK,GAAG,CAAC,EAAE,GAAG,GAAG,IAAI,CAAC,UAAU;QACpC,wEAAwE;QACxE,gBAAgB;QAChB,IAAI,IAAI,CAAC,MAAM,CAAC,MAAM,KAAK,CAAC,EAAE;YAC5B,OAAO,IAAI,WAAW,CAAC,CAAC,CAAC,CAAC;SAC3B;QAED,2EAA2E;QAC3E,KAAK,GAAG,KAAK,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC;QACzC,GAAG,GAAG,KAAK,CAAC,MAAM,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC;QAEnC,sCAAsC;QACtC,KAAK,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,KAAK,CAAC,CAAC;QAC3B,GAAG,GAAG,IAAI,CAAC,GAAG,CAAC,IAAI,CAAC,UAAU,EAAE,GAAG,CAAC,CAAC;QACrC,IAAI,GAAG,IAAI,KAAK,EAAE;YAChB,OAAO,IAAI,WAAW,CAAC,CAAC,CAAC,CAAC;SAC3B;QAED,MAAM,eAAe,GAAG,IAAI,CAAC,gBAAgB,CAAC,KAAK,CAAC,CAAC;QACrD,IAAI,eAAe,KAAK,CAAC,CAAC,EAAE;YAC1B,oEAAoE;YACpE,6CAA6C;YAC7C,MAAM,IAAI,KAAK,CAAC,uCAAuC,KAAK,EAAE,CAAC,CAAC;SACjE;QAED,MAAM,IAAI,GAAG,GAAG,GAAG,KAAK,CAAC;QACzB,MAAM,YAAY,GAAG,IAAI,WAAW,CAAC,IAAI,CAAC,CAAC;QAC3C,MAAM,WAAW,GAAG,IAAI,UAAU,CAAC,YAAY,CAAC,CAAC;QACjD,IAAI,MAAM,GAAG,CAAC,CAAC;QACf,KAAK,IAAI,CAAC,GAAG,eAAe,EAAE,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE;YACzD,MAAM,KAAK,GAAG,IAAI,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC;YAE7B,MAAM,WAAW,GAAG,KAAK,GAAG,MAAM,CAAC;YACnC,MAAM,UAAU,GAAG,WAAW,GAAG,KAAK,CAAC,KAAK,CAAC;YAC7C,MAAM,WAAW,GAAG,MAAM,CAAC;YAE3B,MAAM,SAAS,GAAG,IAAI,CAAC,GAAG,CAAC,GAAG,EAAE,KAAK,CAAC,GAAG,CAAC,CAAC;YAC3C,MAAM,QAAQ,GAAG,SAAS,GAAG,KAAK,CAAC,KAAK,CAAC;YAEzC,MAAM,WAAW,GAAG,IAAI,UAAU,CAAC,KAAK,CAAC,MAAM,EAAE,UAAU,EACxB,QAAQ,GAAG,UAAU,CAAC,CAAC;YAC1D,WAAW,CAAC,GAAG,CAAC,WAAW,EAAE,WAAW,CAAC,CAAC;YAC1C,MAAM,IAAI,WAAW,CAAC,MAAM,CAAC;YAE7B,IAAI,GAAG,GAAG,KAAK,CAAC,GAAG,EAAE;gBACnB,MAAM;aACP;SACF;QACD,OAAO,YAAY,CAAC;IACtB,CAAC;IAED;;OAEG;IACK,gBAAgB,CAAC,SAAiB;QACxC,IAAI,IAAI,CAAC,MAAM,CAAC,MAAM,KAAK,CAAC,IAAI,SAAS,GAAG,CAAC;YAC3C,SAAS,IAAI,IAAI,CAAC,UAAU,EAAE;YAC9B,OAAO,CAAC,CAAC,CAAC;SACX;QAED,kEAAkE;QAClE,IAAI,IAAI,CAAC,iBAAiB,IAAI,IAAI,EAAE;YAClC,IAAI,CAAC,kBAAkB,GAAG,IAAI,CAAC,KAAK,CAAC,SAAS,GAAG,IAAI,CAAC,iBAAiB,CAAC,CAAC;YACzE,OAAO,IAAI,CAAC,kBAAkB,CAAC;SAChC;QAED,sEAAsE;QACtE,yEAAyE;QACzE,6BAA6B;QAC7B,SAAS,KAAK,CAAC,KAAkB;YAC/B,IAAI,SAAS,GAAG,KAAK,CAAC,KAAK,EAAE;gBAC3B,OAAO,CAAC,CAAC,CAAC;aACX;YACD,IAAI,SAAS,IAAI,KAAK,CAAC,GAAG,EAAE;gBAC1B,OAAO,CAAC,CAAC;aACV;YACD,OAAO,CAAC,CAAC;QACX,CAAC;QAED,gDAAgD;QAChD,IAAI,KAAK,CAAC,IAAI,CAAC,MAAM,CAAC,IAAI,CAAC,kBAAkB,CAAC,CAAC,KAAK,CAAC,EAAE;YACrD,OAAO,IAAI,CAAC,kBAAkB,CAAC;SAChC;QAED,4CAA4C;QAC5C,0EAA0E;QAC1E,qCAAqC;QACrC,MAAM,KAAK,GAAG,MAAM,CAAC,IAAI,CAAC,MAAM,EAAE,KAAK,CAAC,CAAC;QACzC,IAAI,KAAK,KAAK,CAAC,CAAC,EAAE;YAChB,OAAO,CAAC,CAAC,CAAC;SACX;QAED,IAAI,CAAC,kBAAkB,GAAG,KAAK,CAAC;QAChC,OAAO,IAAI,CAAC,kBAAkB,CAAC;IACjC,CAAC;CACF;AAED;;;;;;;;;GASG;AACH,MAAM,UAAU,MAAM,CAAI,WAAgB,EAAE,OAAyB;IACnE,gBAAgB;IAChB,IAAI,GAAG,GAAG,CAAC,CAAC;IACZ,IAAI,GAAG,GAAG,WAAW,CAAC,MAAM,CAAC;IAE7B,OAAO,GAAG,IAAI,GAAG,EAAE;QACjB,MAAM,MAAM,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,GAAG,GAAG,GAAG,CAAC,GAAG,CAAC,CAAC,GAAG,GAAG,CAAC;QACjD,MAAM,IAAI,GAAG,OAAO,CAAC,WAAW,CAAC,MAAM,CAAC,CAAC,CAAC;QAE1C,IAAI,IAAI,KAAK,CAAC,EAAE;YACd,OAAO,MAAM,CAAC;SACf;aAAM,IAAI,IAAI,GAAG,CAAC,EAAE;YACnB,GAAG,GAAG,MAAM,CAAC;SACd;aAAM;YACL,GAAG,GAAG,MAAM,GAAG,CAAC,CAAC;SAClB;KACF;IACD,OAAO,CAAC,CAAC,CAAC;AACZ,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2023 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {TypedArray} from '../types';\nimport * as util from '../util';\n\ntype BufferShard = {\n  start: number,\n  end: number,\n  buffer: ArrayBuffer,\n};\n\n/**\n * Wraps a list of ArrayBuffers into a `slice()`-able object without allocating\n * a large ArrayBuffer.\n *\n * Allocating large ArrayBuffers (~2GB) can be unstable on Chrome. TFJS loads\n * its weights as a list of (usually) 4MB ArrayBuffers and then slices the\n * weight tensors out of them. For small models, it's safe to concatenate all\n * the weight buffers into a single ArrayBuffer and then slice the weight\n * tensors out of it, but for large models, a different approach is needed.\n */\n\nexport class CompositeArrayBuffer {\n  private shards: BufferShard[] = [];\n  private previousShardIndex = 0;\n  private bufferUniformSize?: number;\n  public readonly byteLength: number;\n\n  /**\n   * Concatenate a number of ArrayBuffers into one.\n   *\n   * @param buffers An array of ArrayBuffers to concatenate, or a single\n   *     ArrayBuffer.\n   * @returns Result of concatenating `buffers` in order.\n   */\n  static join(buffers?: ArrayBuffer[] | ArrayBuffer) {\n    return new CompositeArrayBuffer(buffers).slice();\n  }\n\n  constructor(buffers?: ArrayBuffer | ArrayBuffer[] | TypedArray |\n    TypedArray[]) {\n    if (buffers == null) {\n      return;\n    }\n    // Normalize the `buffers` input to be `ArrayBuffer[]`.\n    if (!(buffers instanceof Array)) {\n      buffers = [buffers];\n    }\n    buffers = buffers.map((bufferOrTypedArray) => {\n      if (util.isTypedArray(bufferOrTypedArray)) {\n        return bufferOrTypedArray.buffer;\n      }\n      return bufferOrTypedArray;\n    });\n\n    // Skip setting up shards if there are no buffers.\n    if (buffers.length === 0) {\n      return;\n    }\n\n    this.bufferUniformSize = buffers[0].byteLength;\n    let start = 0;\n\n    for (let i = 0; i < buffers.length; i++) {\n      const buffer = buffers[i];\n      // Check that all buffers except the last one have the same length.\n      if (i !== buffers.length - 1 &&\n        buffer.byteLength !== this.bufferUniformSize) {\n        // Unset the buffer uniform size, since the buffer sizes are not\n        // uniform.\n        this.bufferUniformSize = undefined;\n      }\n\n      // Create the shards, including their start and end points.\n      const end = start + buffer.byteLength;\n      this.shards.push({ buffer, start, end });\n      start = end;\n    }\n\n    // Set the byteLength\n    if (this.shards.length === 0) {\n      this.byteLength = 0;\n    }\n    this.byteLength = this.shards[this.shards.length - 1].end;\n  }\n\n  slice(start = 0, end = this.byteLength): ArrayBuffer {\n    // If there are no shards, then the CompositeArrayBuffer was initialized\n    // with no data.\n    if (this.shards.length === 0) {\n      return new ArrayBuffer(0);\n    }\n\n    // NaN is treated as zero for slicing. This matches ArrayBuffer's behavior.\n    start = isNaN(Number(start)) ? 0 : start;\n    end = isNaN(Number(end)) ? 0 : end;\n\n    // Fix the bounds to within the array.\n    start = Math.max(0, start);\n    end = Math.min(this.byteLength, end);\n    if (end <= start) {\n      return new ArrayBuffer(0);\n    }\n\n    const startShardIndex = this.findShardForByte(start);\n    if (startShardIndex === -1) {\n      // This should not happen since the start and end indices are always\n      // within 0 and the composite array's length.\n      throw new Error(`Could not find start shard for byte ${start}`);\n    }\n\n    const size = end - start;\n    const outputBuffer = new ArrayBuffer(size);\n    const outputArray = new Uint8Array(outputBuffer);\n    let sliced = 0;\n    for (let i = startShardIndex; i < this.shards.length; i++) {\n      const shard = this.shards[i];\n\n      const globalStart = start + sliced;\n      const localStart = globalStart - shard.start;\n      const outputStart = sliced;\n\n      const globalEnd = Math.min(end, shard.end);\n      const localEnd = globalEnd - shard.start;\n\n      const outputSlice = new Uint8Array(shard.buffer, localStart,\n                                         localEnd - localStart);\n      outputArray.set(outputSlice, outputStart);\n      sliced += outputSlice.length;\n\n      if (end < shard.end) {\n        break;\n      }\n    }\n    return outputBuffer;\n  }\n\n  /**\n   * Get the index of the shard that contains the byte at `byteIndex`.\n   */\n  private findShardForByte(byteIndex: number): number {\n    if (this.shards.length === 0 || byteIndex < 0 ||\n      byteIndex >= this.byteLength) {\n      return -1;\n    }\n\n    // If the buffers have a uniform size, compute the shard directly.\n    if (this.bufferUniformSize != null) {\n      this.previousShardIndex = Math.floor(byteIndex / this.bufferUniformSize);\n      return this.previousShardIndex;\n    }\n\n    // If the buffers don't have a uniform size, we need to search for the\n    // shard. That means we need a function to check where the byteIndex lies\n    // relative to a given shard.\n    function check(shard: BufferShard) {\n      if (byteIndex < shard.start) {\n        return -1;\n      }\n      if (byteIndex >= shard.end) {\n        return 1;\n      }\n      return 0;\n    }\n\n    // For efficiency, try the previous shard first.\n    if (check(this.shards[this.previousShardIndex]) === 0) {\n      return this.previousShardIndex;\n    }\n\n    // Otherwise, use a generic search function.\n    // This should almost never end up being used in practice since the weight\n    // entries should always be in order.\n    const index = search(this.shards, check);\n    if (index === -1) {\n      return -1;\n    }\n\n    this.previousShardIndex = index;\n    return this.previousShardIndex;\n  }\n}\n\n/**\n * Search for an element of a sorted array.\n *\n * @param sortedArray The sorted array to search\n * @param compare A function to compare the current value against the searched\n *     value. Return 0 on a match, negative if the searched value is less than\n *     the value passed to the function, and positive if the searched value is\n *     greater than the value passed to the function.\n * @returns The index of the element, or -1 if it's not in the array.\n */\nexport function search<T>(sortedArray: T[], compare: (t: T) => number): number {\n  // Binary search\n  let min = 0;\n  let max = sortedArray.length;\n\n  while (min <= max) {\n    const middle = Math.floor((max - min) / 2) + min;\n    const side = compare(sortedArray[middle]);\n\n    if (side === 0) {\n      return middle;\n    } else if (side < 0) {\n      max = middle;\n    } else {\n      min = middle + 1;\n    }\n  }\n  return -1;\n}\n"]}