ml5-save
Version:
174 lines (154 loc) • 6.28 kB
JavaScript
// Copyright (c) 2018 ml5
//
// This software is released under the MIT License.
// https://opensource.org/licenses/MIT
/* eslint max-len: "off" */
/* eslint no-trailing-spaces: "off" */
/*
Fast Style Transfer
This implementation is heavily based on github.com/reiinakano/fast-style-transfer-deeplearnjs by Reiichiro Nakano.
The original TensorFlow implementation was developed by Logan Engstrom: github.com/lengstrom/fast-style-transfer
*/
import * as tf from '@tensorflow/tfjs';
import Video from './../utils/Video';
import CheckpointLoader from '../utils/checkpointLoader';
import { array3DToImage } from '../utils/imageUtilities';
import callCallback from '../utils/callcallback';
const IMAGE_SIZE = 200;
class StyleTransfer extends Video {
/**
* Create a new Style Transfer Instance。
* @param {model} model - The path to Style Transfer model.
* @param {HTMLVideoElement || p5.Video} video - Optional. A HTML video element or a p5 video element.
* @param {funciton} callback - Optional. A function to be called once the model is loaded. If no callback is provided, it will return a promise that will be resolved once the model has loaded.
*/
constructor(model, video, callback) {
super(video, IMAGE_SIZE);
/**
* Boolean value that specifies if the model has loaded.
* @type {boolean}
* @public
*/
this.ready = false;
this.variableDictionary = {};
this.timesScalar = tf.scalar(150);
this.plusScalar = tf.scalar(255.0 / 2);
this.epsilonScalar = tf.scalar(1e-3);
this.video = null;
this.ready = callCallback(this.load(model), callback);
// this.then = this.ready.then;
}
async load(model) {
if (this.videoElt) {
await this.loadVideo();
this.videoReady = true;
}
await this.loadCheckpoints(model);
return this;
}
async loadCheckpoints(path) {
const checkpointLoader = new CheckpointLoader(path);
this.variables = await checkpointLoader.getAllVariables();
}
instanceNorm(input, id) {
return tf.tidy( () => {
const [height, width, inDepth] = input.shape;
const moments = tf.moments(input, [0, 1]);
const mu = moments.mean;
const sigmaSq = moments.variance;
const shift = this.variables[StyleTransfer.getVariableName(id)];
const scale = this.variables[StyleTransfer.getVariableName(id + 1)];
const epsilon = this.epsilonScalar;
const normalized = tf.div(tf.sub(input.asType('float32'), mu), tf.sqrt(tf.add(sigmaSq, epsilon)));
const shifted = tf.add(tf.mul(scale, normalized), shift);
return shifted.as3D(height, width, inDepth);
});
}
convLayer(input, strides, relu, id) {
const y = tf.conv2d(input, this.variables[StyleTransfer.getVariableName(id)], [strides, strides], 'same');
const y2 = this.instanceNorm(y, id + 1);
if (relu) {
return tf.relu(y2);
}
return y2;
}
residualBlock(input, id) {
const conv1 = this.convLayer(input, 1, true, id);
const conv2 = this.convLayer(conv1, 1, false, id + 3);
return tf.add(conv2, input);
}
convTransposeLayer(input, numFilters, strides, id) {
const [height, width] = input.shape;
const newRows = height * strides;
const newCols = width * strides;
const newShape = [newRows, newCols, numFilters];
const y = tf.conv2dTranspose(input, this.variables[StyleTransfer.getVariableName(id)], newShape, [strides, strides], 'same');
const y2 = this.instanceNorm(y, id + 1);
const y3 = tf.relu(y2);
return y3;
}
/**
*
* @param {Image || p5.Image || HTMLVideoElement || p5.Video} input - A HTML video or image element or a p5 image or video element. If no input is provided, the default is to use the video element given in the constructor.
* @param {funciton} callback - Optional. A function to run once the model has made the transfer. If no callback is provided, it will return a promise that will be resolved once the model has made the transfer.
*/
async transfer(inputOrCallback, cb) {
let input;
let callback = cb;
if (inputOrCallback instanceof HTMLVideoElement ||
inputOrCallback instanceof HTMLImageElement ||
inputOrCallback instanceof ImageData) {
input = inputOrCallback;
} else if (typeof inputOrCallback === 'object' && (inputOrCallback.elt instanceof HTMLVideoElement
|| inputOrCallback.elt instanceof HTMLImageElement
|| inputOrCallback.elt instanceof ImageData)) {
input = inputOrCallback.elt;
} else if (typeof inputOrCallback === 'function') {
input = this.video;
callback = inputOrCallback;
}
return callCallback(this.transferInternal(input), callback);
}
async transferInternal(input) {
const image = tf.browser.fromPixels(input);
const result = array3DToImage(tf.tidy(() => {
const conv1 = this.convLayer(image, 1, true, 0);
const conv2 = this.convLayer(conv1, 2, true, 3);
const conv3 = this.convLayer(conv2, 2, true, 6);
const res1 = this.residualBlock(conv3, 9);
const res2 = this.residualBlock(res1, 15);
const res3 = this.residualBlock(res2, 21);
const res4 = this.residualBlock(res3, 27);
const res5 = this.residualBlock(res4, 33);
const convT1 = this.convTransposeLayer(res5, 64, 2, 39);
const convT2 = this.convTransposeLayer(convT1, 32, 2, 42);
const convT3 = this.convLayer(convT2, 1, false, 45);
const outTanh = tf.tanh(convT3);
const scaled = tf.mul(this.timesScalar, outTanh);
const shifted = tf.add(this.plusScalar, scaled);
const clamped = tf.clipByValue(shifted, 0, 255);
const normalized = tf.div(clamped, tf.scalar(255.0));
return normalized;
}));
image.dispose();
await tf.nextFrame();
return result;
}
// Static Methods
static getVariableName(id) {
if (id === 0) {
return 'Variable';
}
return `Variable_${id}`;
}
}
const styleTransfer = (model, videoOrCallback, cb) => {
const video = videoOrCallback;
let callback = cb;
if (typeof videoOrCallback === 'function') {
callback = videoOrCallback;
}
const instance = new StyleTransfer(model, video, callback);
return callback ? instance : instance.ready;
};
export default styleTransfer;