@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
1,028 lines • 150 kB
JavaScript
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import { KernelBackend } from './backends/backend';
import { Environment, setEnvironmentGlobal } from './environment';
import { getGlobalNamespace } from './global_util';
import { Add, Cast, Identity } from './kernel_names';
import { getGradient, getKernel, getKernelsForBackend } from './kernel_registry';
import * as log from './log';
import { Profiler } from './profiler';
import { backpropagateGradients, getFilteredNodesXToY } from './tape';
import { setTensorTracker, Tensor, Variable } from './tensor';
import { getTensorsInContainer } from './tensor_util';
import * as util from './util';
import { bytesFromStringArray, makeOnesTypedArray, now, sizeFromShape } from './util';
function isRegisteredKernelInvocation(kernelInvocation) {
return kernelInvocation.kernelName != null;
}
class EngineState {
constructor() {
// Public since optimizers will use it.
this.registeredVariables = {};
this.nextTapeNodeId = 0;
this.numBytes = 0;
this.numTensors = 0;
this.numStringTensors = 0;
this.numDataBuffers = 0;
// Number of nested tf.grad() statements when computing higher-order
// gradients. E.g. `1` for first-order gradients and `2` for second-order
// gradients. Used to track if the tape should be removed after a backprop.
this.gradientDepth = 0;
// Number of nested kernel calls. When kernel depth is greater than 1, we turn
// off the tape.
this.kernelDepth = 0;
this.scopeStack = [];
/**
* Keeps track of the number of data moves during a kernel execution. We
* maintain a stack since kernels can call other kernels, recursively.
*/
this.numDataMovesStack = [];
this.nextScopeId = 0;
this.tensorInfo = new WeakMap();
this.profiling = false;
this.activeProfile = {
newBytes: 0,
newTensors: 0,
peakBytes: 0,
kernels: [],
result: null,
get kernelNames() {
return Array.from(new Set(this.kernels.map(k => k.name)));
}
};
}
dispose() {
for (const variableName in this.registeredVariables) {
this.registeredVariables[variableName].dispose();
}
}
}
class Engine {
constructor(ENV) {
this.ENV = ENV;
this.registry = {};
this.registryFactory = {};
this.pendingBackendInitId = 0;
this.state = new EngineState();
}
async ready() {
if (this.pendingBackendInit != null) {
return this.pendingBackendInit.then(() => { });
}
if (this.backendInstance != null) {
return;
}
const sortedBackends = this.getSortedBackends();
for (let i = 0; i < sortedBackends.length; i++) {
const backendName = sortedBackends[i];
const success = await this.initializeBackend(backendName).success;
if (success) {
await this.setBackend(backendName);
return;
}
}
throw new Error(`Could not initialize any backends, all backend initializations ` +
`failed.`);
}
get backend() {
if (this.pendingBackendInit != null) {
throw new Error(`Backend '${this.backendName}' has not yet been initialized. Make ` +
`sure to await tf.ready() or await tf.setBackend() before calling ` +
`other methods`);
}
if (this.backendInstance == null) {
const { name, asyncInit } = this.initializeBackendsAndReturnBest();
if (asyncInit) {
throw new Error(`The highest priority backend '${name}' has not yet been ` +
`initialized. Make sure to await tf.ready() or ` +
`await tf.setBackend() before calling other methods`);
}
this.setBackend(name);
}
return this.backendInstance;
}
backendNames() {
return Object.keys(this.registryFactory);
}
findBackend(backendName) {
if (!(backendName in this.registry)) {
// If the backend hasn't been initialized but we have a registry entry for
// it, initialize it and return it.
if (backendName in this.registryFactory) {
const { asyncInit } = this.initializeBackend(backendName);
if (asyncInit) {
// Backend is not ready yet.
return null;
}
}
else {
return null;
}
}
return this.registry[backendName];
}
findBackendFactory(backendName) {
if (!(backendName in this.registryFactory)) {
return null;
}
return this.registryFactory[backendName].factory;
}
registerBackend(backendName, factory, priority = 1) {
if (backendName in this.registryFactory) {
log.warn(`${backendName} backend was already registered. ` +
`Reusing existing backend factory.`);
return false;
}
this.registryFactory[backendName] = { factory, priority };
return true;
}
async setBackend(backendName) {
if (this.registryFactory[backendName] == null) {
throw new Error(`Backend name '${backendName}' not found in registry`);
}
this.backendName = backendName;
if (this.registry[backendName] == null) {
this.backendInstance = null;
const { success, asyncInit } = this.initializeBackend(backendName);
const result = asyncInit ? await success : success;
if (!result) {
return false;
}
}
this.backendInstance = this.registry[backendName];
this.setupRegisteredKernels();
// Reset the profiler.
this.profiler = new Profiler(this.backendInstance);
return true;
}
setupRegisteredKernels() {
const kernels = getKernelsForBackend(this.backendName);
kernels.forEach(kernel => {
if (kernel.setupFunc != null) {
kernel.setupFunc(this.backendInstance);
}
});
}
disposeRegisteredKernels(backendName) {
const kernels = getKernelsForBackend(backendName);
kernels.forEach(kernel => {
if (kernel.disposeFunc != null) {
kernel.disposeFunc(this.registry[backendName]);
}
});
}
/**
* Initializes a backend by looking up the backend name in the factory
* registry and calling the factory method. Returns a boolean representing
* whether the initialization of the backend succeeded. Throws an error if
* there is no backend in the factory registry.
*/
initializeBackend(backendName) {
const registryFactoryEntry = this.registryFactory[backendName];
if (registryFactoryEntry == null) {
throw new Error(`Cannot initialize backend ${backendName}, no registration found.`);
}
try {
const backend = registryFactoryEntry.factory();
/* Test if the factory returns a promise.
Done in a more liberal way than
previous 'Promise.resolve(backend)===backend'
as we needed to account for custom Promise
implementations (e.g. Angular) */
if (backend && !(backend instanceof KernelBackend) &&
typeof backend.then === 'function') {
const promiseId = ++this.pendingBackendInitId;
const success = backend
.then(backendInstance => {
// Outdated promise. Another backend was set in the meantime.
if (promiseId < this.pendingBackendInitId) {
return false;
}
this.registry[backendName] = backendInstance;
this.pendingBackendInit = null;
return true;
})
.catch(err => {
// Outdated promise. Another backend was set in the meantime.
if (promiseId < this.pendingBackendInitId) {
return false;
}
this.pendingBackendInit = null;
log.warn(`Initialization of backend ${backendName} failed`);
log.warn(err.stack || err.message);
return false;
});
this.pendingBackendInit = success;
return { success, asyncInit: true };
}
else {
this.registry[backendName] = backend;
return { success: true, asyncInit: false };
}
}
catch (err) {
log.warn(`Initialization of backend ${backendName} failed`);
log.warn(err.stack || err.message);
return { success: false, asyncInit: false };
}
}
removeBackend(backendName) {
if (!(backendName in this.registryFactory)) {
throw new Error(`${backendName} backend not found in registry`);
}
if (this.backendName === backendName && this.pendingBackendInit != null) {
// There is a pending promise of the backend we want to remove. Make it
// obsolete.
this.pendingBackendInitId++;
}
if (backendName in this.registry) {
this.disposeRegisteredKernels(backendName);
this.registry[backendName].dispose();
delete this.registry[backendName];
}
delete this.registryFactory[backendName];
// Unset the backend if it is active.
if (this.backendName === backendName) {
this.pendingBackendInit = null;
this.backendName = null;
this.backendInstance = null;
}
}
getSortedBackends() {
if (Object.keys(this.registryFactory).length === 0) {
throw new Error('No backend found in registry.');
}
return Object.keys(this.registryFactory).sort((a, b) => {
// Highest priority comes first.
return this.registryFactory[b].priority -
this.registryFactory[a].priority;
});
}
initializeBackendsAndReturnBest() {
const sortedBackends = this.getSortedBackends();
for (let i = 0; i < sortedBackends.length; i++) {
const backendName = sortedBackends[i];
const { success, asyncInit } = this.initializeBackend(backendName);
if (asyncInit || success) {
return { name: backendName, asyncInit };
}
}
throw new Error(`Could not initialize any backends, all backend initializations ` +
`failed.`);
}
moveData(backend, dataId) {
const info = this.state.tensorInfo.get(dataId);
const srcBackend = info.backend;
const values = this.readSync(dataId);
const refCount = srcBackend.refCount(dataId);
// Delete the tensor from the old backend and move it to the new
// backend.
srcBackend.disposeData(dataId, true);
info.backend = backend;
backend.move(dataId, values, info.shape, info.dtype, refCount);
if (this.shouldCheckForMemLeaks()) {
// Track the number of moves during a kernel execution to correctly
// detect memory leaks.
this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++;
}
}
tidy(nameOrFn, fn) {
let name = null;
if (fn == null) {
// Called with only 1 argument.
if (typeof nameOrFn !== 'function') {
throw new Error('Please provide a function to tidy()');
}
fn = nameOrFn;
}
else {
// Called with 2 arguments.
if (typeof nameOrFn !== 'string' && !(nameOrFn instanceof String)) {
throw new Error('When calling with two arguments, the first argument ' +
'to tidy() must be a string');
}
if (typeof fn !== 'function') {
throw new Error('When calling with two arguments, the 2nd argument ' +
'to tidy() must be a function');
}
name = nameOrFn;
// TODO(nsthorat,smilkov): Do operation logging and performance
// profiling.
}
let result;
return this.scopedRun(() => this.startScope(name), () => this.endScope(result), () => {
result = fn();
if (result instanceof Promise) {
console.error('Cannot return a Promise inside of tidy.');
}
return result;
});
}
scopedRun(start, end, f) {
start();
try {
const res = f();
end();
return res;
}
catch (ex) {
end();
throw ex;
}
}
nextTensorId() {
return Engine.nextTensorId++;
}
nextVariableId() {
return Engine.nextVariableId++;
}
/**
* This method is called instead of the public-facing tensor.clone() when
* saving a tensor for backwards pass. It makes sure to add the clone
* operation to the tape regardless of being called inside a kernel
* execution.
*/
clone(x) {
const y = ENGINE.runKernel(Identity, { x });
const inputs = { x };
const grad = (dy) => ({
x: () => {
const dtype = 'float32';
const gradInputs = { x: dy };
const attrs = { dtype };
return ENGINE.runKernel(Cast, gradInputs,
// tslint:disable-next-line: no-unnecessary-type-assertion
attrs);
}
});
const saved = [];
this.addTapeNode(this.state.activeScope.name, inputs, [y], grad, saved, {});
return y;
}
/**
* Execute a kernel with the given name and return the output tensor.
*
* @param kernelName The name of the kernel to execute.
* @param inputs A map of input names to tensors.
* @param attrs A map of attribute names to their values. An attribute is a
* primitive (non-tensor) input to the kernel.
* @param inputsToSave A list of tensors, inputs to save for the backprop
* computation.
* @param outputsToSave A list of booleans, specifying which output to save
* for the backprop computation. These are booleans since the output
* tensors are not visible to the user.
*/
runKernel(kernelName, inputs, attrs) {
if (this.backendName == null) {
// backend has not been initialized yet (backend initialization is lazy
// can be deferred until an op/ kernel is run).
// The below getter has side effects that will try to initialize the
// backend and set properties like this.backendName
// tslint:disable-next-line: no-unused-expression
this.backend;
}
const hasKernel = getKernel(kernelName, this.backendName) != null;
if (!hasKernel) {
throw new Error(`Kernel '${kernelName}' not registered for backend '${this.backendName}'`);
}
return this.runKernelFunc({ kernelName, inputs, attrs });
}
shouldCheckForMemLeaks() {
return this.ENV.getBool('IS_TEST');
}
checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos) {
const numDataIdsAfter = this.backend.numDataIds();
// Count the number of data ids associated with the result of the kernel.
let numOutputDataIds = 0;
outInfos.forEach(info => {
// Complex numbers allocate 3 data ids, one for 'real', one for
// 'imaginary', and one for the container that holds the former two.
numOutputDataIds += (info.dtype === 'complex64' ? 3 : 1);
});
// Account for the number of moves during kernel execution. A "data move"
// can happen in the middle of a kernel execution, placing a new (key,value)
// pair in the data storage. Since data moves have net zero effect (we
// always remove the data from the old backend), we have to cancel them out
// when detecting memory leaks.
const numMoves = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1];
const dataIdsLeaked = numDataIdsAfter - numDataIdsBefore - numOutputDataIds - numMoves;
if (dataIdsLeaked > 0) {
throw new Error(`Backend '${this.backendName}' has an internal memory leak ` +
`(${dataIdsLeaked} data ids) after running '${kernelName}'`);
}
}
/**
* Internal helper method to execute a kernel Func
*
* Use `runKernel` to execute kernels from outside of engine.
*/
runKernelFunc(kernelParams) {
let outputs;
let saved = [];
const isTapeOn = this.isTapeOn();
const startingBytecount = this.state.numBytes;
const startingNumTensors = this.state.numTensors;
if (this.shouldCheckForMemLeaks()) {
this.state.numDataMovesStack.push(0);
}
let kernelFunc;
if (this.backendName == null) {
// backend has not been initialized yet (backend initialization is lazy
// can be deferred until an op/ kernel is run).
// The below getter has side effects that will try to initialize the
// backend and set properties like this.backendName
// tslint:disable-next-line: no-unused-expression
this.backend;
}
let out;
const kernelOrScopeName = isRegisteredKernelInvocation(kernelParams) ?
kernelParams.kernelName :
this.state.activeScope != null ? this.state.activeScope.name : '';
// Create the kernelFunc from either a registered kernel OR passed in
// forward/backward functions (used by custom grad). In this context a
// kernelFunc wraps a kernel implementation with some bookkeeping.
if (isRegisteredKernelInvocation(kernelParams)) {
const { kernelName, inputs, attrs } = kernelParams;
if (this.backendName == null) {
// backend has not been initialized yet (backend initialization is lazy
// can be deferred until an op/ kernel is run).
// The below getter has side effects that will try to initialize the
// backend and set properties like this.backendName
// tslint:disable-next-line: no-unused-expression
this.backend;
}
const kernel = getKernel(kernelName, this.backendName);
util.assert(kernel != null, () => `Cannot find registered kernel '${kernelName}' for backend '${this.backendName}'`);
kernelFunc = () => {
const numDataIdsBefore = this.backend.numDataIds();
out = kernel.kernelFunc({ inputs, attrs, backend: this.backend });
const outInfos = Array.isArray(out) ? out : [out];
if (this.shouldCheckForMemLeaks()) {
this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos);
}
const outTensors = outInfos.map((outInfo) => {
// todo (yassogba) remove this option (Tensor) when node backend
// methods have been modularized and they all return tensorInfo.
// TensorInfos do not have a rank attribute.
if (outInfo.rank != null) {
return outInfo;
}
return this.makeTensorFromTensorInfo(outInfo);
});
// Save any required inputs and outputs.
// Do not save unless we are recording to the tape. Otherwise it would
// cause a mem leak since there would be no backprop for these tensors
// (which would otherwise dispose them).
if (isTapeOn) {
const tensorsToSave = this.getTensorsForGradient(kernelName, inputs, outTensors);
saved = this.saveTensorsForBackwardMode(tensorsToSave);
}
return outTensors;
};
}
else {
const { forwardFunc } = kernelParams;
// Running a customGrad op.
const saveFunc = (tensors) => {
// Do not save unless we are recording to the tape. Otherwise it would
// cause a mem leak since we would never run backprop, which disposes
// the kept tensors.
if (!isTapeOn) {
return;
}
saved = tensors.map(tensor => this.keep(this.clone(tensor)));
};
kernelFunc = () => {
const numDataIdsBefore = this.backend.numDataIds();
out = this.tidy(() => forwardFunc(this.backend, saveFunc));
const outs = (Array.isArray(out) ? out : [out]);
if (this.shouldCheckForMemLeaks()) {
// Scope name is used to print a more helpful error message if needed.
this.checkKernelForMemLeak(kernelOrScopeName, numDataIdsBefore, outs);
}
return outs;
};
}
//
// Run the kernelFunc. Optionally profiling it.
//
const { inputs, attrs } = kernelParams;
const backwardsFunc = isRegisteredKernelInvocation(kernelParams) ?
null :
kernelParams.backwardsFunc;
let kernelProfile;
this.scopedRun(
// Stop recording to a tape when running a kernel.
() => this.state.kernelDepth++, () => this.state.kernelDepth--, () => {
if (!this.ENV.getBool('DEBUG') && !this.state.profiling) {
outputs = kernelFunc();
}
else {
kernelProfile = this.profiler.profileKernel(kernelOrScopeName, inputs, () => kernelFunc());
if (this.ENV.getBool('DEBUG')) {
this.profiler.logKernelProfile(kernelProfile);
}
outputs = kernelProfile.outputs;
}
});
if (isTapeOn) {
this.addTapeNode(kernelOrScopeName, inputs, outputs, backwardsFunc, saved, attrs);
}
if (this.state.profiling) {
this.state.activeProfile.kernels.push({
name: kernelOrScopeName,
bytesAdded: this.state.numBytes - startingBytecount,
totalBytesSnapshot: this.state.numBytes,
tensorsAdded: this.state.numTensors - startingNumTensors,
totalTensorsSnapshot: this.state.numTensors,
inputShapes: Object.keys(inputs).map(key => inputs[key] != null ? inputs[key].shape : null),
outputShapes: outputs.map(item => item.shape),
kernelTimeMs: kernelProfile.timeMs,
extraInfo: kernelProfile.extraInfo
});
}
return (Array.isArray(out) ? outputs : outputs[0]);
}
/**
* Saves tensors used in forward mode for use in backward mode.
*
* @param tensors the list of tensors to save.
*/
saveTensorsForBackwardMode(tensors) {
const saved = tensors.map(tensor => this.keep(this.clone(tensor)));
return saved;
}
/**
* Returns a list of tensors to save for a given gradient calculation.
*
* @param kernelName name of kernel to look up gradient for.
* @param inputs a map of input tensors.
* @param outputs an array of output tensors from forward mode of kernel.
*/
getTensorsForGradient(kernelName, inputs, outputs) {
const gradConfig = getGradient(kernelName);
if (gradConfig != null) {
const inputsToSave = gradConfig.inputsToSave || [];
const outputsToSave = gradConfig.outputsToSave || [];
// If saveAllInputs is true, all inputs will be saved. Otherwise, inputs
// specified in inputsToSave will be saved.
let inputTensorsToSave;
if (gradConfig.saveAllInputs) {
util.assert(Array.isArray(inputs), () => 'saveAllInputs is true, expected inputs to be an array.');
inputTensorsToSave = Object.keys(inputs).map((key) => inputs[key]);
}
else {
inputTensorsToSave = inputsToSave.map((inputName) => inputs[inputName]);
}
const outputTensorsToSave = outputs.filter((_, i) => outputsToSave[i]);
return inputTensorsToSave.concat(outputTensorsToSave);
}
// We return an empty list rather than throw an error because the kernel we
// are looking up may not actually be relevant to backproping through the
// overall function
//
// See 'does not error if irrelevant (pruned) ops are missing grads' test
// in gradients_test.ts for an example.
return [];
}
/**
* Internal method used by public APIs for tensor creation. Makes a new
* tensor with the provided shape, dtype and values. It always
* creates a new data id and writes the values to the underlying backend.
*/
makeTensor(values, shape, dtype, backend) {
if (values == null) {
throw new Error('Values passed to engine.makeTensor() are null');
}
dtype = dtype || 'float32';
backend = backend || this.backend;
let backendVals = values;
if (dtype === 'string' && util.isString(values[0])) {
backendVals = values.map(d => util.encodeString(d));
}
const dataId = backend.write(backendVals, shape, dtype);
const t = new Tensor(shape, dtype, dataId, this.nextTensorId());
this.trackTensor(t, backend);
// Count bytes for string tensors.
if (dtype === 'string') {
const info = this.state.tensorInfo.get(dataId);
const newBytes = bytesFromStringArray(backendVals);
this.state.numBytes += newBytes - info.bytes;
info.bytes = newBytes;
}
return t;
}
/**
* Internal method used by backends. Makes a new tensor
* that is a wrapper around an existing data id. It doesn't create
* a new data id, only increments the ref count used in memory tracking.
* @deprecated
*/
makeTensorFromDataId(dataId, shape, dtype, backend) {
dtype = dtype || 'float32';
const tensorInfo = { dataId, shape, dtype };
return this.makeTensorFromTensorInfo(tensorInfo, backend);
}
/**
* Internal method used by backends. Makes a new tensor that is a wrapper
* around an existing data id in TensorInfo. It doesn't create a new data id,
* only increments the ref count used in memory tracking.
*/
makeTensorFromTensorInfo(tensorInfo, backend) {
const { dataId, shape, dtype } = tensorInfo;
const t = new Tensor(shape, dtype, dataId, this.nextTensorId());
this.trackTensor(t, backend);
return t;
}
makeVariable(initialValue, trainable = true, name, dtype) {
name = name || this.nextVariableId().toString();
if (dtype != null && dtype !== initialValue.dtype) {
initialValue = initialValue.cast(dtype);
}
const v = new Variable(initialValue, trainable, name, this.nextTensorId());
if (this.state.registeredVariables[v.name] != null) {
throw new Error(`Variable with name ${v.name} was already registered`);
}
this.state.registeredVariables[v.name] = v;
this.incRef(v, this.backend);
return v;
}
trackTensor(a, backend) {
this.state.numTensors++;
if (a.dtype === 'string') {
this.state.numStringTensors++;
}
// Bytes for complex numbers are counted by their components. Bytes for
// string tensors are counted when writing values.
let bytes = 0;
if (a.dtype !== 'complex64' && a.dtype !== 'string') {
bytes = a.size * util.bytesPerElement(a.dtype);
}
this.state.numBytes += bytes;
if (!this.state.tensorInfo.has(a.dataId)) {
this.state.numDataBuffers++;
this.state.tensorInfo.set(a.dataId, {
backend: backend || this.backend,
dtype: a.dtype,
shape: a.shape,
bytes
});
}
if (!(a instanceof Variable)) {
this.track(a);
}
}
// Track the tensor by dataId and increase the refCount for the dataId in the
// backend.
// TODO(pyu10055): This is currently used by makeVariable method, to increase
// refCount on the backend for the dataId. It can potentially be replaced with
// Identity op indead of calling backend directly.
incRef(a, backend) {
this.trackTensor(a, backend);
this.backend.incRef(a.dataId);
}
removeDataId(dataId, backend) {
if (this.state.tensorInfo.has(dataId) &&
this.state.tensorInfo.get(dataId).backend === backend) {
this.state.tensorInfo.delete(dataId);
this.state.numDataBuffers--;
}
}
disposeTensor(a) {
if (!this.state.tensorInfo.has(a.dataId)) {
return;
}
const info = this.state.tensorInfo.get(a.dataId);
this.state.numTensors--;
if (a.dtype === 'string') {
this.state.numStringTensors--;
this.state.numBytes -= info.bytes;
}
// Don't count bytes for complex numbers as they are counted by their
// components.
if (a.dtype !== 'complex64' && a.dtype !== 'string') {
const bytes = a.size * util.bytesPerElement(a.dtype);
this.state.numBytes -= bytes;
}
// Remove the reference to dataId if backend dispose the data successfully
if (info.backend.disposeData(a.dataId)) {
this.removeDataId(a.dataId, info.backend);
}
// TODO(nsthorat): Construct an error and save the stack trace for
// debugging when in debug mode. Creating a stack trace is too expensive
// to do unconditionally.
}
disposeVariables() {
for (const varName in this.state.registeredVariables) {
const v = this.state.registeredVariables[varName];
this.disposeVariable(v);
}
}
disposeVariable(v) {
this.disposeTensor(v);
if (this.state.registeredVariables[v.name] != null) {
delete this.state.registeredVariables[v.name];
}
}
memory() {
const info = this.backend.memory();
info.numTensors = this.state.numTensors;
info.numDataBuffers = this.state.numDataBuffers;
info.numBytes = this.state.numBytes;
if (this.state.numStringTensors > 0) {
info.unreliable = true;
if (info.reasons == null) {
info.reasons = [];
}
info.reasons.push('Memory usage by string tensors is approximate ' +
'(2 bytes per character)');
}
return info;
}
async profile(query) {
this.state.profiling = true;
const startBytes = this.state.numBytes;
const startNumTensors = this.state.numTensors;
this.state.activeProfile.kernels = [];
this.state.activeProfile.result = await query();
this.state.profiling = false;
this.state.activeProfile.peakBytes = Math.max(...this.state.activeProfile.kernels.map(d => d.totalBytesSnapshot));
this.state.activeProfile.newBytes = this.state.numBytes - startBytes;
this.state.activeProfile.newTensors =
this.state.numTensors - startNumTensors;
for (const kernel of this.state.activeProfile.kernels) {
kernel.kernelTimeMs = await kernel.kernelTimeMs;
kernel.extraInfo = await kernel.extraInfo;
}
return this.state.activeProfile;
}
isTapeOn() {
return this.state.gradientDepth > 0 && this.state.kernelDepth === 0;
}
addTapeNode(kernelName, inputs, outputs, gradientsFunc, saved, attrs) {
const tapeNode = { id: this.state.nextTapeNodeId++, kernelName, inputs, outputs, saved };
const gradConfig = getGradient(kernelName);
if (gradConfig != null) {
gradientsFunc = gradConfig.gradFunc;
}
if (gradientsFunc != null) {
tapeNode.gradient = (dys) => {
// TODO(smilkov): To optimize back-prop, pass dys that are not used in
// the backprop graph to the user as null instead of zeros
dys = dys.map((dy, i) => {
if (dy == null) {
const output = outputs[i];
const vals = util.makeZerosTypedArray(output.size, output.dtype);
return this.makeTensor(vals, output.shape, output.dtype);
}
return dy;
});
// Grad functions of ops with single outputs expect a dy, while ops
// with multiple outputs expect dys (array of dy).
return gradientsFunc(dys.length > 1 ? dys : dys[0], saved, attrs);
};
}
this.state.activeTape.push(tapeNode);
}
keep(result) {
result.kept = true;
return result;
}
startTape() {
if (this.state.gradientDepth === 0) {
this.state.activeTape = [];
}
this.state.gradientDepth++;
}
endTape() {
this.state.gradientDepth--;
}
/**
* Start a scope. Use this with endScope() to achieve the same functionality
* as scope() without the need for a function closure.
*/
startScope(name) {
const scopeInfo = {
track: [],
name: 'unnamed scope',
id: this.state.nextScopeId++
};
if (name) {
scopeInfo.name = name;
}
this.state.scopeStack.push(scopeInfo);
this.state.activeScope = scopeInfo;
}
/**
* End a scope. Use this with startScope() to achieve the same functionality
* as scope() without the need for a function closure.
*/
endScope(result) {
const tensorsToTrackInParent = getTensorsInContainer(result);
const tensorsToTrackInParentSet = new Set(tensorsToTrackInParent.map(t => t.id));
// Dispose the arrays tracked in this scope.
for (let i = 0; i < this.state.activeScope.track.length; i++) {
const tensor = this.state.activeScope.track[i];
if (!tensor.kept && !tensorsToTrackInParentSet.has(tensor.id)) {
tensor.dispose();
}
}
const oldScope = this.state.scopeStack.pop();
this.state.activeScope = this.state.scopeStack.length === 0 ?
null :
this.state.scopeStack[this.state.scopeStack.length - 1];
// Track the current result in the parent scope.
tensorsToTrackInParent.forEach(tensor => {
// Only track the tensor if was allocated in the inner scope and is not
// globally kept.
if (!tensor.kept && tensor.scopeId === oldScope.id) {
this.track(tensor);
}
});
}
/**
* Returns gradients of `f` with respect to each of the `xs`. The gradients
* returned are of the same length as `xs`, but some might be null if `f`
* was not a function of that `x`. It also takes optional dy to multiply the
* gradient, which defaults to `1`.
*/
gradients(f, xs, dy, allowNoGradients = false) {
util.assert(xs.length > 0, () => 'gradients() received an empty list of xs.');
if (dy != null && dy.dtype !== 'float32') {
throw new Error(`dy must have 'float32' dtype, but has '${dy.dtype}'`);
}
const y = this.scopedRun(() => this.startTape(), () => this.endTape(), () => this.tidy('forward', f));
util.assert(y instanceof Tensor, () => 'The result y returned by f() must be a tensor.');
// Filter out the nodes that don't connect x => y.
const filteredTape = getFilteredNodesXToY(this.state.activeTape, xs, y);
if (!allowNoGradients && filteredTape.length === 0 && xs.length > 0) {
throw new Error('Cannot compute gradient of y=f(x) with respect to x. Make sure ' +
'that the f you passed encloses all operations that lead from x ' +
'to y.');
}
return this.tidy('backward', () => {
const accumulatedGradientMap = {};
accumulatedGradientMap[y.id] = (dy == null) ? ones(y.shape) : dy;
// Backprop gradients through the filtered nodes.
backpropagateGradients(accumulatedGradientMap, filteredTape,
// Pass the tidy function to avoid circular dep with `tape.ts`.
f => this.tidy(f),
// Pass an add function to avoide a circular dep with `tape.ts`.
add);
const grads = xs.map(x => accumulatedGradientMap[x.id]);
if (this.state.gradientDepth === 0) {
// This means that we are not computing higher-order gradients
// and can clean up the tape.
this.state.activeTape.forEach(node => {
for (const tensor of node.saved) {
tensor.dispose();
}
});
this.state.activeTape = null;
}
return { value: y, grads };
});
}
customGrad(f) {
util.assert(util.isFunction(f), () => 'The f passed in customGrad(f) must be a function.');
return (...inputs) => {
util.assert(inputs.every(t => t instanceof Tensor), () => 'The args passed in customGrad(f)(x1, x2,...) must all be ' +
'tensors');
let res;
const inputMap = {};
inputs.forEach((input, i) => {
inputMap[i] = input;
});
const forwardFunc = (_, save) => {
res = f(...[...inputs, save]);
util.assert(res.value instanceof Tensor, () => 'The function f passed in customGrad(f) must return an ' +
'object where `obj.value` is a tensor');
util.assert(util.isFunction(res.gradFunc), () => 'The function f passed in customGrad(f) must return an ' +
'object where `obj.gradFunc` is a function.');
return res.value;
};
const backwardsFunc = (dy, saved) => {
const gradRes = res.gradFunc(dy, saved);
const grads = Array.isArray(gradRes) ? gradRes : [gradRes];
util.assert(grads.length === inputs.length, () => 'The function f passed in customGrad(f) must return an ' +
'object where `obj.gradFunc` is a function that returns ' +
'the same number of tensors as inputs passed to f(...).');
util.assert(grads.every(t => t instanceof Tensor), () => 'The function f passed in customGrad(f) must return an ' +
'object where `obj.gradFunc` is a function that returns ' +
'a list of only tensors.');
const gradMap = {};
grads.forEach((grad, i) => {
gradMap[i] = () => grad;
});
return gradMap;
};
return this.runKernelFunc({
forwardFunc,
backwardsFunc,
inputs: inputMap,
});
};
}
readSync(dataId) {
// Route the read to the correct backend.
const info = this.state.tensorInfo.get(dataId);
return info.backend.readSync(dataId);
}
read(dataId) {
// Route the read to the correct backend.
const info = this.state.tensorInfo.get(dataId);
return info.backend.read(dataId);
}
readToGPU(dataId, options) {
// Route the read to the correct backend.
const info = this.state.tensorInfo.get(dataId);
return info.backend.readToGPU(dataId, options);
}
async time(query) {
const start = now();
const timingInfo = await this.backend.time(query);
timingInfo.wallMs = now() - start;
return timingInfo;
}
/**
* Tracks a Tensor in the current scope to be automatically cleaned up
* when the current scope ends, and returns the value.
*
* @param result The Tensor to track in the current scope.
*/
track(result) {
if (this.state.activeScope != null) {
result.scopeId = this.state.activeScope.id;
this.state.activeScope.track.push(result);
}
return result;
}
get registeredVariables() {
return this.state.registeredVariables;
}
/**
* Resets the engine state. Removes all backends but does not remove
* registered backend factories.
*/
reset() {
// Make any pending promise obsolete.
this.pendingBackendInitId++;
this.state.dispose();
this.ENV.reset();
this.state = new EngineState();
for (const backendName in this.registry) {
this.disposeRegisteredKernels(backendName);
this.registry[backendName].dispose();
delete this.registry[backendName];
}
this.backendName = null;
this.backendInstance = null;
this.pendingBackendInit = null;
}
}
Engine.nextTensorId = 0;
Engine.nextVariableId = 0;
export { Engine };
function ones(shape) {
const values = makeOnesTypedArray(sizeFromShape(shape), 'float32');
return ENGINE.makeTensor(values, shape, 'float32');
}
export function getOrMakeEngine() {
const ns = getGlobalNamespace();
if (ns._tfengine == null) {
const environment = new Environment(ns);
ns._tfengine = new Engine(environment);
}
setEnvironmentGlobal(ns._tfengine.ENV);
// Tell the current tensor interface that the global engine is responsible
// for tracking.
setTensorTracker(() => ns._tfengine);
return ns._tfengine;
}
export const ENGINE = getOrMakeEngine();
/**
* A implementation of the add op for use within engine and tape.
*
* This allows us to avoid a circular dependency between add.ts and engine.
* It is exported to be available in tape tests.
*/
export function add(a, b) {
// We duplicate Add here to avoid a circular dependency with add.ts.
const inputs = { a, b };
return ENGINE.runKernel(Add, inputs);
}
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiZW5naW5lLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vdGZqcy1jb3JlL3NyYy9lbmdpbmUudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUErQixhQUFhLEVBQUMsTUFBTSxvQkFBb0IsQ0FBQztBQUMvRSxPQUFPLEVBQUMsV0FBVyxFQUFFLG9CQUFvQixFQUFDLE1BQU0sZUFBZSxDQUFDO0FBQ2hFLE9BQU8sRUFBQyxrQkFBa0IsRUFBQyxNQUFNLGVBQWUsQ0FBQztBQUNqRCxPQUFPLEVBQUMsR0FBRyxFQUFFLElBQUksRUFBRSxRQUFRLEVBQUMsTUFBTSxnQkFBZ0IsQ0FBQztBQUNuRCxPQUFPLEVBQUUsV0FBVyxFQUFFLFNBQVMsRUFBRSxvQkFBb0IsRUFBMEIsTUFBTSxtQkFBbUIsQ0FBQztBQUV6RyxPQUFPLEtBQUssR0FBRyxNQUFNLE9BQU8sQ0FBQztBQUM3QixPQUFPLEVBQWdCLFFBQVEsRUFBQyxNQUFNLFlBQVksQ0FBQztBQUNuRCxPQUFPLEVBQUMsc0JBQXNCLEVBQUUsb0JBQW9CLEVBQVcsTUFBTSxRQUFRLENBQUM7QUFDOUUsT0FBTyxFQUE0QixnQkFBZ0IsRUFBRSxNQUFNLEVBQWlCLFFBQVEsRUFBQyxNQUFNLFVBQVUsQ0FBQztBQUd0RyxPQUFPLEVBQUMscUJBQXFCLEVBQUMsTUFBTSxlQUFlLENBQUM7QUFFcEQsT0FBTyxLQUFLLElBQUksTUFBTSxRQUFRLENBQUM7QUFDL0IsT0FBTyxFQUFDLG9CQUFvQixFQUFFLGtCQUFrQixFQUFFLEdBQUcsRUFBRSxhQUFhLEVBQUMsTUFBTSxRQUFRLENBQUM7QUF1RXBGLFNBQVMsNEJBQTRCLENBRWpDLGdCQUNnQztJQUVsQyxPQUFRLGdCQUFrRCxDQUFDLFVBQVUsSUFBSSxJQUFJLENBQUM7QUFDaEYsQ0FBQztBQUVELE1BQU0sV0FBVztJQUFqQjtRQUNFLHVDQUF1QztRQUN2Qyx3QkFBbUIsR0FBcUIsRUFBRSxDQUFDO1FBRTNDLG1CQUFjLEdBQUcsQ0FBQyxDQUFDO1FBQ25CLGFBQVEsR0FBRyxDQUFDLENBQUM7UUFDYixlQUFVLEdBQUcsQ0FBQyxDQUFDO1FBQ2YscUJBQWdCLEdBQUcsQ0FBQyxDQUFDO1FBQ3JCLG1CQUFjLEdBQUcsQ0FBQyxDQUFDO1FBR25CLG9FQUFvRTtRQUNwRSx5RUFBeUU7UUFDekUsMkVBQTJFO1FBQzNFLGtCQUFhLEdBQUcsQ0FBQyxDQUFDO1FBQ2xCLDhFQUE4RTtRQUM5RSxnQkFBZ0I7UUFDaEIsZ0JBQVcsR0FBRyxDQUFDLENBQUM7UUFJaEIsZUFBVSxHQUFpQixFQUFFLENBQUM7UUFDOUI7OztXQUdHO1FBQ0gsc0JBQWlCLEdBQWEsRUFBRSxDQUFDO1FBQ2pDLGdCQUFXLEdBQUcsQ0FBQyxDQUFDO1FBRWhCLGVBQVUsR0FBRyxJQUFJLE9BQU8sRUFLcEIsQ0FBQztRQUVMLGNBQVMsR0FBRyxLQUFLLENBQUM7UUFDbEIsa0JBQWEsR0FBZ0I7WUFDM0IsUUFBUSxFQUFFLENBQUM7WUFDWCxVQUFVLEVBQUUsQ0FBQztZQUNiLFNBQVMsRUFBRSxDQUFDO1lBQ1osT0FBTyxFQUFFLEVBQUU7WUFDWCxNQUFNLEVBQUUsSUFBSTtZQUNaLElBQUksV0FBVztnQkFFVCxPQUFPLEtBQUssQ0FBQyxJQUFJLENBQUMsSUFBSSxHQUFHLENBQUMsSUFBSSxDQUFDLE9BQU8sQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQUMsSUFBSSxDQUFDLENBQUMsQ0FBQyxDQUFDO1lBQzVELENBQUM7U0FDTixDQUFDO0lBT0osQ0FBQztJQUxDLE9BQU87UUFDTCxLQUFLLE1BQU0sWUFBWSxJQUFJLElBQUksQ0FBQyxtQkFBbUIsRUFBRTtZQUNuRCxJQUFJLENBQUMsbUJBQW1CLENBQUMsWUFBWSxDQUFDLENBQUMsT0FBTyxFQUFFLENBQUM7U0FDbEQ7SUFDSCxDQUFDO0NBQ0Y7QUFFRCxNQUFhLE1BQU07SUFnQmpCLFlBQW1CLEdBQWdCO1FBQWhCLFFBQUcsR0FBSCxHQUFHLENBQWE7UUFibkMsYUFBUSxHQUFrQyxFQUFFLENBQUM7UUFDN0Msb0JBQWUsR0FLWCxFQUFFLENBQUM7UUFLQyx5QkFBb0IsR0FBRyxDQUFDLENBQUM7UUFHL0IsSUFBSSxDQUFDLEtBQUssR0FBRyxJQUFJLFdBQVcsRUFBRSxDQUFDO0lBQ2pDLENBQUM7SUFFRCxLQUFLLENBQUMsS0FBSztRQUNULElBQUksSUFBSSxDQUFDLGtCQUFrQixJQUFJLElBQUksRUFBRTtZQUNuQyxPQUFPLElBQUksQ0FBQyxrQkFBa0IsQ0FBQyxJQUFJLENBQUMsR0FBRyxFQUFFLEdBQUUsQ0FBQyxDQUFDLENBQUM7U0FDL0M7UUFDRCxJQUFJLElBQUksQ0FBQyxlQUFlLElBQUksSUFBSSxFQUFFO1lBQ2hDLE9BQU87U0FDUjtRQUNELE1BQU0sY0FBYyxHQUFHLElBQUksQ0FBQyxpQkFBaUIsRUFBRSxDQUFDO1FBRWhELEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxjQUFjLENBQUMsTUFBTSxFQUFFLENBQUMsRUFBRSxFQUFFO1lBQzlDLE1BQU0sV0FBVyxHQUFHLGNBQWMsQ0FBQyxDQUFDLENBQUMsQ0FBQztZQUN0QyxNQUFNLE9BQU8sR0FBRyxNQUFNLElBQUksQ0FBQyxpQkFBaUIsQ0FBQyxXQUFXLENBQUMsQ0FBQyxPQUFPLENBQUM7WUFDbEUsSUFBSSxPQUFPLEVBQUU7Z0JBQ1gsTUFBTSxJQUFJLENBQUMsVUFBVSxDQUFDLFdBQVcsQ0FBQyxDQUFDO2dCQUNuQyxPQUFPO2FBQ1I7U0FDRjtRQUVELE1BQU0sSUFBSSxLQUFLLENBQ1gsaUVBQWlFO1lBQ2pFLFNBQVMsQ0FBQyxDQUFDO0lBQ2pCLENBQUM7SUFFRCxJQUFJLE9BQU87UUFDVCxJQUFJLElBQUksQ0FBQyxrQkFBa0IsSUFBSSxJQUFJLEVBQUU7WUFDbkMsTUFBTSxJQUFJLEtBQUssQ0FDWCxZQUFZLElBQUksQ0FBQyxXQUFXLHVDQUF1QztnQkFDbkUsbUVBQW1FO2dCQUNuRSxlQUFlLENBQUMsQ0FBQztTQUN0QjtRQUNELElBQUksSUFBSSxDQUFDLGVBQWUsSUFBSSxJQUFJLEVBQUU7WUFDaEMsTUFBTSxFQUFDLElBQUksRUFBRSxTQUFTLEVBQUMsR0FBRyxJQUFJLENBQUMsK0JBQStCLEVBQUUsQ0FBQztZQUNqRSxJQUFJLFNBQVMsRUFBRTtnQkFDYixNQUFNLElBQUksS0FBSyxDQUNYLGlDQUFpQyxJQUFJLHFCQUFxQjtvQkFDMUQsZ0RBQWdEO29CQUNoRCxvREFBb0QsQ0FBQyxDQUFDO2FBQzNEO1lBQ0QsSUFBSSxDQUFDLFVBQVUsQ0FBQyxJQUFJLENBQUMsQ0FBQztTQUN2QjtRQUNELE9BQU8sSUFBSSxDQUFDLGVBQWUsQ0FBQztJQUM5QixDQUFDO0lBRUQsWUFBWTtRQUNWLE9BQU8sTUFBTSxDQUFDLElBQUksQ0FBQyxJQUFJLENBQUMsZUFBZSxDQUFDLENBQUM7SUFDM0MsQ0FBQztJQUVELFdBQVcsQ0FBQyxXQUFtQjtRQUM3QixJQUFJLENBQUMsQ0FBQyxXQUFXLElBQUksSUFBSSxDQUFDLFFBQVEsQ0FBQyxFQUFFO1lBQ25DLDBFQUEwRTtZQUMxRSxtQ0FBbUM7WUFDbkMsSUFBSSxXQUFXLElBQUksSUFBSSxDQUFDLGVBQWUsRUFBRTtnQkFDdkMsTUFBTSxFQUFDLFNBQVMsRUFBQyxHQUFHLElBQUksQ0FBQyxpQkFBaUIsQ0FBQyxXQUFXLENBQUMsQ0FBQztnQkFDeEQsSUFBSSxTQUFTLEVBQUU7b0JBQ2IsNEJBQTRCO29CQUM1QixPQUFPLElBQUksQ0FBQztpQkFDYjthQUNGO2lCQUFNO2dCQUNMLE9BQU8sSUFBSSxDQUFDO2FBQ2I7U0FDRjtRQUNELE9BQU8sSUFBSSxDQUFDLFFBQVEsQ0FBQyxXQUFXLENBQUMsQ0FBQztJQUNwQyxDQUFDO0lBRUQsa0JBQWtCLENBQUMsV0FBbUI7UUFFcEMsSUFBSSxDQUFDLENBQUMsV0FBVyxJQUFJLElBQUksQ0FBQyxlQUFlLENBQUMsRUFBRTtZQUMxQyxPQUFPLElBQUksQ0FBQztTQUNiO1FBQ0QsT0FBTyxJQUFJLENBQUMsZUFBZSxDQUFDLFdBQVcsQ0FBQyxDQUFDLE9BQU8sQ0FBQztJQUNuRCxDQUFDO0lBRUQsZUFBZSxDQUNYLFdBQW1CLEVBQ25CLE9BQXFELEVBQ3JELFFBQVEsR0FBRyxDQUFDO1FBQ2QsSUFBSSxXQUFXLElBQUksSUFBSSxDQUFDLGVBQWUsRUFBRTtZQUN2QyxHQUFHLENBQUMsSUFBSSxDQUNKLEdBQUcsV0FBVyxtQ0FBbUM7Z0JBQ2pELG1DQUFtQyxDQUFDLENBQUM7WUFDekMsT0FBTyxLQUFLLENBQUM7U0FDZDtRQUNELElBQUksQ0FBQyxlQUFlLENBQUMsV0FBVyxDQUFDLEdBQUcsRUFBQyxPQUFPLEVBQUUsUUFBUSxFQUFDLENBQUM7UUFDeEQsT0FBTyxJQUFJLENBQUM7SUFDZCxDQUFDO0lBRUQsS0FBSyxDQUFDLFVBQVUsQ0FBQyxXQUFtQjtRQUNsQyxJQUFJLElBQUksQ0FBQyxlQUFlLENBQUMsV0FBVyxDQUFDLElBQUksSUFBSSxFQUFFO1lBQzdDLE1BQU0sSUFBSSxLQUFLLENBQUMsaUJBQWlCLFdBQVcseUJBQXlCLENBQUMsQ0FBQztTQUN4RTtRQUNELElBQUksQ0FBQyxXQUFXLEdBQUcsV0FBVyxDQUFDO1FBQy9CLElBQUksSUFBSSxDQUFDLFFBQVEsQ0FBQyxXQUFXLENBQUMsSUFBSSxJQUFJLEVBQUU7WUFDdEMsSUFBSSxDQUFDLGVBQWUsR0FBRyxJQUFJLENBQUM7WUFDNUIsTUFBTSxFQUFDLE9BQU8sRUFBRSxTQUFTLEVBQUMsR0FBRyxJQUFJLENBQUMsaUJBQWlCLENBQUMsV0FBVyxDQUFDLENBQUM7WUFDakUsTUFBTSxNQUFNLEdBQUcsU0FBUyxDQUFDLENBQUMsQ0FBQyxNQUFNLE9BQU8sQ0FBQyxDQUFDLENBQUMsT0FBTyxDQUFDO1lBQ25ELElBQUksQ0FBQyxNQUFNLEVBQUU7Z0JBQ1gsT0FBTyxLQUFLLENBQUM7YUFDZDtTQUNGO1FBQ0QsSUFBSSxDQUFDLGVBQWUsR0FBRyxJQUFJLENBQUMsUUFBUSxDQUFDLFdBQVcsQ0FBQyxDQUFDO1FBQ2xELElBQUksQ0FBQyxzQkFBc0IsRUFBRSxDQUFDO1FBQzlCLHNCQUFzQjtRQUN0QixJQUFJLENBQUMsUUFBUSxHQUFHLElBQUksUUFBUSxDQUFDLElBQUksQ0FBQyxlQUFlLENBQUMsQ0FBQztRQUVuRCxPQUFPLElBQUksQ0FBQztJQUNkLENBQUM7SUFFTyxzQkFBc0I7UUFDNUIsTUFBTSxPQUFPLEdBQUcsb0JBQW9CLENBQ