UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

1,178 lines (1,063 loc) 40.4 kB
/** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import {BackendTimingInfo, DataMover, KernelBackend} from './backends/backend'; import {Environment, setEnvironmentGlobal} from './environment'; import {getGradient, getKernel, getKernelsForBackend, GradFunc, NamedAttrMap, TensorInfo} from './kernel_registry'; import {Profiler} from './profiler'; import {backpropagateGradients, getFilteredNodesXToY, TapeNode} from './tape'; import {DataId, setTensorTracker, Tensor, TensorTracker, Variable} from './tensor'; import {GradSaveFunc, NamedTensorMap, NamedVariableMap, TensorContainer} from './tensor_types'; import {getTensorsInContainer} from './tensor_util'; import {BackendValues, DataType, DataValues} from './types'; import * as util from './util'; import {bytesFromStringArray, makeOnesTypedArray, now, sizeFromShape} from './util'; /** * A function that computes an output. The save function is for saving tensors * computed in the forward pass, that we need in the backward pass. */ export type ForwardFunc<T> = (backend: KernelBackend, save?: GradSaveFunc) => T; /** * @docalias (a: Tensor, b: Tensor,..., save?: Function) => { * value: Tensor, * gradFunc: (dy: Tensor, saved?: NamedTensorMap) => Tensor | Tensor[] * } */ export type CustomGradientFunc<T extends Tensor> = (...inputs: Array<Tensor|GradSaveFunc>) => { value: T; gradFunc: (dy: T, saved: Tensor[]) => Tensor | Tensor[]; }; export type MemoryInfo = { numTensors: number; numDataBuffers: number; numBytes: number; unreliable?: boolean; reasons: string[]; }; type KernelProfile = { name: string; bytesAdded: number; totalBytesSnapshot: number; tensorsAdded: number; totalTensorsSnapshot: number; inputShapes: number[][]; outputShapes: number[][]; }; export type ProfileInfo = { newBytes: number; newTensors: number; peakBytes: number; kernels: KernelProfile[]; result: TensorContainer; }; export interface TimingInfo extends BackendTimingInfo { wallMs: number; } /** @docalias Function */ export type ScopeFn<T extends TensorContainer> = () => T; interface ScopeState { track: Tensor[]; name: string; id: number; } class EngineState { // Public since optimizers will use it. registeredVariables: NamedVariableMap = {}; nextTapeNodeId = 0; numBytes = 0; numTensors = 0; numStringTensors = 0; numDataBuffers = 0; activeTape: TapeNode[]; // Number of nested tf.grad() statements when computing higher-order // gradients. E.g. `1` for first-order gradients and `2` for second-order // gradients. Used to track if the tape should be removed after a backprop. gradientDepth = 0; // Number of nested kernel calls. When kernel depth is greater than 1, we turn // off the tape. kernelDepth = 0; // Keep Tensors that parallel the tapes. activeScope: ScopeState; scopeStack: ScopeState[] = []; /** * Keeps track of the number of data moves during a kernel execution. We * maintain a stack since kernels can call other kernels, recursively. */ numDataMovesStack: number[] = []; nextScopeId = 0; tensorInfo = new WeakMap<DataId, { backend: KernelBackend, bytes: number, dtype: DataType, shape: number[], refCount: number }>(); profiling = false; activeProfile: ProfileInfo = {newBytes: 0, newTensors: 0, peakBytes: 0, kernels: [], result: null}; dispose() { for (const variableName in this.registeredVariables) { this.registeredVariables[variableName].dispose(); } } } export class Engine implements TensorTracker, DataMover { state: EngineState; backendName: string; registry: {[id: string]: KernelBackend} = {}; registryFactory: { [id: string]: { factory: () => KernelBackend | Promise<KernelBackend>, priority: number } } = {}; private profiler: Profiler; private backendInstance: KernelBackend; private pendingBackendInit: Promise<boolean>; private pendingBackendInitId = 0; constructor(public ENV: Environment) { this.state = new EngineState(); } async ready(): Promise<void> { if (this.pendingBackendInit != null) { return this.pendingBackendInit.then(() => {}); } if (this.backendInstance != null) { return; } const sortedBackends = this.getSortedBackends(); for (let i = 0; i < sortedBackends.length; i++) { const backendName = sortedBackends[i]; const success = await this.initializeBackend(backendName).success; if (success) { await this.setBackend(backendName); return; } } throw new Error( `Could not initialize any backends, all backend initializations ` + `failed.`); } get backend(): KernelBackend { if (this.pendingBackendInit != null) { throw new Error( `Backend '${this.backendName}' has not yet been initialized. Make ` + `sure to await tf.ready() or await tf.setBackend() before calling ` + `other methods`); } if (this.backendInstance == null) { const {name, asyncInit} = this.initializeBackendsAndReturnBest(); if (asyncInit) { throw new Error( `The highest priority backend '${name}' has not yet been ` + `initialized. Make sure to await tf.ready() or ` + `await tf.setBackend() before calling other methods`); } this.setBackend(name); } return this.backendInstance; } backendNames(): string[] { return Object.keys(this.registryFactory); } findBackend(backendName: string): KernelBackend { if (!(backendName in this.registry)) { // If the backend hasn't been initialized but we have a registry entry for // it, initialize it and return it. if (backendName in this.registryFactory) { const {asyncInit} = this.initializeBackend(backendName); if (asyncInit) { // Backend is not ready yet. return null; } } else { return null; } } return this.registry[backendName]; } findBackendFactory(backendName: string): () => KernelBackend | Promise<KernelBackend> { if (!(backendName in this.registryFactory)) { return null; } return this.registryFactory[backendName].factory; } registerBackend( backendName: string, factory: () => KernelBackend | Promise<KernelBackend>, priority = 1): boolean { if (backendName in this.registryFactory) { console.warn( `${backendName} backend was already registered. ` + `Reusing existing backend factory.`); return false; } this.registryFactory[backendName] = {factory, priority}; return true; } async setBackend(backendName: string): Promise<boolean> { if (this.registryFactory[backendName] == null) { throw new Error(`Backend name '${backendName}' not found in registry`); } this.backendName = backendName; if (this.registry[backendName] == null) { this.backendInstance = null; const {success, asyncInit} = this.initializeBackend(backendName); const result = asyncInit ? await success : success; if (!result) { return false; } } this.backendInstance = this.registry[backendName]; this.setupRegisteredKernels(); // Reset the profiler. this.profiler = new Profiler(this.backendInstance); return true; } private setupRegisteredKernels(): void { const kernels = getKernelsForBackend(this.backendName); kernels.forEach(kernel => { if (kernel.setupFunc != null) { kernel.setupFunc(this.backendInstance); } }); } private disposeRegisteredKernels(backendName: string): void { const kernels = getKernelsForBackend(backendName); kernels.forEach(kernel => { if (kernel.disposeFunc != null) { kernel.disposeFunc(this.registry[backendName]); } }); } /** * Initializes a backend by looking up the backend name in the factory * registry and calling the factory method. Returns a boolean representing * whether the initialization of the backend suceeded. Throws an error if * there is no backend in the factory registry. */ private initializeBackend(backendName: string): {success: boolean|Promise<boolean>, asyncInit: boolean} { const registryFactoryEntry = this.registryFactory[backendName]; if (registryFactoryEntry == null) { throw new Error( `Cannot initialize backend ${backendName}, no registration found.`); } try { const backend = registryFactoryEntry.factory(); // Test if the factory returns a promise. if (Promise.resolve(backend) === backend) { const promiseId = ++this.pendingBackendInitId; const success = backend .then(backendInstance => { // Outdated promise. Another backend was set in the meantime. if (promiseId < this.pendingBackendInitId) { return false; } this.registry[backendName] = backendInstance; this.pendingBackendInit = null; return true; }) .catch(err => { // Outdated promise. Another backend was set in the meantime. if (promiseId < this.pendingBackendInitId) { return false; } this.pendingBackendInit = null; console.warn( `Initialization of backend ${backendName} failed`); console.warn(err.stack || err.message); return false; }); this.pendingBackendInit = success; return {success, asyncInit: true}; } else { this.registry[backendName] = backend as KernelBackend; return {success: true, asyncInit: false}; } } catch (err) { console.warn(`Initialization of backend ${backendName} failed`); console.warn(err.stack || err.message); return {success: false, asyncInit: false}; } } removeBackend(backendName: string): void { if (!(backendName in this.registryFactory)) { throw new Error(`${backendName} backend not found in registry`); } if (this.backendName === backendName && this.pendingBackendInit != null) { // There is a pending promise of the backend we want to remove. Make it // obsolete. this.pendingBackendInitId++; } if (backendName in this.registry) { this.disposeRegisteredKernels(backendName); this.registry[backendName].dispose(); delete this.registry[backendName]; } delete this.registryFactory[backendName]; // Unset the backend if it is active. if (this.backendName === backendName) { this.pendingBackendInit = null; this.backendName = null; this.backendInstance = null; } } private getSortedBackends(): string[] { if (Object.keys(this.registryFactory).length === 0) { throw new Error('No backend found in registry.'); } return Object.keys(this.registryFactory).sort((a: string, b: string) => { // Highest priority comes first. return this.registryFactory[b].priority - this.registryFactory[a].priority; }); } private initializeBackendsAndReturnBest(): {name: string, asyncInit: boolean} { const sortedBackends = this.getSortedBackends(); for (let i = 0; i < sortedBackends.length; i++) { const backendName = sortedBackends[i]; const {success, asyncInit} = this.initializeBackend(backendName); if (asyncInit || success) { return {name: backendName, asyncInit}; } } throw new Error( `Could not initialize any backends, all backend initializations ` + `failed.`); } moveData(destBackend: KernelBackend, dataId: DataId) { const info = this.state.tensorInfo.get(dataId); const srcBackend = info.backend; const values = this.readSync(dataId); // Delete the tensor from the old backend and move it to the new // backend. srcBackend.disposeData(dataId); info.backend = destBackend; destBackend.move(dataId, values, info.shape, info.dtype); if (this.shouldCheckForMemLeaks()) { // Track the number of moves during a kernel execution to correctly // detect memory leaks. this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++; } } tidy<T extends TensorContainer>(nameOrFn: string|ScopeFn<T>, fn?: ScopeFn<T>): T { let name: string = null; if (fn == null) { // Called with only 1 argument. if (typeof nameOrFn !== 'function') { throw new Error('Please provide a function to tidy()'); } fn = nameOrFn; } else { // Called with 2 arguments. if (typeof nameOrFn !== 'string' && !(nameOrFn instanceof String)) { throw new Error( 'When calling with two arguments, the first argument ' + 'to tidy() must be a string'); } if (typeof fn !== 'function') { throw new Error( 'When calling with two arguments, the 2nd argument ' + 'to tidy() must be a function'); } name = nameOrFn as string; // TODO(nsthorat,smilkov): Do operation logging and performance // profiling. } let result: T; return this.scopedRun( () => this.startScope(name), () => this.endScope(result), () => { result = fn(); if (result instanceof Promise) { console.error('Cannot return a Promise inside of tidy.'); } return result; }); } private scopedRun<T>(start: () => void, end: () => void, f: () => T): T { start(); try { const res = f(); end(); return res; } catch (ex) { end(); throw ex; } } private static nextTensorId = 0; private nextTensorId(): number { return Engine.nextTensorId++; } private static nextVariableId = 0; private nextVariableId(): number { return Engine.nextVariableId++; } /** * This method is called instead of the public-facing tensor.clone() when * saving a tensor for backwards pass. It makes sure to add the clone * operation to the tape regardless of being called inside a kernel * execution. * * This method will go away once all kernels are modularized since we won't * need to turn off the tape inside runKernel(). */ private clone(x: Tensor): Tensor { const y = this.makeTensorFromDataId(x.dataId, x.shape, x.dtype); const inputs = {x}; const grad = (dy: Tensor) => ({x: () => dy.toFloat()}); const saved: Tensor[] = []; this.addTapeNode(this.state.activeScope.name, inputs, [y], grad, saved, {}); return y; } /** * Execute a kernel with the given name and return the output tensor. * * @param kernelName The name of the kernel to execute. * @param inputs A map of input names to tensors. * @param attrs A map of attribute names to their values. An attribute is a * primitive (non-tensor) input to the kernel. * @param inputsToSave A list of tensors, inputs to save for the backprop * computation. * @param outputsToSave A list of booleans, specifying which output to save * for the backprop computation. These are booleans since the output * tensors are not visible to the user. */ runKernel( kernelName: string, inputs: NamedTensorMap, attrs: NamedAttrMap, inputsToSave?: Tensor[], outputsToSave?: boolean[]): Tensor|Tensor[] { const forwardFunc: null = null; const backwardsFunc: null = null; // Call runKernel as a stop-gap until we modularize all kernels. // Once we modularize all kernels, we will remove the existing // `runKernelFunc`. return this.runKernelFunc( forwardFunc, inputs, backwardsFunc, kernelName, attrs, inputsToSave, outputsToSave); } private shouldCheckForMemLeaks(): boolean { return this.ENV.getBool('IS_TEST'); } private checkKernelForMemLeak( kernelName: string, numDataIdsBefore: number, outInfos: TensorInfo[]): void { const numDataIdsAfter = this.backend.numDataIds(); // Count the number of data ids associated with the result of the kernel. let numOutputDataIds = 0; outInfos.forEach(info => { // Complex numbers allocate 3 data ids, one for 'real', one for // 'imaginary', and one for the container that holds the former two. numOutputDataIds += (info.dtype === 'complex64' ? 3 : 1); }); // Account for the number of moves during kernel execution. A "data move" // can happen in the middle of a kernel execution, placing a new (key,value) // pair in the data storage. Since data moves have net zero effect (we // always remove the data from the old backend), we have to cancel them out // when detecting memory leaks. const numMoves = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]; const dataIdsLeaked = numDataIdsAfter - numDataIdsBefore - numOutputDataIds - numMoves; if (dataIdsLeaked > 0) { throw new Error( `Backend '${this.backendName}' has an internal memory leak ` + `(${dataIdsLeaked} data ids) after running '${kernelName}'`); } } /** * @deprecated Use `runKernel` for newly added kernels. Keep using this method * only for kernels that are not yet fully modularized. */ runKernelFunc<T extends Tensor|Tensor[], I extends NamedTensorMap>( forwardFunc: ForwardFunc<T>, inputs: I, backwardsFunc?: (dy: T, saved: Tensor[]) => {[P in keyof I]: () => I[P]}, kernelName?: string, attrs?: NamedAttrMap, inputsToSave?: Tensor[], outputsToSave?: boolean[]): T { let outputs: Tensor[]; let saved: Tensor[] = []; const isTapeOn = this.isTapeOn(); if (kernelName == null) { kernelName = this.state.activeScope != null ? this.state.activeScope.name : ''; } const startingBytecount = this.state.numBytes; const startingNumTensors = this.state.numTensors; if (this.shouldCheckForMemLeaks()) { this.state.numDataMovesStack.push(0); } let kernelFunc: () => Tensor[]; const kernel = getKernel(kernelName, this.backendName); let out: TensorInfo|TensorInfo[]; if (kernel != null) { kernelFunc = () => { const numDataIdsBefore = this.backend.numDataIds(); out = kernel.kernelFunc({inputs, attrs, backend: this.backend}); const outInfos = Array.isArray(out) ? out : [out]; if (this.shouldCheckForMemLeaks()) { this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos); } const outTensors = outInfos.map( ({dataId, shape, dtype}) => this.makeTensorFromDataId(dataId, shape, dtype)); // Save the inputs and outputs. // Do not save unless we are recording to the tape. Otherwise it would // cause a mem leak since we would never run backprop, which disposes // the kept tensors. if (isTapeOn) { let tensorsToSave = this.getTensorsForGradient(kernelName, inputs, outTensors); if (tensorsToSave == null) { // Fallback for ops that call runKernelFunc and pass in // inputsToSave and outputsToSave. Currently this is the set of ops // with kernel support in the WASM backend. Once those ops and // respective gradients are modularised we can remove this path. if (outputsToSave == null) { outputsToSave = []; } const outsToSave = outTensors.filter((_, i) => outputsToSave[i]); tensorsToSave = (inputsToSave || []).slice().concat(outsToSave); } saved = this.saveTensorsForBackwardMode(tensorsToSave); } return outTensors; }; } else { const saveFunc: GradSaveFunc = (tensors) => { // Do not save unless we are recording to the tape. Otherwise it would // cause a mem leak since we would never run backprop, which disposes // the kept tensors. if (!isTapeOn) { return; } saved = tensors.map(tensor => this.keep(this.clone(tensor))); }; kernelFunc = () => { const numDataIdsBefore = this.backend.numDataIds(); out = this.tidy(() => forwardFunc(this.backend, saveFunc)); const outs = (Array.isArray(out) ? out : [out]) as Tensor[]; if (this.shouldCheckForMemLeaks()) { this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outs); } return outs; }; } // Stop recording to a tape when running a kernel. this.scopedRun( () => this.state.kernelDepth++, () => this.state.kernelDepth--, () => { if (!this.ENV.getBool('DEBUG')) { outputs = kernelFunc(); } else { outputs = this.profiler.profileKernel( kernelName, inputs, () => kernelFunc()); } }); if (isTapeOn) { this.addTapeNode( kernelName, inputs, outputs, backwardsFunc, saved, attrs); } if (this.state.profiling) { this.state.activeProfile.kernels.push({ name: kernelName, bytesAdded: this.state.numBytes - startingBytecount, totalBytesSnapshot: this.state.numBytes, tensorsAdded: this.state.numTensors - startingNumTensors, totalTensorsSnapshot: this.state.numTensors, inputShapes: Object.keys(inputs).map(key => inputs[key].shape), outputShapes: outputs.map(item => item.shape) }); } return (Array.isArray(out) ? outputs : outputs[0]) as T; } /** * Saves tensors used in forward mode for use in backward mode. * * @param tensors the list of tensors to save. */ private saveTensorsForBackwardMode(tensors: Tensor[]): Tensor[] { const saved = tensors.map(tensor => this.keep(this.clone(tensor))); return saved; } /** * Returns a list of tensors to save for a given gradient calculation. * * Returns undefined if their is no registered gradient for this kernel in the * gradient registry. * * @param kernelName name of kernel to look up gradient for. * @param inputs a map of input tensors. * @param outputs an array of output tensors from forward mode of kernel. */ private getTensorsForGradient( kernelName: string, inputs: NamedTensorMap, outputs: Tensor[]): Tensor[]|null { const gradConfig = getGradient(kernelName); if (gradConfig != null) { const inputsToSave: string[] = gradConfig.inputsToSave || []; const outputsToSave: boolean[] = gradConfig.outputsToSave || []; // If saveAllInputs is true, all inputs will be saved. Otherwise, inputs // specified in inputsToSave will be saved. let inputTensorsToSave: Tensor[]; if (gradConfig.saveAllInputs) { util.assert( Array.isArray(inputs), () => 'saveAllInputs is true, expected inputs to be an array.'); inputTensorsToSave = Object.keys(inputs).map((key) => inputs[key]); } else { inputTensorsToSave = inputsToSave.map((inputName) => inputs[inputName]); } const outputTensorsToSave: Tensor[] = outputs.filter((_, i) => outputsToSave[i]); return inputTensorsToSave.concat(outputTensorsToSave); } // TODO(yassogba) throw exception here once all runkernelFunc calls with // inputsToSave/outputsToSave are removed return null; } /** * Internal method used by public APIs for tensor creation. Makes a new * tensor with the provided shape, dtype and values. It always * creates a new data id and writes the values to the underlying backend. */ makeTensor( values: DataValues, shape: number[], dtype: DataType, backend?: KernelBackend): Tensor { if (values == null) { throw new Error('Values passed to engine.makeTensor() are null'); } dtype = dtype || 'float32'; backend = backend || this.backend; let backendVals = values as BackendValues; if (dtype === 'string' && util.isString(values[0])) { backendVals = (values as string[]).map(d => util.encodeString(d)); } const dataId = backend.write(backendVals, shape, dtype); const t = new Tensor(shape, dtype, dataId, this.nextTensorId()); this.incRef(t, backend); // Count bytes for string tensors. if (dtype === 'string') { const info = this.state.tensorInfo.get(dataId); const newBytes = bytesFromStringArray(backendVals as Uint8Array[]); this.state.numBytes += newBytes - info.bytes; info.bytes = newBytes; } return t; } /** * Internal method used by backends. Makes a new tensor * that is a wrapper around an existing data id. It doesn't create * a new data id, only increments the ref count used in memory tracking. */ makeTensorFromDataId( dataId: DataId, shape: number[], dtype: DataType, backend?: KernelBackend): Tensor { dtype = dtype || 'float32'; const t = new Tensor(shape, dtype, dataId, this.nextTensorId()); this.incRef(t, backend); return t; } makeVariable( initialValue: Tensor, trainable = true, name?: string, dtype?: DataType): Variable { name = name || this.nextVariableId().toString(); if (dtype != null && dtype !== initialValue.dtype) { initialValue = initialValue.asType(dtype); } const v = new Variable(initialValue, trainable, name, this.nextTensorId()); if (this.state.registeredVariables[v.name] != null) { throw new Error(`Variable with name ${v.name} was already registered`); } this.state.registeredVariables[v.name] = v; this.incRef(v, this.backend); return v; } incRef(a: Tensor, backend: KernelBackend): void { const refCount = this.state.tensorInfo.has(a.dataId) ? this.state.tensorInfo.get(a.dataId).refCount : 0; this.state.numTensors++; if (a.dtype === 'string') { this.state.numStringTensors++; } if (refCount === 0) { this.state.numDataBuffers++; // Bytes for complex numbers are counted by their components. Bytes for // string tensors are counted when writing values. let bytes = 0; if (a.dtype !== 'complex64' && a.dtype !== 'string') { bytes = a.size * util.bytesPerElement(a.dtype); } this.state.tensorInfo.set(a.dataId, { backend: backend || this.backend, dtype: a.dtype, shape: a.shape, bytes, refCount: 0 }); this.state.numBytes += bytes; } this.state.tensorInfo.get(a.dataId).refCount++; if (!(a instanceof Variable)) { this.track(a); } } disposeTensor(a: Tensor): void { if (!this.state.tensorInfo.has(a.dataId)) { return; } this.state.numTensors--; if (a.dtype === 'string') { this.state.numStringTensors--; } const info = this.state.tensorInfo.get(a.dataId); const refCount = info.refCount; if (refCount <= 1) { // Don't count bytes for complex numbers as they are counted by their // components. if (a.dtype !== 'complex64') { this.state.numBytes -= info.bytes; } this.state.numDataBuffers--; info.backend.disposeData(a.dataId); this.state.tensorInfo.delete(a.dataId); } else { this.state.tensorInfo.get(a.dataId).refCount--; } // TODO(nsthorat): Construct an error and save the stack trace for // debugging when in debug mode. Creating a stack trace is too expensive // to do unconditionally. } disposeVariables(): void { for (const varName in this.state.registeredVariables) { const v = this.state.registeredVariables[varName]; this.disposeVariable(v); } } disposeVariable(v: Variable): void { this.disposeTensor(v); if (this.state.registeredVariables[v.name] != null) { delete this.state.registeredVariables[v.name]; } } memory(): MemoryInfo { const info = this.backend.memory() as MemoryInfo; info.numTensors = this.state.numTensors; info.numDataBuffers = this.state.numDataBuffers; info.numBytes = this.state.numBytes; if (this.state.numStringTensors > 0) { info.unreliable = true; if (info.reasons == null) { info.reasons = []; } info.reasons.push( 'Memory usage by string tensors is approximate ' + '(2 bytes per character)'); } return info; } async profile(query: () => TensorContainer): Promise<ProfileInfo> { this.state.profiling = true; const startBytes = this.state.numBytes; const startNumTensors = this.state.numTensors; this.state.activeProfile.kernels = []; this.state.activeProfile.result = query(); this.state.profiling = false; this.state.activeProfile.peakBytes = Math.max( ...this.state.activeProfile.kernels.map(d => d.totalBytesSnapshot)); this.state.activeProfile.newBytes = this.state.numBytes - startBytes; this.state.activeProfile.newTensors = this.state.numTensors - startNumTensors; return this.state.activeProfile; } isTapeOn(): boolean { return this.state.gradientDepth > 0 && this.state.kernelDepth === 0; } private addTapeNode( kernelName: string, inputs: NamedTensorMap, outputs: Tensor[], gradientsFunc: GradFunc, saved: Tensor[], attrs: NamedAttrMap): void { const tapeNode: TapeNode = {id: this.state.nextTapeNodeId++, kernelName, inputs, outputs, saved}; const gradConfig = getGradient(kernelName); if (gradConfig != null) { gradientsFunc = gradConfig.gradFunc; } if (gradientsFunc != null) { tapeNode.gradient = (dys: Tensor[]) => { // TODO(smilkov): To optimize back-prop, pass dys that are not used in // the backprop graph to the user as null instead of zeros dys = dys.map((dy, i) => { if (dy == null) { const output = outputs[i]; const vals = util.makeZerosTypedArray(output.size, output.dtype); return this.makeTensor(vals, output.shape, output.dtype); } return dy; }); // Grad functions of ops with single outputs expect a dy, while ops // with multiple outputs expect dys (array of dy). return gradientsFunc(dys.length > 1 ? dys : dys[0], saved, attrs); }; } this.state.activeTape.push(tapeNode); } keep<T extends Tensor>(result: T): T { result.kept = true; return result; } private startTape() { if (this.state.gradientDepth === 0) { this.state.activeTape = []; } this.state.gradientDepth++; } private endTape() { this.state.gradientDepth--; } /** * Start a scope. Use this with endScope() to achieve the same functionality * as scope() without the need for a function closure. */ startScope(name?: string) { const scopeInfo: ScopeState = { track: [], name: 'unnamed scope', id: this.state.nextScopeId++ }; if (name) { scopeInfo.name = name; } this.state.scopeStack.push(scopeInfo); this.state.activeScope = scopeInfo; } /** * End a scope. Use this with startScope() to achieve the same functionality * as scope() without the need for a function closure. */ endScope(result?: TensorContainer) { const tensorsToTrackInParent = getTensorsInContainer(result); const tensorsToTrackInParentSet = new Set(tensorsToTrackInParent.map(t => t.id)); // Dispose the arrays tracked in this scope. for (let i = 0; i < this.state.activeScope.track.length; i++) { const tensor = this.state.activeScope.track[i]; if (!tensor.kept && !tensorsToTrackInParentSet.has(tensor.id)) { tensor.dispose(); } } const oldScope = this.state.scopeStack.pop(); this.state.activeScope = this.state.scopeStack.length === 0 ? null : this.state.scopeStack[this.state.scopeStack.length - 1]; // Track the current result in the parent scope. tensorsToTrackInParent.forEach(tensor => { // Only track the tensor if was allocated in the inner scope and is not // globally kept. if (!tensor.kept && tensor.scopeId === oldScope.id) { this.track(tensor); } }); } /** * Returns gradients of `f` with respect to each of the `xs`. The gradients * returned are of the same length as `xs`, but some might be null if `f` * was not a function of that `x`. It also takes optional dy to multiply the * gradient, which defaults to `1`. */ gradients<T extends Tensor>( f: () => T, xs: Tensor[], dy?: T, allowNoGradients = false): {value: T, grads: Tensor[]} { util.assert( xs.length > 0, () => 'gradients() received an empty list of xs.'); if (dy != null && dy.dtype !== 'float32') { throw new Error(`dy must have 'float32' dtype, but has '${dy.dtype}'`); } const y = this.scopedRun( () => this.startTape(), () => this.endTape(), () => this.tidy('forward', f)); util.assert( y instanceof Tensor, () => 'The result y returned by f() must be a tensor.'); // Filter out the nodes that don't connect x => y. const filteredTape = getFilteredNodesXToY(this.state.activeTape, xs, y); if (!allowNoGradients && filteredTape.length === 0 && xs.length > 0) { throw new Error( 'Cannot compute gradient of y=f(x) with respect to x. Make sure ' + 'that the f you passed encloses all operations that lead from x ' + 'to y.'); } return this.tidy('backward', () => { const accumulatedGradientMap: {[tensorId: number]: Tensor} = {}; accumulatedGradientMap[y.id] = (dy == null) ? ones(y.shape) : dy; // Backprop gradients through the filtered nodes. backpropagateGradients( accumulatedGradientMap, filteredTape, // Pass the tidy function to avoid circular dep with `tape.ts`. f => this.tidy(f as ScopeFn<Tensor>)); const grads = xs.map(x => accumulatedGradientMap[x.id]); if (this.state.gradientDepth === 0) { // This means that we are not computing higher-order gradients // and can clean up the tape. this.state.activeTape.forEach(node => { for (const tensor of node.saved) { tensor.dispose(); } }); this.state.activeTape = null; } return {value: y, grads}; }); } customGrad<T extends Tensor>(f: CustomGradientFunc<T>): (...args: Array<Tensor|GradSaveFunc>) => T { util.assert( util.isFunction(f), () => 'The f passed in customGrad(f) must be a function.'); return (...inputs: Tensor[]): T => { util.assert( inputs.every(t => t instanceof Tensor), () => 'The args passed in customGrad(f)(x1, x2,...) must all be ' + 'tensors'); let res: { value: T, gradFunc: (dy: T, saved: Tensor[]) => Tensor | Tensor[], }; const inputMap: NamedTensorMap = {}; inputs.forEach((input, i) => { inputMap[i] = input; }); return this.runKernelFunc( (_, save) => { res = f(...[...inputs, save]); util.assert( res.value instanceof Tensor, () => 'The function f passed in customGrad(f) must return an ' + 'object where `obj.value` is a tensor'); util.assert( util.isFunction(res.gradFunc), () => 'The function f passed in customGrad(f) must return an ' + 'object where `obj.gradFunc` is a function.'); return res.value; }, inputMap, (dy: T, saved: Tensor[]) => { const gradRes = res.gradFunc(dy, saved); const grads: Tensor[] = Array.isArray(gradRes) ? gradRes : [gradRes]; util.assert( grads.length === inputs.length, () => 'The function f passed in customGrad(f) must return an ' + 'object where `obj.gradFunc` is a function that returns ' + 'the same number of tensors as inputs passed to f(...).'); util.assert( grads.every(t => t instanceof Tensor), () => 'The function f passed in customGrad(f) must return an ' + 'object where `obj.gradFunc` is a function that returns ' + 'a list of only tensors.'); const gradMap: {[key: string]: () => Tensor} = {}; grads.forEach((grad, i) => { gradMap[i] = () => grad; }); return gradMap; }); }; } readSync(dataId: DataId): BackendValues { // Route the read to the correct backend. const info = this.state.tensorInfo.get(dataId); return info.backend.readSync(dataId); } read(dataId: DataId): Promise<BackendValues> { // Route the read to the correct backend. const info = this.state.tensorInfo.get(dataId); return info.backend.read(dataId); } async time(query: () => void): Promise<TimingInfo> { const start = now(); const timingInfo = await this.backend.time(query) as TimingInfo; timingInfo.wallMs = now() - start; return timingInfo; } /** * Tracks a Tensor in the current scope to be automatically cleaned up * when the current scope ends, and returns the value. * * @param result The Tensor to track in the current scope. */ private track<T extends Tensor>(result: T): T { if (this.state.activeScope != null) { result.scopeId = this.state.activeScope.id; this.state.activeScope.track.push(result); } return result; } get registeredVariables(): NamedVariableMap { return this.state.registeredVariables; } /** * Resets the engine state. Removes all backends but does not remove * registered backend factories. */ reset(): void { // Make any pending promise obsolete. this.pendingBackendInitId++; this.state.dispose(); this.ENV.reset(); this.state = new EngineState(); for (const backendName in this.registry) { this.disposeRegisteredKernels(backendName); this.registry[backendName].dispose(); delete this.registry[backendName]; } this.backendName = null; this.backendInstance = null; this.pendingBackendInit = null; } } function ones(shape: number[]): Tensor { const values = makeOnesTypedArray(sizeFromShape(shape), 'float32'); return ENGINE.makeTensor(values, shape, 'float32'); } let GLOBAL: {_tfengine: Engine}; function getGlobalNamespace(): {_tfengine: Engine} { if (GLOBAL == null) { // tslint:disable-next-line:no-any let ns: any; if (typeof (window) !== 'undefined') { ns = window; } else if (typeof (global) !== 'undefined') { ns = global; } else if (typeof (process) !== 'undefined') { ns = process; } else if (typeof (self) !== 'undefined') { ns = self; } else { throw new Error('Could not find a global object'); } GLOBAL = ns; } return GLOBAL; } function getOrMakeEngine(): Engine { const ns = getGlobalNamespace(); if (ns._tfengine == null) { const environment = new Environment(ns); ns._tfengine = new Engine(environment); } setEnvironmentGlobal(ns._tfengine.ENV); // Tell the current tensor interface that the global engine is responsible // for tracking. setTensorTracker(() => ns._tfengine); return ns._tfengine; } export const ENGINE = getOrMakeEngine();