@upscalerjs/esrgan-medium
Version:
ESRGAN Medium Model for UpscalerJS. Upscale images and increase image resolution with AI using Javascript
83 lines (82 loc) • 2.26 kB
JavaScript
const isTensorArray = (inputs) => {
return Array.isArray(inputs);
};
const getInput = (inputs) => {
if (isTensorArray(inputs)) {
return inputs[0];
}
return inputs;
};
export const getESRGANModelDefinition = ({ scale, name, version, meta: { architecture, ...meta }, path: modelPath, }) => {
const path = modelPath || `models/x${scale}/model.json`;
if (architecture === 'rdn') {
return {
scale,
modelType: 'layers',
_internals: {
path,
name,
version,
},
meta: {
architecture,
...meta,
},
inputRange: [0, 255,],
outputRange: [0, 255,],
};
}
const setup = (tf) => {
const Layer = tf.layers.Layer;
const BETA = 0.2;
class MultiplyBeta extends Layer {
beta;
constructor() {
super({});
this.beta = BETA;
}
call(inputs) {
return tf.mul(getInput(inputs), this.beta);
}
static className = 'MultiplyBeta';
}
const getPixelShuffle = (_scale) => {
class PixelShuffle extends Layer {
scale = _scale;
constructor() {
super({});
}
computeOutputShape(inputShape) {
return [inputShape[0], inputShape[1], inputShape[2], 3,];
}
call(inputs) {
return tf.depthToSpace(getInput(inputs), this.scale, 'NHWC');
}
static className = `PixelShuffle${scale}x`;
}
return PixelShuffle;
};
[
MultiplyBeta,
getPixelShuffle(scale),
].forEach((layer) => {
tf.serialization.registerClass(layer);
});
};
return {
setup,
scale,
modelType: 'layers',
_internals: {
path,
name,
version,
},
meta: {
architecture,
...meta,
},
inputRange: [0, 1,],
outputRange: [0, 1,],
};
};