@upscalerjs/esrgan-medium
Version:
ESRGAN Medium Model for UpscalerJS. Upscale images and increase image resolution with AI using Javascript
119 lines (112 loc) • 3.75 kB
JavaScript
(function (global, factory) {
typeof exports === 'object' && typeof module !== 'undefined' ? module.exports = factory() :
typeof define === 'function' && define.amd ? define(factory) :
(global = typeof globalThis !== 'undefined' ? globalThis : global || self, global.ESRGANMedium2x = factory());
})(this, (function () { 'use strict';
const isTensorArray = (inputs) => {
return Array.isArray(inputs);
};
const getInput = (inputs) => {
if (isTensorArray(inputs)) {
return inputs[0];
}
return inputs;
};
const getESRGANModelDefinition = ({ scale, name, version, meta: { architecture, ...meta }, path: modelPath, }) => {
const path = modelPath || `models/x${scale}/model.json`;
if (architecture === 'rdn') {
return {
scale,
modelType: 'layers',
_internals: {
path,
name,
version,
},
meta: {
architecture,
...meta,
},
inputRange: [0, 255,],
outputRange: [0, 255,],
};
}
const setup = (tf) => {
const Layer = tf.layers.Layer;
const BETA = 0.2;
class MultiplyBeta extends Layer {
beta;
constructor() {
super({});
this.beta = BETA;
}
call(inputs) {
return tf.mul(getInput(inputs), this.beta);
}
static className = 'MultiplyBeta';
}
const getPixelShuffle = (_scale) => {
class PixelShuffle extends Layer {
scale = _scale;
constructor() {
super({});
}
computeOutputShape(inputShape) {
return [inputShape[0], inputShape[1], inputShape[2], 3,];
}
call(inputs) {
return tf.depthToSpace(getInput(inputs), this.scale, 'NHWC');
}
static className = `PixelShuffle${scale}x`;
}
return PixelShuffle;
};
[
MultiplyBeta,
getPixelShuffle(scale),
].forEach((layer) => {
tf.serialization.registerClass(layer);
});
};
return {
setup,
scale,
modelType: 'layers',
_internals: {
path,
name,
version,
},
meta: {
architecture,
...meta,
},
inputRange: [0, 1,],
outputRange: [0, 1,],
};
};
const NAME = "@upscalerjs/esrgan-medium";
const VERSION = "1.0.0-beta.13";
const getModelDefinition = (scale, modelFileName) => getESRGANModelDefinition({
scale,
path: `models/x${scale}/model.json`,
name: NAME,
version: VERSION,
meta: {
C: 1,
D: 10,
G: 64,
G0: 64,
T: 10,
architecture: "rdn",
patchSize: scale === 3 ? 129 : 128,
size: 'medium',
artifactReducing: false,
sharpening: false,
dataset: 'div2k',
modelFileName,
},
});
var index = getModelDefinition(2, 'rdn-C1-D10-G64-G064-T10-x2-patchsize128-compress100-sharpen0-datadiv2k-vary_cFalse_best-val_generator_PSNR_Y_epoch478');
return index;
}));