UNPKG

@upscalerjs/esrgan-medium

Version:

ESRGAN Medium Model for UpscalerJS. Upscale images and increase image resolution with AI using Javascript

119 lines (112 loc) 3.73 kB
(function (global, factory) { typeof exports === 'object' && typeof module !== 'undefined' ? module.exports = factory() : typeof define === 'function' && define.amd ? define(factory) : (global = typeof globalThis !== 'undefined' ? globalThis : global || self, global.ESRGANMedium8x = factory()); })(this, (function () { 'use strict'; const isTensorArray = (inputs) => { return Array.isArray(inputs); }; const getInput = (inputs) => { if (isTensorArray(inputs)) { return inputs[0]; } return inputs; }; const getESRGANModelDefinition = ({ scale, name, version, meta: { architecture, ...meta }, path: modelPath, }) => { const path = modelPath || `models/x${scale}/model.json`; if (architecture === 'rdn') { return { scale, modelType: 'layers', _internals: { path, name, version, }, meta: { architecture, ...meta, }, inputRange: [0, 255,], outputRange: [0, 255,], }; } const setup = (tf) => { const Layer = tf.layers.Layer; const BETA = 0.2; class MultiplyBeta extends Layer { beta; constructor() { super({}); this.beta = BETA; } call(inputs) { return tf.mul(getInput(inputs), this.beta); } static className = 'MultiplyBeta'; } const getPixelShuffle = (_scale) => { class PixelShuffle extends Layer { scale = _scale; constructor() { super({}); } computeOutputShape(inputShape) { return [inputShape[0], inputShape[1], inputShape[2], 3,]; } call(inputs) { return tf.depthToSpace(getInput(inputs), this.scale, 'NHWC'); } static className = `PixelShuffle${scale}x`; } return PixelShuffle; }; [ MultiplyBeta, getPixelShuffle(scale), ].forEach((layer) => { tf.serialization.registerClass(layer); }); }; return { setup, scale, modelType: 'layers', _internals: { path, name, version, }, meta: { architecture, ...meta, }, inputRange: [0, 1,], outputRange: [0, 1,], }; }; const NAME = "@upscalerjs/esrgan-medium"; const VERSION = "1.0.0-beta.13"; const getModelDefinition = (scale, modelFileName) => getESRGANModelDefinition({ scale, path: `models/x${scale}/model.json`, name: NAME, version: VERSION, meta: { C: 1, D: 10, G: 64, G0: 64, T: 10, architecture: "rdn", patchSize: scale === 3 ? 129 : 128, size: 'medium', artifactReducing: false, sharpening: false, dataset: 'div2k', modelFileName, }, }); var index = getModelDefinition(8, 'rdn-C1-D10-G64-G064-T10-x8-patchsize128-compress100-sharpen0-datadiv2k-vary_cFalse_epoch483'); return index; }));