@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
76 lines • 9.24 kB
JavaScript
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as broadcast_util from './broadcast_util';
import { elu } from './elu';
import { leakyRelu } from './leaky_relu';
import { mul } from './mul';
import { prelu } from './prelu';
import { relu } from './relu';
import { relu6 } from './relu6';
import { reshape } from './reshape';
import { sigmoid } from './sigmoid';
import { step } from './step';
import { sum } from './sum';
// Returns gradient for fused activation.
export function getFusedDyActivation(dy, y, activation) {
if (activation == null || activation === 'linear') {
return dy;
}
if (activation === 'relu') {
return mul(dy, step(y));
}
throw new Error(`Cannot compute gradient for fused activation ${activation}.`);
}
// Returns gradient for fused bias.
export function getFusedBiasGradient(bias, dyActivation) {
let res = dyActivation;
const reduceAxes = broadcast_util.getReductionAxes(bias.shape, dyActivation.shape);
if (reduceAxes.length > 0) {
res = sum(res, reduceAxes);
}
return reshape(res, bias.shape);
}
export function applyActivation(x, activation, preluActivationWeights, leakyreluAlpha) {
if (activation === 'linear') {
return x;
}
else if (activation === 'relu') {
return relu(x);
}
else if (activation === 'elu') {
return elu(x);
}
else if (activation === 'relu6') {
return relu6(x);
}
else if (activation === 'prelu') {
return prelu(x, preluActivationWeights);
}
else if (activation === 'leakyrelu') {
return leakyRelu(x, leakyreluAlpha);
}
else if (activation === 'sigmoid') {
return sigmoid(x);
}
throw new Error(`Unknown fused activation ${activation}.`);
}
// Whether we should call fused ops.
export const shouldFuse = (gradientDepth, activation) => {
const gradientMode = gradientDepth > 0;
return !gradientMode || activation === 'linear';
};
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiZnVzZWRfdXRpbC5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtY29yZS9zcmMvb3BzL2Z1c2VkX3V0aWwudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBSUgsT0FBTyxLQUFLLGNBQWMsTUFBTSxrQkFBa0IsQ0FBQztBQUNuRCxPQUFPLEVBQUMsR0FBRyxFQUFDLE1BQU0sT0FBTyxDQUFDO0FBRTFCLE9BQU8sRUFBQyxTQUFTLEVBQUMsTUFBTSxjQUFjLENBQUM7QUFDdkMsT0FBTyxFQUFDLEdBQUcsRUFBQyxNQUFNLE9BQU8sQ0FBQztBQUMxQixPQUFPLEVBQUMsS0FBSyxFQUFDLE1BQU0sU0FBUyxDQUFDO0FBQzlCLE9BQU8sRUFBQyxJQUFJLEVBQUMsTUFBTSxRQUFRLENBQUM7QUFDNUIsT0FBTyxFQUFDLEtBQUssRUFBQyxNQUFNLFNBQVMsQ0FBQztBQUM5QixPQUFPLEVBQUMsT0FBTyxFQUFDLE1BQU0sV0FBVyxDQUFDO0FBQ2xDLE9BQU8sRUFBQyxPQUFPLEVBQUMsTUFBTSxXQUFXLENBQUM7QUFDbEMsT0FBTyxFQUFDLElBQUksRUFBQyxNQUFNLFFBQVEsQ0FBQztBQUM1QixPQUFPLEVBQUMsR0FBRyxFQUFDLE1BQU0sT0FBTyxDQUFDO0FBRTFCLHlDQUF5QztBQUN6QyxNQUFNLFVBQVUsb0JBQW9CLENBQ2hDLEVBQVUsRUFBRSxDQUFTLEVBQUUsVUFBc0I7SUFDL0MsSUFBSSxVQUFVLElBQUksSUFBSSxJQUFJLFVBQVUsS0FBSyxRQUFRLEVBQUU7UUFDakQsT0FBTyxFQUFFLENBQUM7S0FDWDtJQUNELElBQUksVUFBVSxLQUFLLE1BQU0sRUFBRTtRQUN6QixPQUFPLEdBQUcsQ0FBQyxFQUFFLEVBQUUsSUFBSSxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUM7S0FDekI7SUFDRCxNQUFNLElBQUksS0FBSyxDQUNYLGdEQUFnRCxVQUFVLEdBQUcsQ0FBQyxDQUFDO0FBQ3JFLENBQUM7QUFFRCxtQ0FBbUM7QUFDbkMsTUFBTSxVQUFVLG9CQUFvQixDQUNoQyxJQUFZLEVBQUUsWUFBb0I7SUFDcEMsSUFBSSxHQUFHLEdBQUcsWUFBWSxDQUFDO0lBQ3ZCLE1BQU0sVUFBVSxHQUNaLGNBQWMsQ0FBQyxnQkFBZ0IsQ0FBQyxJQUFJLENBQUMsS0FBSyxFQUFFLFlBQVksQ0FBQyxLQUFLLENBQUMsQ0FBQztJQUNwRSxJQUFJLFVBQVUsQ0FBQyxNQUFNLEdBQUcsQ0FBQyxFQUFFO1FBQ3pCLEdBQUcsR0FBRyxHQUFHLENBQUMsR0FBRyxFQUFFLFVBQVUsQ0FBQyxDQUFDO0tBQzVCO0lBQ0QsT0FBTyxPQUFPLENBQUMsR0FBRyxFQUFFLElBQUksQ0FBQyxLQUFLLENBQUMsQ0FBQztBQUNsQyxDQUFDO0FBRUQsTUFBTSxVQUFVLGVBQWUsQ0FDM0IsQ0FBUyxFQUFFLFVBQXNCLEVBQUUsc0JBQStCLEVBQ2xFLGNBQXVCO0lBQ3pCLElBQUksVUFBVSxLQUFLLFFBQVEsRUFBRTtRQUMzQixPQUFPLENBQUMsQ0FBQztLQUNWO1NBQU0sSUFBSSxVQUFVLEtBQUssTUFBTSxFQUFFO1FBQ2hDLE9BQU8sSUFBSSxDQUFDLENBQUMsQ0FBQyxDQUFDO0tBQ2hCO1NBQU0sSUFBSSxVQUFVLEtBQUssS0FBSyxFQUFFO1FBQy9CLE9BQU8sR0FBRyxDQUFDLENBQUMsQ0FBQyxDQUFDO0tBQ2Y7U0FBTSxJQUFJLFVBQVUsS0FBSyxPQUFPLEVBQUU7UUFDakMsT0FBTyxLQUFLLENBQUMsQ0FBQyxDQUFDLENBQUM7S0FDakI7U0FBTSxJQUFJLFVBQVUsS0FBSyxPQUFPLEVBQUU7UUFDakMsT0FBTyxLQUFLLENBQUMsQ0FBQyxFQUFFLHNCQUFzQixDQUFDLENBQUM7S0FDekM7U0FBTSxJQUFJLFVBQVUsS0FBSyxXQUFXLEVBQUU7UUFDckMsT0FBTyxTQUFTLENBQUMsQ0FBQyxFQUFFLGNBQWMsQ0FBQyxDQUFDO0tBQ3JDO1NBQU0sSUFBSSxVQUFVLEtBQUssU0FBUyxFQUFFO1FBQ25DLE9BQU8sT0FBTyxDQUFDLENBQUMsQ0FBQyxDQUFDO0tBQ25CO0lBQ0QsTUFBTSxJQUFJLEtBQUssQ0FBQyw0QkFBNEIsVUFBVSxHQUFHLENBQUMsQ0FBQztBQUM3RCxDQUFDO0FBRUQsb0NBQW9DO0FBQ3BDLE1BQU0sQ0FBQyxNQUFNLFVBQVUsR0FBRyxDQUFDLGFBQXFCLEVBQUUsVUFBc0IsRUFBRSxFQUFFO0lBQzFFLE1BQU0sWUFBWSxHQUFHLGFBQWEsR0FBRyxDQUFDLENBQUM7SUFDdkMsT0FBTyxDQUFDLFlBQVksSUFBSSxVQUFVLEtBQUssUUFBUSxDQUFDO0FBQ2xELENBQUMsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDE5IEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtUZW5zb3J9IGZyb20gJy4uL3RlbnNvcic7XG5cbmltcG9ydCAqIGFzIGJyb2FkY2FzdF91dGlsIGZyb20gJy4vYnJvYWRjYXN0X3V0aWwnO1xuaW1wb3J0IHtlbHV9IGZyb20gJy4vZWx1JztcbmltcG9ydCB7QWN0aXZhdGlvbn0gZnJvbSAnLi9mdXNlZF90eXBlcyc7XG5pbXBvcnQge2xlYWt5UmVsdX0gZnJvbSAnLi9sZWFreV9yZWx1JztcbmltcG9ydCB7bXVsfSBmcm9tICcuL211bCc7XG5pbXBvcnQge3ByZWx1fSBmcm9tICcuL3ByZWx1JztcbmltcG9ydCB7cmVsdX0gZnJvbSAnLi9yZWx1JztcbmltcG9ydCB7cmVsdTZ9IGZyb20gJy4vcmVsdTYnO1xuaW1wb3J0IHtyZXNoYXBlfSBmcm9tICcuL3Jlc2hhcGUnO1xuaW1wb3J0IHtzaWdtb2lkfSBmcm9tICcuL3NpZ21vaWQnO1xuaW1wb3J0IHtzdGVwfSBmcm9tICcuL3N0ZXAnO1xuaW1wb3J0IHtzdW19IGZyb20gJy4vc3VtJztcblxuLy8gUmV0dXJucyBncmFkaWVudCBmb3IgZnVzZWQgYWN0aXZhdGlvbi5cbmV4cG9ydCBmdW5jdGlvbiBnZXRGdXNlZER5QWN0aXZhdGlvbihcbiAgICBkeTogVGVuc29yLCB5OiBUZW5zb3IsIGFjdGl2YXRpb246IEFjdGl2YXRpb24pOiBUZW5zb3Ige1xuICBpZiAoYWN0aXZhdGlvbiA9PSBudWxsIHx8IGFjdGl2YXRpb24gPT09ICdsaW5lYXInKSB7XG4gICAgcmV0dXJuIGR5O1xuICB9XG4gIGlmIChhY3RpdmF0aW9uID09PSAncmVsdScpIHtcbiAgICByZXR1cm4gbXVsKGR5LCBzdGVwKHkpKTtcbiAgfVxuICB0aHJvdyBuZXcgRXJyb3IoXG4gICAgICBgQ2Fubm90IGNvbXB1dGUgZ3JhZGllbnQgZm9yIGZ1c2VkIGFjdGl2YXRpb24gJHthY3RpdmF0aW9ufS5gKTtcbn1cblxuLy8gUmV0dXJucyBncmFkaWVudCBmb3IgZnVzZWQgYmlhcy5cbmV4cG9ydCBmdW5jdGlvbiBnZXRGdXNlZEJpYXNHcmFkaWVudChcbiAgICBiaWFzOiBUZW5zb3IsIGR5QWN0aXZhdGlvbjogVGVuc29yKTogVGVuc29yIHtcbiAgbGV0IHJlcyA9IGR5QWN0aXZhdGlvbjtcbiAgY29uc3QgcmVkdWNlQXhlcyA9XG4gICAgICBicm9hZGNhc3RfdXRpbC5nZXRSZWR1Y3Rpb25BeGVzKGJpYXMuc2hhcGUsIGR5QWN0aXZhdGlvbi5zaGFwZSk7XG4gIGlmIChyZWR1Y2VBeGVzLmxlbmd0aCA+IDApIHtcbiAgICByZXMgPSBzdW0ocmVzLCByZWR1Y2VBeGVzKTtcbiAgfVxuICByZXR1cm4gcmVzaGFwZShyZXMsIGJpYXMuc2hhcGUpO1xufVxuXG5leHBvcnQgZnVuY3Rpb24gYXBwbHlBY3RpdmF0aW9uKFxuICAgIHg6IFRlbnNvciwgYWN0aXZhdGlvbjogQWN0aXZhdGlvbiwgcHJlbHVBY3RpdmF0aW9uV2VpZ2h0cz86IFRlbnNvcixcbiAgICBsZWFreXJlbHVBbHBoYT86IG51bWJlcik6IFRlbnNvciB7XG4gIGlmIChhY3RpdmF0aW9uID09PSAnbGluZWFyJykge1xuICAgIHJldHVybiB4O1xuICB9IGVsc2UgaWYgKGFjdGl2YXRpb24gPT09ICdyZWx1Jykge1xuICAgIHJldHVybiByZWx1KHgpO1xuICB9IGVsc2UgaWYgKGFjdGl2YXRpb24gPT09ICdlbHUnKSB7XG4gICAgcmV0dXJuIGVsdSh4KTtcbiAgfSBlbHNlIGlmIChhY3RpdmF0aW9uID09PSAncmVsdTYnKSB7XG4gICAgcmV0dXJuIHJlbHU2KHgpO1xuICB9IGVsc2UgaWYgKGFjdGl2YXRpb24gPT09ICdwcmVsdScpIHtcbiAgICByZXR1cm4gcHJlbHUoeCwgcHJlbHVBY3RpdmF0aW9uV2VpZ2h0cyk7XG4gIH0gZWxzZSBpZiAoYWN0aXZhdGlvbiA9PT0gJ2xlYWt5cmVsdScpIHtcbiAgICByZXR1cm4gbGVha3lSZWx1KHgsIGxlYWt5cmVsdUFscGhhKTtcbiAgfSBlbHNlIGlmIChhY3RpdmF0aW9uID09PSAnc2lnbW9pZCcpIHtcbiAgICByZXR1cm4gc2lnbW9pZCh4KTtcbiAgfVxuICB0aHJvdyBuZXcgRXJyb3IoYFVua25vd24gZnVzZWQgYWN0aXZhdGlvbiAke2FjdGl2YXRpb259LmApO1xufVxuXG4vLyBXaGV0aGVyIHdlIHNob3VsZCBjYWxsIGZ1c2VkIG9wcy5cbmV4cG9ydCBjb25zdCBzaG91bGRGdXNlID0gKGdyYWRpZW50RGVwdGg6IG51bWJlciwgYWN0aXZhdGlvbjogQWN0aXZhdGlvbikgPT4ge1xuICBjb25zdCBncmFkaWVudE1vZGUgPSBncmFkaWVudERlcHRoID4gMDtcbiAgcmV0dXJuICFncmFkaWVudE1vZGUgfHwgYWN0aXZhdGlvbiA9PT0gJ2xpbmVhcic7XG59O1xuIl19