@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
961 lines (859 loc) • 31 kB
text/typescript
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {BackendTimingInfo, DataMover, KernelBackend} from './backends/backend';
import {Environment, setEnvironmentGlobal} from './environment';
import {Profiler} from './profiler';
import {backpropagateGradients, getFilteredNodesXToY, NamedGradientMap, TapeNode} from './tape';
import {DataId, setTensorTracker, Tensor, Tensor3D, TensorTracker, Variable} from './tensor';
import {GradSaveFunc, NamedTensorMap, NamedVariableMap, TensorContainer} from './tensor_types';
import {getTensorsInContainer} from './tensor_util';
import {BackendValues, DataType, PixelData} from './types';
import * as util from './util';
import {bytesFromStringArray, makeOnesTypedArray, now, sizeFromShape} from './util';
/**
* A function that computes an output. The save function is for saving tensors
* computed in the forward pass, that we need in the backward pass.
*/
export type ForwardFunc<T> = (backend: KernelBackend, save?: GradSaveFunc) => T;
/**
* @docalias (a: Tensor, b: Tensor,..., save?: Function) => {
* value: Tensor,
* gradFunc: (dy: Tensor, saved?: NamedTensorMap) => Tensor | Tensor[]
* }
*/
export type CustomGradientFunc<T extends Tensor> =
(...inputs: Array<Tensor|GradSaveFunc>) => {
value: T;
gradFunc: (dy: T, saved: Tensor[]) => Tensor | Tensor[];
};
export type MemoryInfo = {
numTensors: number; numDataBuffers: number; numBytes: number;
unreliable?: boolean; reasons: string[];
};
type KernelProfile = {
name: string; bytesAdded: number; totalBytesSnapshot: number;
tensorsAdded: number;
totalTensorsSnapshot: number;
inputShapes: number[][];
outputShape: number[] | number[][];
};
export type ProfileInfo = {
newBytes: number; newTensors: number; peakBytes: number;
kernels: KernelProfile[];
result: TensorContainer;
};
export interface TimingInfo extends BackendTimingInfo {
wallMs: number;
}
/** @docalias Function */
export type ScopeFn<T extends TensorContainer> = () => T;
export interface TensorManager {
registerTensor(a: Tensor, backend?: KernelBackend): void;
registerVariable(v: Variable): void;
disposeTensor(a: Tensor): void;
memory(): {numDataBuffers: number; numBytes: number;};
}
interface ScopeState {
track: Tensor[];
name: string;
id: number;
}
class EngineState {
// Public since optimizers will use it.
registeredVariables: NamedVariableMap = {};
nextTapeNodeId = 0;
numBytes = 0;
numTensors = 0;
numStringTensors = 0;
numDataBuffers = 0;
activeTape: TapeNode[];
// Number of nested tf.grad() statements when computing higher-order
// gradients. E.g. `1` for first-order gradients and `2` for second-order
// gradients. Used to track if the tape should be removed after a backprop.
gradientDepth = 0;
// Number of nested kernel calls. When kernel depth is greater than 1, we turn
// off the tape.
kernelDepth = 0;
// Keep Tensors that parallel the tapes.
activeScope: ScopeState;
scopeStack: ScopeState[] = [];
nextScopeId = 0;
tensorInfo = new WeakMap<DataId, {
backend: KernelBackend,
bytes: number,
dtype: DataType,
shape: number[],
refCount: number
}>();
profiling = false;
activeProfile: ProfileInfo =
{newBytes: 0, newTensors: 0, peakBytes: 0, kernels: [], result: null};
dispose() {
for (const variableName in this.registeredVariables) {
this.registeredVariables[variableName].dispose();
}
}
}
export class Engine implements TensorManager, TensorTracker, DataMover {
state: EngineState;
backendName: string;
registry: {[id: string]: KernelBackend} = {};
registryFactory: {
[id: string]: {
factory: () => KernelBackend | Promise<KernelBackend>,
priority: number
}
} = {};
private profiler: Profiler;
private backendInstance: KernelBackend;
private pendingBackendInit: Promise<boolean>;
private pendingBackendInitId = 0;
constructor(public ENV: Environment) {
this.state = new EngineState();
}
async ready(): Promise<void> {
if (this.pendingBackendInit != null) {
return this.pendingBackendInit.then(() => {});
}
if (this.backendInstance != null) {
return;
}
const sortedBackends = this.getSortedBackends();
for (let i = 0; i < sortedBackends.length; i++) {
const backendName = sortedBackends[i];
const success = await this.initializeBackend(backendName).success;
if (success) {
await this.setBackend(backendName);
return;
}
}
throw new Error(
`Could not initialize any backends, all backend initializations ` +
`failed.`);
}
get backend(): KernelBackend {
if (this.pendingBackendInit != null) {
throw new Error(
`Backend '${this.backendName}' has not yet been initialized. Make ` +
`sure to await tf.ready() before calling other methods`);
}
if (this.backendInstance == null) {
const {name, asyncInit} = this.initializeBackendsAndReturnBest();
if (asyncInit) {
throw new Error(
`The highest priority backend '${name}' has not yet been ` +
`initialized. Make sure to await tf.ready() before calling ` +
`other methods`);
}
this.setBackend(name);
}
return this.backendInstance;
}
backendNames(): string[] {
return Object.keys(this.registryFactory);
}
findBackend(backendName: string): KernelBackend {
if (!(backendName in this.registry)) {
// If the backend hasn't been initialized but we have a registry entry for
// it, initialize it and return it.
if (backendName in this.registryFactory) {
const {asyncInit} = this.initializeBackend(backendName);
if (asyncInit) {
// Backend is not ready yet.
return null;
}
} else {
return null;
}
}
return this.registry[backendName];
}
findBackendFactory(backendName: string):
() => KernelBackend | Promise<KernelBackend> {
if (!(backendName in this.registryFactory)) {
return null;
}
return this.registryFactory[backendName].factory;
}
registerBackend(
backendName: string,
factory: () => KernelBackend | Promise<KernelBackend>,
priority = 1): boolean {
if (backendName in this.registryFactory) {
console.warn(
`${backendName} backend was already registered. ` +
`Reusing existing backend factory.`);
return false;
}
this.registryFactory[backendName] = {factory, priority};
return true;
}
async setBackend(backendName: string): Promise<boolean> {
if (this.registryFactory[backendName] == null) {
throw new Error(`Backend name '${backendName}' not found in registry`);
}
this.backendName = backendName;
if (this.registry[backendName] == null) {
this.backendInstance = null;
const {success, asyncInit} = this.initializeBackend(backendName);
const result = asyncInit ? await success : success;
if (!result) {
return false;
}
}
this.backendInstance = this.registry[backendName];
// Reset the profiler.
this.profiler = new Profiler(this.backendInstance);
return true;
}
/**
* Initializes a backend by looking up the backend name in the factory
* registry and calling the factory method. Returns a boolean representing
* whether the initialization of the backend suceeded. Throws an error if
* there is no backend in the factory registry.
*/
private initializeBackend(backendName: string):
{success: boolean|Promise<boolean>, asyncInit: boolean} {
const registryFactoryEntry = this.registryFactory[backendName];
if (registryFactoryEntry == null) {
throw new Error(
`Cannot initialize backend ${backendName}, no registration found.`);
}
try {
const backend = registryFactoryEntry.factory();
// Test if the factory returns a promise.
if (Promise.resolve(backend) === backend) {
const promiseId = ++this.pendingBackendInitId;
const success =
backend
.then(backendInstance => {
// Outdated promise. Another backend was set in the meantime.
if (promiseId < this.pendingBackendInitId) {
return false;
}
this.registry[backendName] = backendInstance;
this.pendingBackendInit = null;
return true;
})
.catch(err => {
// Outdated promise. Another backend was set in the meantime.
if (promiseId < this.pendingBackendInitId) {
return false;
}
this.pendingBackendInit = null;
console.warn(
`Initialization of backend ${backendName} failed`);
console.warn(err.stack || err.message);
return false;
});
this.pendingBackendInit = success;
return {success, asyncInit: true};
} else {
this.registry[backendName] = backend as KernelBackend;
return {success: true, asyncInit: false};
}
} catch (err) {
console.warn(`Initialization of backend ${backendName} failed`);
console.warn(err.stack || err.message);
return {success: false, asyncInit: false};
}
}
removeBackend(backendName: string): void {
if (!(backendName in this.registryFactory)) {
throw new Error(`${backendName} backend not found in registry`);
}
if (this.backendName === backendName && this.pendingBackendInit != null) {
// There is a pending promise of the backend we want to remove. Make it
// obsolete.
this.pendingBackendInitId++;
}
if (backendName in this.registry) {
this.registry[backendName].dispose();
delete this.registry[backendName];
}
delete this.registryFactory[backendName];
// Unset the backend if it is active.
if (this.backendName === backendName) {
this.pendingBackendInit = null;
this.backendName = null;
this.backendInstance = null;
}
}
private getSortedBackends(): string[] {
if (Object.keys(this.registryFactory).length === 0) {
throw new Error('No backend found in registry.');
}
return Object.keys(this.registryFactory).sort((a: string, b: string) => {
// Highest priority comes first.
return this.registryFactory[b].priority -
this.registryFactory[a].priority;
});
}
private initializeBackendsAndReturnBest():
{name: string, asyncInit: boolean} {
const sortedBackends = this.getSortedBackends();
for (let i = 0; i < sortedBackends.length; i++) {
const backendName = sortedBackends[i];
const {success, asyncInit} = this.initializeBackend(backendName);
if (asyncInit || success) {
return {name: backendName, asyncInit};
}
}
throw new Error(
`Could not initialize any backends, all backend initializations ` +
`failed.`);
}
moveData(destBackend: KernelBackend, dataId: DataId) {
this.write(destBackend, dataId, this.readSync(dataId));
}
tidy<T extends TensorContainer>(nameOrFn: string|ScopeFn<T>, fn?: ScopeFn<T>):
T {
let name: string = null;
if (fn == null) {
// Called with only 1 argument.
if (typeof nameOrFn !== 'function') {
throw new Error('Please provide a function to tidy()');
}
fn = nameOrFn;
} else {
// Called with 2 arguments.
if (typeof nameOrFn !== 'string' && !(nameOrFn instanceof String)) {
throw new Error(
'When calling with two arguments, the first argument ' +
'to tidy() must be a string');
}
if (typeof fn !== 'function') {
throw new Error(
'When calling with two arguments, the 2nd argument ' +
'to tidy() must be a function');
}
name = nameOrFn as string;
// TODO(nsthorat,smilkov): Do operation logging and performance
// profiling.
}
let result: T;
return this.scopedRun(
() => this.startScope(name), () => this.endScope(result), () => {
result = fn();
if (result instanceof Promise) {
console.error('Cannot return a Promise inside of tidy.');
}
return result;
});
}
private scopedRun<T>(start: () => void, end: () => void, f: () => T): T {
start();
try {
const res = f();
end();
return res;
} catch (ex) {
end();
throw ex;
}
}
private static nextTensorId = 0;
nextTensorId(): number {
return Engine.nextTensorId++;
}
private static nextVariableId = 0;
nextVariableId(): number {
return Engine.nextVariableId++;
}
/**
* This method is called instead of the public-facing tensor.clone() when
* saving a tensor for backwards pass. It makes sure to add the clone
* operation to the tape regardless of being called inside a kernel
* execution.
*/
private clone(x: Tensor): Tensor {
const y = Tensor.make(x.shape, {dataId: x.dataId}, x.dtype);
this.addTapeNode([x], y, dy => [dy.toFloat()]);
return y;
}
runKernel<T extends Tensor|Tensor[], I extends NamedTensorMap>(
forwardFunc: ForwardFunc<T>,
inputs: I,
backwardsFunc?: (dy: T, saved: Tensor[]) => {[P in keyof I]: () => I[P]},
): T {
let result: T;
let saved: Tensor[] = [];
const isTapeOn = this.isTapeOn();
const scopeName =
this.state.activeScope != null ? this.state.activeScope.name : '';
const saveFunc: GradSaveFunc = (tensors) => {
// Do not save unless we are recording to the tape. Otherwise it would
// cause a mem leak since we would never run backprop, which disposes
// the kept tensors.
if (!isTapeOn) {
return;
}
saved = tensors.map(tensor => this.keep(this.clone(tensor)));
};
const startingBytecount = this.state.numBytes;
const startingNumTensors = this.state.numTensors;
// Stop recording to a tape when running a kernel.
this.scopedRun(
() => this.state.kernelDepth++, () => this.state.kernelDepth--, () => {
if (!this.ENV.getBool('DEBUG')) {
result = forwardFunc(this.backend, saveFunc);
} else {
result = this.profiler.profileKernel(
scopeName, () => forwardFunc(this.backend, saveFunc));
}
});
if (isTapeOn) {
const tapeNode: TapeNode = {
id: this.state.nextTapeNodeId++,
name: scopeName,
inputs,
outputs: Array.isArray(result) ? result : [result] as Tensor[],
saved
};
if (backwardsFunc != null) {
tapeNode.gradient = (dy: T) => backwardsFunc(dy, saved);
}
this.state.activeTape.push(tapeNode);
}
if (this.state.profiling) {
this.state.activeProfile.kernels.push({
name: scopeName,
bytesAdded: this.state.numBytes - startingBytecount,
totalBytesSnapshot: this.state.numBytes,
tensorsAdded: this.state.numTensors - startingNumTensors,
totalTensorsSnapshot: this.state.numTensors,
inputShapes: Object.keys(inputs).map(key => inputs[key].shape),
outputShape: Array.isArray(result) ?
(result as Tensor[]).map(item => (item as Tensor).shape) :
(result as Tensor).shape
});
}
return result;
}
// TensorManager implementation.
registerTensor(a: Tensor|Variable, backend?: KernelBackend): void {
const refCount = this.state.tensorInfo.has(a.dataId) ?
this.state.tensorInfo.get(a.dataId).refCount :
0;
this.state.numTensors++;
if (a.dtype === 'string') {
this.state.numStringTensors++;
}
if (refCount === 0) {
this.state.numDataBuffers++;
// Bytes for complex numbers are counted by their components. Bytes for
// string tensors are counted when writing values.
let bytes = 0;
if (a.dtype !== 'complex64' && a.dtype !== 'string') {
bytes = a.size * util.bytesPerElement(a.dtype);
}
this.state.tensorInfo.set(a.dataId, {
backend: backend != null ? backend : this.backend,
dtype: a.dtype,
shape: a.shape,
bytes,
refCount: 0
});
this.state.numBytes += bytes;
if (backend != null) {
backend.register(a.dataId, a.shape, a.dtype);
} else {
this.backend.register(a.dataId, a.shape, a.dtype);
}
}
this.state.tensorInfo.get(a.dataId).refCount++;
if (!(a instanceof Variable)) {
this.track(a);
}
}
registerVariable(v: Variable) {
if (this.state.registeredVariables[v.name] != null) {
throw new Error(`Variable with name ${v.name} was already registered`);
}
this.state.registeredVariables[v.name] = v;
}
disposeTensor(a: Tensor): void {
if (!this.state.tensorInfo.has(a.dataId)) {
return;
}
this.state.numTensors--;
if (a.dtype === 'string') {
this.state.numStringTensors--;
}
const info = this.state.tensorInfo.get(a.dataId);
const refCount = info.refCount;
if (refCount <= 1) {
// Don't count bytes for complex numbers as they are counted by their
// components.
if (a.dtype !== 'complex64') {
this.state.numBytes -= info.bytes;
}
this.state.numDataBuffers--;
info.backend.disposeData(a.dataId);
this.state.tensorInfo.delete(a.dataId);
} else {
this.state.tensorInfo.get(a.dataId).refCount--;
}
// TODO(nsthorat): Construct an error and save the stack trace for
// debugging when in debug mode. Creating a stack trace is too expensive
// to do unconditionally.
}
disposeVariables(): void {
for (const varName in this.state.registeredVariables) {
const v = this.state.registeredVariables[varName];
this.disposeVariable(v);
}
}
disposeVariable(v: Variable): void {
this.disposeTensor(v);
if (this.state.registeredVariables[v.name] != null) {
delete this.state.registeredVariables[v.name];
}
}
memory(): MemoryInfo {
const info = this.backend.memory() as MemoryInfo;
info.numTensors = this.state.numTensors;
info.numDataBuffers = this.state.numDataBuffers;
info.numBytes = this.state.numBytes;
if (this.state.numStringTensors > 0) {
info.unreliable = true;
if (info.reasons == null) {
info.reasons = [];
}
info.reasons.push(
'Memory usage by string tensors is approximate ' +
'(2 bytes per character)');
}
return info;
}
async profile(query: () => TensorContainer): Promise<ProfileInfo> {
this.state.profiling = true;
const startBytes = this.state.numBytes;
const startNumTensors = this.state.numTensors;
this.state.activeProfile.kernels = [];
this.state.activeProfile.result = query();
this.state.profiling = false;
this.state.activeProfile.peakBytes = Math.max(
...this.state.activeProfile.kernels.map(d => d.totalBytesSnapshot));
this.state.activeProfile.newBytes = this.state.numBytes - startBytes;
this.state.activeProfile.newTensors =
this.state.numTensors - startNumTensors;
return this.state.activeProfile;
}
isTapeOn(): boolean {
return this.state.gradientDepth > 0 && this.state.kernelDepth === 0;
}
private addTapeNode(
inputs: Tensor[], result: Tensor,
gradientsFunc: (dy: Tensor) => Tensor[]): void {
const inputsMap: NamedTensorMap = {};
inputs.forEach((input, idx) => {
inputsMap[idx] = input;
});
const gradient = (dy: Tensor) => {
const res = gradientsFunc(dy);
const resMap: NamedGradientMap = {};
res.forEach((r, idx) => {
resMap[idx] = () => r;
});
return resMap;
};
const tapeNode: TapeNode = {
id: this.state.nextTapeNodeId++,
name: this.state.activeScope.name,
inputs: inputsMap,
outputs: [result],
gradient
};
this.state.activeTape.push(tapeNode);
}
keep<T extends Tensor>(result: T): T {
result.kept = true;
return result;
}
private startTape() {
if (this.state.gradientDepth === 0) {
this.state.activeTape = [];
}
this.state.gradientDepth++;
}
private endTape() {
this.state.gradientDepth--;
}
/**
* Start a scope. Use this with endScope() to achieve the same functionality
* as scope() without the need for a function closure.
*/
startScope(name?: string) {
const scopeInfo: ScopeState = {
track: [],
name: 'unnamed scope',
id: this.state.nextScopeId++
};
if (name) {
scopeInfo.name = name;
}
this.state.scopeStack.push(scopeInfo);
this.state.activeScope = scopeInfo;
}
/**
* End a scope. Use this with startScope() to achieve the same functionality
* as scope() without the need for a function closure.
*/
endScope(result?: TensorContainer) {
const tensorsToTrackInParent = getTensorsInContainer(result);
const tensorsToTrackInParentSet =
new Set(tensorsToTrackInParent.map(t => t.id));
// Dispose the arrays tracked in this scope.
for (let i = 0; i < this.state.activeScope.track.length; i++) {
const tensor = this.state.activeScope.track[i];
if (!tensor.kept && !tensorsToTrackInParentSet.has(tensor.id)) {
tensor.dispose();
}
}
const oldScope = this.state.scopeStack.pop();
this.state.activeScope = this.state.scopeStack.length === 0 ?
null :
this.state.scopeStack[this.state.scopeStack.length - 1];
// Track the current result in the parent scope.
tensorsToTrackInParent.forEach(tensor => {
// Only track the tensor if was allocated in the inner scope and is not
// globally kept.
if (!tensor.kept && tensor.scopeId === oldScope.id) {
this.track(tensor);
}
});
}
/**
* Returns gradients of `f` with respect to each of the `xs`. The gradients
* returned are of the same length as `xs`, but some might be null if `f`
* was not a function of that `x`. It also takes optional dy to multiply the
* gradient, which defaults to `1`.
*/
gradients<T extends Tensor>(
f: () => T, xs: Tensor[], dy?: T,
allowNoGradients = false): {value: T, grads: Tensor[]} {
util.assert(
xs.length > 0, () => 'gradients() received an empty list of xs.');
if (dy != null && dy.dtype !== 'float32') {
throw new Error(`dy must have 'float32' dtype, but has '${dy.dtype}'`);
}
const y = this.scopedRun(
() => this.startTape(), () => this.endTape(),
() => this.tidy('forward', f));
util.assert(
y instanceof Tensor,
() => 'The result y returned by f() must be a tensor.');
// Filter out the nodes that don't connect x => y.
const filteredTape = getFilteredNodesXToY(this.state.activeTape, xs, y);
if (!allowNoGradients && filteredTape.length === 0 && xs.length > 0) {
throw new Error(
'Cannot compute gradient of y=f(x) with respect to x. Make sure ' +
'that the f you passed encloses all operations that lead from x ' +
'to y.');
}
return this.tidy('backward', () => {
const accumulatedGradientMap: {[tensorId: number]: Tensor} = {};
accumulatedGradientMap[y.id] = (dy == null) ? ones(y.shape) : dy;
// Backprop gradients through the filtered nodes.
backpropagateGradients(
accumulatedGradientMap, filteredTape,
// Pass the tidy function to avoid circular dep with `tape.ts`.
f => this.tidy(f as ScopeFn<Tensor>));
const grads = xs.map(x => accumulatedGradientMap[x.id]);
if (this.state.gradientDepth === 0) {
// This means that we are not computing higher-order gradients
// and can clean up the tape.
this.state.activeTape.forEach(node => {
for (const key in node.saved) {
node.saved[key].dispose();
}
});
this.state.activeTape = null;
}
return {value: y, grads};
});
}
customGrad<T extends Tensor>(f: CustomGradientFunc<T>):
(...args: Array<Tensor|GradSaveFunc>) => T {
util.assert(
util.isFunction(f),
() => 'The f passed in customGrad(f) must be a function.');
return (...inputs: Tensor[]): T => {
util.assert(
inputs.every(t => t instanceof Tensor),
() => 'The args passed in customGrad(f)(x1, x2,...) must all be ' +
'tensors');
let res: {
value: T,
gradFunc: (dy: T, saved: Tensor[]) => Tensor | Tensor[],
};
const inputMap: NamedTensorMap = {};
inputs.forEach((input, i) => {
inputMap[i] = input;
});
return this.runKernel(
(_, save) => {
res = f(...[...inputs, save]);
util.assert(
res.value instanceof Tensor,
() => 'The function f passed in customGrad(f) must return an ' +
'object where `obj.value` is a tensor');
util.assert(
util.isFunction(res.gradFunc),
() => 'The function f passed in customGrad(f) must return an ' +
'object where `obj.gradFunc` is a function.');
return res.value;
},
inputMap,
(dy: T, saved: Tensor[]) => {
const gradRes = res.gradFunc(dy, saved);
const grads: Tensor[] =
Array.isArray(gradRes) ? gradRes : [gradRes];
util.assert(
grads.length === inputs.length,
() => 'The function f passed in customGrad(f) must return an ' +
'object where `obj.gradFunc` is a function that returns ' +
'the same number of tensors as inputs passed to f(...).');
util.assert(
grads.every(t => t instanceof Tensor),
() => 'The function f passed in customGrad(f) must return an ' +
'object where `obj.gradFunc` is a function that returns ' +
'a list of only tensors.');
const gradMap: {[key: string]: () => Tensor} = {};
grads.forEach((grad, i) => {
gradMap[i] = () => grad;
});
return gradMap;
});
};
}
// Forwarding to backend.
write(destBackend: KernelBackend, dataId: DataId, values: BackendValues):
void {
const info = this.state.tensorInfo.get(dataId);
const srcBackend = info.backend;
destBackend = destBackend || this.backend;
// Bytes for string tensors are counted when writing.
if (info.dtype === 'string') {
const newBytes = bytesFromStringArray(values as Uint8Array[]);
this.state.numBytes += newBytes - info.bytes;
info.bytes = newBytes;
}
if (destBackend !== srcBackend) {
// Delete the tensor from the old backend and move it to the new
// backend.
srcBackend.disposeData(dataId);
info.backend = destBackend;
destBackend.register(dataId, info.shape, info.dtype);
}
destBackend.write(dataId, values);
}
readSync(dataId: DataId): BackendValues {
// Route the read to the correct backend.
const info = this.state.tensorInfo.get(dataId);
return info.backend.readSync(dataId);
}
read(dataId: DataId): Promise<BackendValues> {
// Route the read to the correct backend.
const info = this.state.tensorInfo.get(dataId);
return info.backend.read(dataId);
}
fromPixels(
pixels: PixelData|ImageData|HTMLImageElement|HTMLCanvasElement|
HTMLVideoElement,
numChannels: number): Tensor3D {
return this.backend.fromPixels(pixels, numChannels);
}
async time(query: () => void): Promise<TimingInfo> {
const start = now();
const timingInfo = await this.backend.time(query) as TimingInfo;
timingInfo.wallMs = now() - start;
return timingInfo;
}
/**
* Tracks a Tensor in the current scope to be automatically cleaned up
* when the current scope ends, and returns the value.
*
* @param result The Tensor to track in the current scope.
*/
private track<T extends Tensor>(result: T): T {
if (this.state.activeScope != null) {
result.scopeId = this.state.activeScope.id;
this.state.activeScope.track.push(result);
}
return result;
}
get registeredVariables(): NamedVariableMap {
return this.state.registeredVariables;
}
/**
* Resets the engine state. Removes all backends but does not remove
* registered backend factories.
*/
reset(): void {
// Make any pending promise obsolete.
this.pendingBackendInitId++;
this.state.dispose();
this.ENV.reset();
this.state = new EngineState();
for (const backendName in this.registry) {
this.registry[backendName].dispose();
delete this.registry[backendName];
}
this.backendName = null;
this.backendInstance = null;
this.pendingBackendInit = null;
}
}
function ones(shape: number[]): Tensor {
const values = makeOnesTypedArray(sizeFromShape(shape), 'float32');
return Tensor.make(shape, {values});
}
let GLOBAL: {_tfengine: Engine};
function getGlobalNamespace(): {_tfengine: Engine} {
if (GLOBAL == null) {
// tslint:disable-next-line:no-any
let ns: any;
if (typeof (window) !== 'undefined') {
ns = window;
} else if (typeof (global) !== 'undefined') {
ns = global;
} else if (typeof (process) !== 'undefined') {
ns = process;
} else if (typeof (self) !== 'undefined') {
ns = self;
} else {
throw new Error('Could not find a global object');
}
GLOBAL = ns;
}
return GLOBAL;
}
function getOrMakeEngine(): Engine {
const ns = getGlobalNamespace();
if (ns._tfengine == null) {
const environment = new Environment(ns);
ns._tfengine = new Engine(environment);
}
setEnvironmentGlobal(ns._tfengine.ENV);
// Tell the current tensor interface that the global engine is responsible
// for tracking.
setTensorTracker(() => ns._tfengine);
return ns._tfengine;
}
export let ENGINE = getOrMakeEngine();