UNPKG

@hoff97/tensor-js

Version:

PyTorch like deep learning inferrence library

49 lines 2.05 kB
var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) { function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); } return new (P || (P = Promise))(function (resolve, reject) { function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } } function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } } function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); } step((generator = generator.apply(thisArg, _arguments || [])).next()); }); }; import { OnnxNode } from '../node'; export class SoftmaxNode extends OnnxNode { constructor(attributes, inputs, outputs, constants, onnxVersion, mode) { super(attributes, inputs, outputs, constants, onnxVersion, mode); //@ts-ignore this.axis = this.getAttributeInt('axis'); } forward(inputs) { return __awaiter(this, void 0, void 0, function* () { const x = inputs[0]; const shapeX = x.getShape(); let ax = this.axis; if (ax === undefined) { if (this.onnxVersion < 13) { ax = 1; } else { ax = shapeX.length - 1; } } const sh1 = shapeX.slice(0, ax).reduce((x, y) => x * y, 1); const reshaped = x.reshape([sh1, -1], false); const max = reshaped.max(1, true); const normalized = reshaped.subtract(max); const exp = normalized.exp(); const sum = exp.sum(1, true); const result = exp.divide(sum); max.delete(); normalized.delete(); exp.delete(); sum.delete(); return [result.reshape(shapeX, false)]; }); } getType() { return 'Softmax'; } delete() { } } //# sourceMappingURL=softmax.js.map