@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
1,178 lines (1,063 loc) • 40.4 kB
text/typescript
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {BackendTimingInfo, DataMover, KernelBackend} from './backends/backend';
import {Environment, setEnvironmentGlobal} from './environment';
import {getGradient, getKernel, getKernelsForBackend, GradFunc, NamedAttrMap, TensorInfo} from './kernel_registry';
import {Profiler} from './profiler';
import {backpropagateGradients, getFilteredNodesXToY, TapeNode} from './tape';
import {DataId, setTensorTracker, Tensor, TensorTracker, Variable} from './tensor';
import {GradSaveFunc, NamedTensorMap, NamedVariableMap, TensorContainer} from './tensor_types';
import {getTensorsInContainer} from './tensor_util';
import {BackendValues, DataType, DataValues} from './types';
import * as util from './util';
import {bytesFromStringArray, makeOnesTypedArray, now, sizeFromShape} from './util';
/**
* A function that computes an output. The save function is for saving tensors
* computed in the forward pass, that we need in the backward pass.
*/
export type ForwardFunc<T> = (backend: KernelBackend, save?: GradSaveFunc) => T;
/**
* @docalias (a: Tensor, b: Tensor,..., save?: Function) => {
* value: Tensor,
* gradFunc: (dy: Tensor, saved?: NamedTensorMap) => Tensor | Tensor[]
* }
*/
export type CustomGradientFunc<T extends Tensor> =
(...inputs: Array<Tensor|GradSaveFunc>) => {
value: T;
gradFunc: (dy: T, saved: Tensor[]) => Tensor | Tensor[];
};
export type MemoryInfo = {
numTensors: number; numDataBuffers: number; numBytes: number;
unreliable?: boolean; reasons: string[];
};
type KernelProfile = {
name: string; bytesAdded: number; totalBytesSnapshot: number;
tensorsAdded: number;
totalTensorsSnapshot: number;
inputShapes: number[][];
outputShapes: number[][];
};
export type ProfileInfo = {
newBytes: number; newTensors: number; peakBytes: number;
kernels: KernelProfile[];
result: TensorContainer;
};
export interface TimingInfo extends BackendTimingInfo {
wallMs: number;
}
/** @docalias Function */
export type ScopeFn<T extends TensorContainer> = () => T;
interface ScopeState {
track: Tensor[];
name: string;
id: number;
}
class EngineState {
// Public since optimizers will use it.
registeredVariables: NamedVariableMap = {};
nextTapeNodeId = 0;
numBytes = 0;
numTensors = 0;
numStringTensors = 0;
numDataBuffers = 0;
activeTape: TapeNode[];
// Number of nested tf.grad() statements when computing higher-order
// gradients. E.g. `1` for first-order gradients and `2` for second-order
// gradients. Used to track if the tape should be removed after a backprop.
gradientDepth = 0;
// Number of nested kernel calls. When kernel depth is greater than 1, we turn
// off the tape.
kernelDepth = 0;
// Keep Tensors that parallel the tapes.
activeScope: ScopeState;
scopeStack: ScopeState[] = [];
/**
* Keeps track of the number of data moves during a kernel execution. We
* maintain a stack since kernels can call other kernels, recursively.
*/
numDataMovesStack: number[] = [];
nextScopeId = 0;
tensorInfo = new WeakMap<DataId, {
backend: KernelBackend,
bytes: number,
dtype: DataType,
shape: number[],
refCount: number
}>();
profiling = false;
activeProfile: ProfileInfo =
{newBytes: 0, newTensors: 0, peakBytes: 0, kernels: [], result: null};
dispose() {
for (const variableName in this.registeredVariables) {
this.registeredVariables[variableName].dispose();
}
}
}
export class Engine implements TensorTracker, DataMover {
state: EngineState;
backendName: string;
registry: {[id: string]: KernelBackend} = {};
registryFactory: {
[id: string]: {
factory: () => KernelBackend | Promise<KernelBackend>,
priority: number
}
} = {};
private profiler: Profiler;
private backendInstance: KernelBackend;
private pendingBackendInit: Promise<boolean>;
private pendingBackendInitId = 0;
constructor(public ENV: Environment) {
this.state = new EngineState();
}
async ready(): Promise<void> {
if (this.pendingBackendInit != null) {
return this.pendingBackendInit.then(() => {});
}
if (this.backendInstance != null) {
return;
}
const sortedBackends = this.getSortedBackends();
for (let i = 0; i < sortedBackends.length; i++) {
const backendName = sortedBackends[i];
const success = await this.initializeBackend(backendName).success;
if (success) {
await this.setBackend(backendName);
return;
}
}
throw new Error(
`Could not initialize any backends, all backend initializations ` +
`failed.`);
}
get backend(): KernelBackend {
if (this.pendingBackendInit != null) {
throw new Error(
`Backend '${this.backendName}' has not yet been initialized. Make ` +
`sure to await tf.ready() or await tf.setBackend() before calling ` +
`other methods`);
}
if (this.backendInstance == null) {
const {name, asyncInit} = this.initializeBackendsAndReturnBest();
if (asyncInit) {
throw new Error(
`The highest priority backend '${name}' has not yet been ` +
`initialized. Make sure to await tf.ready() or ` +
`await tf.setBackend() before calling other methods`);
}
this.setBackend(name);
}
return this.backendInstance;
}
backendNames(): string[] {
return Object.keys(this.registryFactory);
}
findBackend(backendName: string): KernelBackend {
if (!(backendName in this.registry)) {
// If the backend hasn't been initialized but we have a registry entry for
// it, initialize it and return it.
if (backendName in this.registryFactory) {
const {asyncInit} = this.initializeBackend(backendName);
if (asyncInit) {
// Backend is not ready yet.
return null;
}
} else {
return null;
}
}
return this.registry[backendName];
}
findBackendFactory(backendName: string):
() => KernelBackend | Promise<KernelBackend> {
if (!(backendName in this.registryFactory)) {
return null;
}
return this.registryFactory[backendName].factory;
}
registerBackend(
backendName: string,
factory: () => KernelBackend | Promise<KernelBackend>,
priority = 1): boolean {
if (backendName in this.registryFactory) {
console.warn(
`${backendName} backend was already registered. ` +
`Reusing existing backend factory.`);
return false;
}
this.registryFactory[backendName] = {factory, priority};
return true;
}
async setBackend(backendName: string): Promise<boolean> {
if (this.registryFactory[backendName] == null) {
throw new Error(`Backend name '${backendName}' not found in registry`);
}
this.backendName = backendName;
if (this.registry[backendName] == null) {
this.backendInstance = null;
const {success, asyncInit} = this.initializeBackend(backendName);
const result = asyncInit ? await success : success;
if (!result) {
return false;
}
}
this.backendInstance = this.registry[backendName];
this.setupRegisteredKernels();
// Reset the profiler.
this.profiler = new Profiler(this.backendInstance);
return true;
}
private setupRegisteredKernels(): void {
const kernels = getKernelsForBackend(this.backendName);
kernels.forEach(kernel => {
if (kernel.setupFunc != null) {
kernel.setupFunc(this.backendInstance);
}
});
}
private disposeRegisteredKernels(backendName: string): void {
const kernels = getKernelsForBackend(backendName);
kernels.forEach(kernel => {
if (kernel.disposeFunc != null) {
kernel.disposeFunc(this.registry[backendName]);
}
});
}
/**
* Initializes a backend by looking up the backend name in the factory
* registry and calling the factory method. Returns a boolean representing
* whether the initialization of the backend suceeded. Throws an error if
* there is no backend in the factory registry.
*/
private initializeBackend(backendName: string):
{success: boolean|Promise<boolean>, asyncInit: boolean} {
const registryFactoryEntry = this.registryFactory[backendName];
if (registryFactoryEntry == null) {
throw new Error(
`Cannot initialize backend ${backendName}, no registration found.`);
}
try {
const backend = registryFactoryEntry.factory();
// Test if the factory returns a promise.
if (Promise.resolve(backend) === backend) {
const promiseId = ++this.pendingBackendInitId;
const success =
backend
.then(backendInstance => {
// Outdated promise. Another backend was set in the meantime.
if (promiseId < this.pendingBackendInitId) {
return false;
}
this.registry[backendName] = backendInstance;
this.pendingBackendInit = null;
return true;
})
.catch(err => {
// Outdated promise. Another backend was set in the meantime.
if (promiseId < this.pendingBackendInitId) {
return false;
}
this.pendingBackendInit = null;
console.warn(
`Initialization of backend ${backendName} failed`);
console.warn(err.stack || err.message);
return false;
});
this.pendingBackendInit = success;
return {success, asyncInit: true};
} else {
this.registry[backendName] = backend as KernelBackend;
return {success: true, asyncInit: false};
}
} catch (err) {
console.warn(`Initialization of backend ${backendName} failed`);
console.warn(err.stack || err.message);
return {success: false, asyncInit: false};
}
}
removeBackend(backendName: string): void {
if (!(backendName in this.registryFactory)) {
throw new Error(`${backendName} backend not found in registry`);
}
if (this.backendName === backendName && this.pendingBackendInit != null) {
// There is a pending promise of the backend we want to remove. Make it
// obsolete.
this.pendingBackendInitId++;
}
if (backendName in this.registry) {
this.disposeRegisteredKernels(backendName);
this.registry[backendName].dispose();
delete this.registry[backendName];
}
delete this.registryFactory[backendName];
// Unset the backend if it is active.
if (this.backendName === backendName) {
this.pendingBackendInit = null;
this.backendName = null;
this.backendInstance = null;
}
}
private getSortedBackends(): string[] {
if (Object.keys(this.registryFactory).length === 0) {
throw new Error('No backend found in registry.');
}
return Object.keys(this.registryFactory).sort((a: string, b: string) => {
// Highest priority comes first.
return this.registryFactory[b].priority -
this.registryFactory[a].priority;
});
}
private initializeBackendsAndReturnBest():
{name: string, asyncInit: boolean} {
const sortedBackends = this.getSortedBackends();
for (let i = 0; i < sortedBackends.length; i++) {
const backendName = sortedBackends[i];
const {success, asyncInit} = this.initializeBackend(backendName);
if (asyncInit || success) {
return {name: backendName, asyncInit};
}
}
throw new Error(
`Could not initialize any backends, all backend initializations ` +
`failed.`);
}
moveData(destBackend: KernelBackend, dataId: DataId) {
const info = this.state.tensorInfo.get(dataId);
const srcBackend = info.backend;
const values = this.readSync(dataId);
// Delete the tensor from the old backend and move it to the new
// backend.
srcBackend.disposeData(dataId);
info.backend = destBackend;
destBackend.move(dataId, values, info.shape, info.dtype);
if (this.shouldCheckForMemLeaks()) {
// Track the number of moves during a kernel execution to correctly
// detect memory leaks.
this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++;
}
}
tidy<T extends TensorContainer>(nameOrFn: string|ScopeFn<T>, fn?: ScopeFn<T>):
T {
let name: string = null;
if (fn == null) {
// Called with only 1 argument.
if (typeof nameOrFn !== 'function') {
throw new Error('Please provide a function to tidy()');
}
fn = nameOrFn;
} else {
// Called with 2 arguments.
if (typeof nameOrFn !== 'string' && !(nameOrFn instanceof String)) {
throw new Error(
'When calling with two arguments, the first argument ' +
'to tidy() must be a string');
}
if (typeof fn !== 'function') {
throw new Error(
'When calling with two arguments, the 2nd argument ' +
'to tidy() must be a function');
}
name = nameOrFn as string;
// TODO(nsthorat,smilkov): Do operation logging and performance
// profiling.
}
let result: T;
return this.scopedRun(
() => this.startScope(name), () => this.endScope(result), () => {
result = fn();
if (result instanceof Promise) {
console.error('Cannot return a Promise inside of tidy.');
}
return result;
});
}
private scopedRun<T>(start: () => void, end: () => void, f: () => T): T {
start();
try {
const res = f();
end();
return res;
} catch (ex) {
end();
throw ex;
}
}
private static nextTensorId = 0;
private nextTensorId(): number {
return Engine.nextTensorId++;
}
private static nextVariableId = 0;
private nextVariableId(): number {
return Engine.nextVariableId++;
}
/**
* This method is called instead of the public-facing tensor.clone() when
* saving a tensor for backwards pass. It makes sure to add the clone
* operation to the tape regardless of being called inside a kernel
* execution.
*
* This method will go away once all kernels are modularized since we won't
* need to turn off the tape inside runKernel().
*/
private clone(x: Tensor): Tensor {
const y = this.makeTensorFromDataId(x.dataId, x.shape, x.dtype);
const inputs = {x};
const grad = (dy: Tensor) => ({x: () => dy.toFloat()});
const saved: Tensor[] = [];
this.addTapeNode(this.state.activeScope.name, inputs, [y], grad, saved, {});
return y;
}
/**
* Execute a kernel with the given name and return the output tensor.
*
* @param kernelName The name of the kernel to execute.
* @param inputs A map of input names to tensors.
* @param attrs A map of attribute names to their values. An attribute is a
* primitive (non-tensor) input to the kernel.
* @param inputsToSave A list of tensors, inputs to save for the backprop
* computation.
* @param outputsToSave A list of booleans, specifying which output to save
* for the backprop computation. These are booleans since the output
* tensors are not visible to the user.
*/
runKernel(
kernelName: string, inputs: NamedTensorMap, attrs: NamedAttrMap,
inputsToSave?: Tensor[], outputsToSave?: boolean[]): Tensor|Tensor[] {
const forwardFunc: null = null;
const backwardsFunc: null = null;
// Call runKernel as a stop-gap until we modularize all kernels.
// Once we modularize all kernels, we will remove the existing
// `runKernelFunc`.
return this.runKernelFunc(
forwardFunc, inputs, backwardsFunc, kernelName, attrs, inputsToSave,
outputsToSave);
}
private shouldCheckForMemLeaks(): boolean {
return this.ENV.getBool('IS_TEST');
}
private checkKernelForMemLeak(
kernelName: string, numDataIdsBefore: number,
outInfos: TensorInfo[]): void {
const numDataIdsAfter = this.backend.numDataIds();
// Count the number of data ids associated with the result of the kernel.
let numOutputDataIds = 0;
outInfos.forEach(info => {
// Complex numbers allocate 3 data ids, one for 'real', one for
// 'imaginary', and one for the container that holds the former two.
numOutputDataIds += (info.dtype === 'complex64' ? 3 : 1);
});
// Account for the number of moves during kernel execution. A "data move"
// can happen in the middle of a kernel execution, placing a new (key,value)
// pair in the data storage. Since data moves have net zero effect (we
// always remove the data from the old backend), we have to cancel them out
// when detecting memory leaks.
const numMoves =
this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1];
const dataIdsLeaked =
numDataIdsAfter - numDataIdsBefore - numOutputDataIds - numMoves;
if (dataIdsLeaked > 0) {
throw new Error(
`Backend '${this.backendName}' has an internal memory leak ` +
`(${dataIdsLeaked} data ids) after running '${kernelName}'`);
}
}
/**
* @deprecated Use `runKernel` for newly added kernels. Keep using this method
* only for kernels that are not yet fully modularized.
*/
runKernelFunc<T extends Tensor|Tensor[], I extends NamedTensorMap>(
forwardFunc: ForwardFunc<T>, inputs: I,
backwardsFunc?: (dy: T, saved: Tensor[]) => {[P in keyof I]: () => I[P]},
kernelName?: string, attrs?: NamedAttrMap, inputsToSave?: Tensor[],
outputsToSave?: boolean[]): T {
let outputs: Tensor[];
let saved: Tensor[] = [];
const isTapeOn = this.isTapeOn();
if (kernelName == null) {
kernelName =
this.state.activeScope != null ? this.state.activeScope.name : '';
}
const startingBytecount = this.state.numBytes;
const startingNumTensors = this.state.numTensors;
if (this.shouldCheckForMemLeaks()) {
this.state.numDataMovesStack.push(0);
}
let kernelFunc: () => Tensor[];
const kernel = getKernel(kernelName, this.backendName);
let out: TensorInfo|TensorInfo[];
if (kernel != null) {
kernelFunc = () => {
const numDataIdsBefore = this.backend.numDataIds();
out = kernel.kernelFunc({inputs, attrs, backend: this.backend});
const outInfos = Array.isArray(out) ? out : [out];
if (this.shouldCheckForMemLeaks()) {
this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos);
}
const outTensors = outInfos.map(
({dataId, shape, dtype}) =>
this.makeTensorFromDataId(dataId, shape, dtype));
// Save the inputs and outputs.
// Do not save unless we are recording to the tape. Otherwise it would
// cause a mem leak since we would never run backprop, which disposes
// the kept tensors.
if (isTapeOn) {
let tensorsToSave =
this.getTensorsForGradient(kernelName, inputs, outTensors);
if (tensorsToSave == null) {
// Fallback for ops that call runKernelFunc and pass in
// inputsToSave and outputsToSave. Currently this is the set of ops
// with kernel support in the WASM backend. Once those ops and
// respective gradients are modularised we can remove this path.
if (outputsToSave == null) {
outputsToSave = [];
}
const outsToSave = outTensors.filter((_, i) => outputsToSave[i]);
tensorsToSave = (inputsToSave || []).slice().concat(outsToSave);
}
saved = this.saveTensorsForBackwardMode(tensorsToSave);
}
return outTensors;
};
} else {
const saveFunc: GradSaveFunc = (tensors) => {
// Do not save unless we are recording to the tape. Otherwise it would
// cause a mem leak since we would never run backprop, which disposes
// the kept tensors.
if (!isTapeOn) {
return;
}
saved = tensors.map(tensor => this.keep(this.clone(tensor)));
};
kernelFunc = () => {
const numDataIdsBefore = this.backend.numDataIds();
out = this.tidy(() => forwardFunc(this.backend, saveFunc));
const outs = (Array.isArray(out) ? out : [out]) as Tensor[];
if (this.shouldCheckForMemLeaks()) {
this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outs);
}
return outs;
};
}
// Stop recording to a tape when running a kernel.
this.scopedRun(
() => this.state.kernelDepth++, () => this.state.kernelDepth--, () => {
if (!this.ENV.getBool('DEBUG')) {
outputs = kernelFunc();
} else {
outputs = this.profiler.profileKernel(
kernelName, inputs, () => kernelFunc());
}
});
if (isTapeOn) {
this.addTapeNode(
kernelName, inputs, outputs, backwardsFunc, saved, attrs);
}
if (this.state.profiling) {
this.state.activeProfile.kernels.push({
name: kernelName,
bytesAdded: this.state.numBytes - startingBytecount,
totalBytesSnapshot: this.state.numBytes,
tensorsAdded: this.state.numTensors - startingNumTensors,
totalTensorsSnapshot: this.state.numTensors,
inputShapes: Object.keys(inputs).map(key => inputs[key].shape),
outputShapes: outputs.map(item => item.shape)
});
}
return (Array.isArray(out) ? outputs : outputs[0]) as T;
}
/**
* Saves tensors used in forward mode for use in backward mode.
*
* @param tensors the list of tensors to save.
*/
private saveTensorsForBackwardMode(tensors: Tensor[]): Tensor[] {
const saved = tensors.map(tensor => this.keep(this.clone(tensor)));
return saved;
}
/**
* Returns a list of tensors to save for a given gradient calculation.
*
* Returns undefined if their is no registered gradient for this kernel in the
* gradient registry.
*
* @param kernelName name of kernel to look up gradient for.
* @param inputs a map of input tensors.
* @param outputs an array of output tensors from forward mode of kernel.
*/
private getTensorsForGradient(
kernelName: string, inputs: NamedTensorMap,
outputs: Tensor[]): Tensor[]|null {
const gradConfig = getGradient(kernelName);
if (gradConfig != null) {
const inputsToSave: string[] = gradConfig.inputsToSave || [];
const outputsToSave: boolean[] = gradConfig.outputsToSave || [];
// If saveAllInputs is true, all inputs will be saved. Otherwise, inputs
// specified in inputsToSave will be saved.
let inputTensorsToSave: Tensor[];
if (gradConfig.saveAllInputs) {
util.assert(
Array.isArray(inputs),
() => 'saveAllInputs is true, expected inputs to be an array.');
inputTensorsToSave = Object.keys(inputs).map((key) => inputs[key]);
} else {
inputTensorsToSave = inputsToSave.map((inputName) => inputs[inputName]);
}
const outputTensorsToSave: Tensor[] =
outputs.filter((_, i) => outputsToSave[i]);
return inputTensorsToSave.concat(outputTensorsToSave);
}
// TODO(yassogba) throw exception here once all runkernelFunc calls with
// inputsToSave/outputsToSave are removed
return null;
}
/**
* Internal method used by public APIs for tensor creation. Makes a new
* tensor with the provided shape, dtype and values. It always
* creates a new data id and writes the values to the underlying backend.
*/
makeTensor(
values: DataValues, shape: number[], dtype: DataType,
backend?: KernelBackend): Tensor {
if (values == null) {
throw new Error('Values passed to engine.makeTensor() are null');
}
dtype = dtype || 'float32';
backend = backend || this.backend;
let backendVals = values as BackendValues;
if (dtype === 'string' && util.isString(values[0])) {
backendVals = (values as string[]).map(d => util.encodeString(d));
}
const dataId = backend.write(backendVals, shape, dtype);
const t = new Tensor(shape, dtype, dataId, this.nextTensorId());
this.incRef(t, backend);
// Count bytes for string tensors.
if (dtype === 'string') {
const info = this.state.tensorInfo.get(dataId);
const newBytes = bytesFromStringArray(backendVals as Uint8Array[]);
this.state.numBytes += newBytes - info.bytes;
info.bytes = newBytes;
}
return t;
}
/**
* Internal method used by backends. Makes a new tensor
* that is a wrapper around an existing data id. It doesn't create
* a new data id, only increments the ref count used in memory tracking.
*/
makeTensorFromDataId(
dataId: DataId, shape: number[], dtype: DataType,
backend?: KernelBackend): Tensor {
dtype = dtype || 'float32';
const t = new Tensor(shape, dtype, dataId, this.nextTensorId());
this.incRef(t, backend);
return t;
}
makeVariable(
initialValue: Tensor, trainable = true, name?: string,
dtype?: DataType): Variable {
name = name || this.nextVariableId().toString();
if (dtype != null && dtype !== initialValue.dtype) {
initialValue = initialValue.asType(dtype);
}
const v = new Variable(initialValue, trainable, name, this.nextTensorId());
if (this.state.registeredVariables[v.name] != null) {
throw new Error(`Variable with name ${v.name} was already registered`);
}
this.state.registeredVariables[v.name] = v;
this.incRef(v, this.backend);
return v;
}
incRef(a: Tensor, backend: KernelBackend): void {
const refCount = this.state.tensorInfo.has(a.dataId) ?
this.state.tensorInfo.get(a.dataId).refCount :
0;
this.state.numTensors++;
if (a.dtype === 'string') {
this.state.numStringTensors++;
}
if (refCount === 0) {
this.state.numDataBuffers++;
// Bytes for complex numbers are counted by their components. Bytes for
// string tensors are counted when writing values.
let bytes = 0;
if (a.dtype !== 'complex64' && a.dtype !== 'string') {
bytes = a.size * util.bytesPerElement(a.dtype);
}
this.state.tensorInfo.set(a.dataId, {
backend: backend || this.backend,
dtype: a.dtype,
shape: a.shape,
bytes,
refCount: 0
});
this.state.numBytes += bytes;
}
this.state.tensorInfo.get(a.dataId).refCount++;
if (!(a instanceof Variable)) {
this.track(a);
}
}
disposeTensor(a: Tensor): void {
if (!this.state.tensorInfo.has(a.dataId)) {
return;
}
this.state.numTensors--;
if (a.dtype === 'string') {
this.state.numStringTensors--;
}
const info = this.state.tensorInfo.get(a.dataId);
const refCount = info.refCount;
if (refCount <= 1) {
// Don't count bytes for complex numbers as they are counted by their
// components.
if (a.dtype !== 'complex64') {
this.state.numBytes -= info.bytes;
}
this.state.numDataBuffers--;
info.backend.disposeData(a.dataId);
this.state.tensorInfo.delete(a.dataId);
} else {
this.state.tensorInfo.get(a.dataId).refCount--;
}
// TODO(nsthorat): Construct an error and save the stack trace for
// debugging when in debug mode. Creating a stack trace is too expensive
// to do unconditionally.
}
disposeVariables(): void {
for (const varName in this.state.registeredVariables) {
const v = this.state.registeredVariables[varName];
this.disposeVariable(v);
}
}
disposeVariable(v: Variable): void {
this.disposeTensor(v);
if (this.state.registeredVariables[v.name] != null) {
delete this.state.registeredVariables[v.name];
}
}
memory(): MemoryInfo {
const info = this.backend.memory() as MemoryInfo;
info.numTensors = this.state.numTensors;
info.numDataBuffers = this.state.numDataBuffers;
info.numBytes = this.state.numBytes;
if (this.state.numStringTensors > 0) {
info.unreliable = true;
if (info.reasons == null) {
info.reasons = [];
}
info.reasons.push(
'Memory usage by string tensors is approximate ' +
'(2 bytes per character)');
}
return info;
}
async profile(query: () => TensorContainer): Promise<ProfileInfo> {
this.state.profiling = true;
const startBytes = this.state.numBytes;
const startNumTensors = this.state.numTensors;
this.state.activeProfile.kernels = [];
this.state.activeProfile.result = query();
this.state.profiling = false;
this.state.activeProfile.peakBytes = Math.max(
...this.state.activeProfile.kernels.map(d => d.totalBytesSnapshot));
this.state.activeProfile.newBytes = this.state.numBytes - startBytes;
this.state.activeProfile.newTensors =
this.state.numTensors - startNumTensors;
return this.state.activeProfile;
}
isTapeOn(): boolean {
return this.state.gradientDepth > 0 && this.state.kernelDepth === 0;
}
private addTapeNode(
kernelName: string, inputs: NamedTensorMap, outputs: Tensor[],
gradientsFunc: GradFunc, saved: Tensor[], attrs: NamedAttrMap): void {
const tapeNode: TapeNode =
{id: this.state.nextTapeNodeId++, kernelName, inputs, outputs, saved};
const gradConfig = getGradient(kernelName);
if (gradConfig != null) {
gradientsFunc = gradConfig.gradFunc;
}
if (gradientsFunc != null) {
tapeNode.gradient = (dys: Tensor[]) => {
// TODO(smilkov): To optimize back-prop, pass dys that are not used in
// the backprop graph to the user as null instead of zeros
dys = dys.map((dy, i) => {
if (dy == null) {
const output = outputs[i];
const vals = util.makeZerosTypedArray(output.size, output.dtype);
return this.makeTensor(vals, output.shape, output.dtype);
}
return dy;
});
// Grad functions of ops with single outputs expect a dy, while ops
// with multiple outputs expect dys (array of dy).
return gradientsFunc(dys.length > 1 ? dys : dys[0], saved, attrs);
};
}
this.state.activeTape.push(tapeNode);
}
keep<T extends Tensor>(result: T): T {
result.kept = true;
return result;
}
private startTape() {
if (this.state.gradientDepth === 0) {
this.state.activeTape = [];
}
this.state.gradientDepth++;
}
private endTape() {
this.state.gradientDepth--;
}
/**
* Start a scope. Use this with endScope() to achieve the same functionality
* as scope() without the need for a function closure.
*/
startScope(name?: string) {
const scopeInfo: ScopeState = {
track: [],
name: 'unnamed scope',
id: this.state.nextScopeId++
};
if (name) {
scopeInfo.name = name;
}
this.state.scopeStack.push(scopeInfo);
this.state.activeScope = scopeInfo;
}
/**
* End a scope. Use this with startScope() to achieve the same functionality
* as scope() without the need for a function closure.
*/
endScope(result?: TensorContainer) {
const tensorsToTrackInParent = getTensorsInContainer(result);
const tensorsToTrackInParentSet =
new Set(tensorsToTrackInParent.map(t => t.id));
// Dispose the arrays tracked in this scope.
for (let i = 0; i < this.state.activeScope.track.length; i++) {
const tensor = this.state.activeScope.track[i];
if (!tensor.kept && !tensorsToTrackInParentSet.has(tensor.id)) {
tensor.dispose();
}
}
const oldScope = this.state.scopeStack.pop();
this.state.activeScope = this.state.scopeStack.length === 0 ?
null :
this.state.scopeStack[this.state.scopeStack.length - 1];
// Track the current result in the parent scope.
tensorsToTrackInParent.forEach(tensor => {
// Only track the tensor if was allocated in the inner scope and is not
// globally kept.
if (!tensor.kept && tensor.scopeId === oldScope.id) {
this.track(tensor);
}
});
}
/**
* Returns gradients of `f` with respect to each of the `xs`. The gradients
* returned are of the same length as `xs`, but some might be null if `f`
* was not a function of that `x`. It also takes optional dy to multiply the
* gradient, which defaults to `1`.
*/
gradients<T extends Tensor>(
f: () => T, xs: Tensor[], dy?: T,
allowNoGradients = false): {value: T, grads: Tensor[]} {
util.assert(
xs.length > 0, () => 'gradients() received an empty list of xs.');
if (dy != null && dy.dtype !== 'float32') {
throw new Error(`dy must have 'float32' dtype, but has '${dy.dtype}'`);
}
const y = this.scopedRun(
() => this.startTape(), () => this.endTape(),
() => this.tidy('forward', f));
util.assert(
y instanceof Tensor,
() => 'The result y returned by f() must be a tensor.');
// Filter out the nodes that don't connect x => y.
const filteredTape = getFilteredNodesXToY(this.state.activeTape, xs, y);
if (!allowNoGradients && filteredTape.length === 0 && xs.length > 0) {
throw new Error(
'Cannot compute gradient of y=f(x) with respect to x. Make sure ' +
'that the f you passed encloses all operations that lead from x ' +
'to y.');
}
return this.tidy('backward', () => {
const accumulatedGradientMap: {[tensorId: number]: Tensor} = {};
accumulatedGradientMap[y.id] = (dy == null) ? ones(y.shape) : dy;
// Backprop gradients through the filtered nodes.
backpropagateGradients(
accumulatedGradientMap, filteredTape,
// Pass the tidy function to avoid circular dep with `tape.ts`.
f => this.tidy(f as ScopeFn<Tensor>));
const grads = xs.map(x => accumulatedGradientMap[x.id]);
if (this.state.gradientDepth === 0) {
// This means that we are not computing higher-order gradients
// and can clean up the tape.
this.state.activeTape.forEach(node => {
for (const tensor of node.saved) {
tensor.dispose();
}
});
this.state.activeTape = null;
}
return {value: y, grads};
});
}
customGrad<T extends Tensor>(f: CustomGradientFunc<T>):
(...args: Array<Tensor|GradSaveFunc>) => T {
util.assert(
util.isFunction(f),
() => 'The f passed in customGrad(f) must be a function.');
return (...inputs: Tensor[]): T => {
util.assert(
inputs.every(t => t instanceof Tensor),
() => 'The args passed in customGrad(f)(x1, x2,...) must all be ' +
'tensors');
let res: {
value: T,
gradFunc: (dy: T, saved: Tensor[]) => Tensor | Tensor[],
};
const inputMap: NamedTensorMap = {};
inputs.forEach((input, i) => {
inputMap[i] = input;
});
return this.runKernelFunc(
(_, save) => {
res = f(...[...inputs, save]);
util.assert(
res.value instanceof Tensor,
() => 'The function f passed in customGrad(f) must return an ' +
'object where `obj.value` is a tensor');
util.assert(
util.isFunction(res.gradFunc),
() => 'The function f passed in customGrad(f) must return an ' +
'object where `obj.gradFunc` is a function.');
return res.value;
},
inputMap,
(dy: T, saved: Tensor[]) => {
const gradRes = res.gradFunc(dy, saved);
const grads: Tensor[] =
Array.isArray(gradRes) ? gradRes : [gradRes];
util.assert(
grads.length === inputs.length,
() => 'The function f passed in customGrad(f) must return an ' +
'object where `obj.gradFunc` is a function that returns ' +
'the same number of tensors as inputs passed to f(...).');
util.assert(
grads.every(t => t instanceof Tensor),
() => 'The function f passed in customGrad(f) must return an ' +
'object where `obj.gradFunc` is a function that returns ' +
'a list of only tensors.');
const gradMap: {[key: string]: () => Tensor} = {};
grads.forEach((grad, i) => {
gradMap[i] = () => grad;
});
return gradMap;
});
};
}
readSync(dataId: DataId): BackendValues {
// Route the read to the correct backend.
const info = this.state.tensorInfo.get(dataId);
return info.backend.readSync(dataId);
}
read(dataId: DataId): Promise<BackendValues> {
// Route the read to the correct backend.
const info = this.state.tensorInfo.get(dataId);
return info.backend.read(dataId);
}
async time(query: () => void): Promise<TimingInfo> {
const start = now();
const timingInfo = await this.backend.time(query) as TimingInfo;
timingInfo.wallMs = now() - start;
return timingInfo;
}
/**
* Tracks a Tensor in the current scope to be automatically cleaned up
* when the current scope ends, and returns the value.
*
* @param result The Tensor to track in the current scope.
*/
private track<T extends Tensor>(result: T): T {
if (this.state.activeScope != null) {
result.scopeId = this.state.activeScope.id;
this.state.activeScope.track.push(result);
}
return result;
}
get registeredVariables(): NamedVariableMap {
return this.state.registeredVariables;
}
/**
* Resets the engine state. Removes all backends but does not remove
* registered backend factories.
*/
reset(): void {
// Make any pending promise obsolete.
this.pendingBackendInitId++;
this.state.dispose();
this.ENV.reset();
this.state = new EngineState();
for (const backendName in this.registry) {
this.disposeRegisteredKernels(backendName);
this.registry[backendName].dispose();
delete this.registry[backendName];
}
this.backendName = null;
this.backendInstance = null;
this.pendingBackendInit = null;
}
}
function ones(shape: number[]): Tensor {
const values = makeOnesTypedArray(sizeFromShape(shape), 'float32');
return ENGINE.makeTensor(values, shape, 'float32');
}
let GLOBAL: {_tfengine: Engine};
function getGlobalNamespace(): {_tfengine: Engine} {
if (GLOBAL == null) {
// tslint:disable-next-line:no-any
let ns: any;
if (typeof (window) !== 'undefined') {
ns = window;
} else if (typeof (global) !== 'undefined') {
ns = global;
} else if (typeof (process) !== 'undefined') {
ns = process;
} else if (typeof (self) !== 'undefined') {
ns = self;
} else {
throw new Error('Could not find a global object');
}
GLOBAL = ns;
}
return GLOBAL;
}
function getOrMakeEngine(): Engine {
const ns = getGlobalNamespace();
if (ns._tfengine == null) {
const environment = new Environment(ns);
ns._tfengine = new Engine(environment);
}
setEnvironmentGlobal(ns._tfengine.ENV);
// Tell the current tensor interface that the global engine is responsible
// for tracking.
setTensorTracker(() => ns._tfengine);
return ns._tfengine;
}
export const ENGINE = getOrMakeEngine();