fast-sobel-tfjs
Version:
GPU-accelerated Sobel edge detection for TensorFlow.js - 5-10x faster than CPU implementations
134 lines (133 loc) • 5.71 kB
JavaScript
import * as tf from '@tensorflow/tfjs';
import { KERNELS } from '../kernels';
/**
* Creates a Sobel kernel tensor for the specified direction and number of channels
*
* @param direction 'x' or 'y' for horizontal or vertical gradient
* @param kernelSize Size of the kernel (3, 5, or 7)
* @param channels Number of input channels
* @returns Tensor4D representing the kernel
*/
export function createSobelKernel(direction, kernelSize, channels) {
const kernelArray = KERNELS[direction][kernelSize];
return tf.tidy(() => {
// Create 2D kernel tensor
const kernel2d = tf.tensor2d(kernelArray);
// Reshape to 4D: [kernelSize, kernelSize, 1, 1]
const kernel4d = kernel2d.reshape([kernelSize, kernelSize, 1, 1]);
// Tile across the channels dimension and return
return kernel4d.tile([1, 1, channels, 1]);
});
}
/**
* Normalizes a tensor to a specific range for display or output
*
* @param tensor Input tensor to normalize
* @param min Minimum value of the target range
* @param max Maximum value of the target range
* @returns Normalized tensor
*/
export function normalizeTensor(tensor, min = 0, max = 255) {
return tf.tidy(() => {
// Original normalization logic for all cases
console.log('[normalizeTensor] Applying standard normalization.');
const minVal = tensor.min();
const maxVal = tensor.max();
// Avoid division by zero
const range = tf.maximum(tf.sub(maxVal, minVal), tf.scalar(1e-6));
// Normalize to [0, 1] range
const normalized = tf.div(tf.sub(tensor, minVal), range);
// Scale to target range
const result = tf.add(tf.mul(normalized, tf.scalar(max - min)), tf.scalar(min));
// Clean up
minVal.dispose();
maxVal.dispose();
range.dispose();
normalized.dispose();
return result;
});
}
/**
* Converts an RGB tensor to grayscale if needed
*
* @param input Input tensor
* @param grayscale Whether to convert to grayscale
* @returns Processed tensor and indicator if a new tensor was created
*/
export function ensureGrayscaleIfNeeded(input, grayscale) {
const [height, width, channels] = input.shape;
// Ensure float32
const floatInput = input.dtype === 'float32' ? input : input.toFloat();
const createdFloatTensor = floatInput !== input;
if (createdFloatTensor) {
console.log("Converted input to float32");
}
// --- Channel Handling Logic ---
// Case 1: Grayscale is ON
if (grayscale) {
if (channels === 4) {
// RGBA input, need grayscale -> Slice alpha, then convert RGB to Gray
console.log("Grayscale ON: Slicing off alpha from RGBA.");
const rgbTensor = tf.tidy(() => floatInput.slice([0, 0, 0], [-1, -1, 3]));
if (createdFloatTensor)
floatInput.dispose(); // Dispose intermediate float tensor
console.log("Grayscale ON: Converting sliced RGB to grayscale.");
const grayTensor = tf.image.rgbToGrayscale(rgbTensor);
rgbTensor.dispose(); // Dispose intermediate RGB tensor
return {
tensor: grayTensor,
newTensorCreated: true // Always true since we converted
};
}
else if (channels === 3) {
// RGB input, need grayscale -> Convert RGB to Gray
console.log("Grayscale ON: Converting RGB to grayscale.");
if (createdFloatTensor)
floatInput.dispose(); // Dispose intermediate float tensor if created
const grayTensor = tf.image.rgbToGrayscale(floatInput);
return {
tensor: grayTensor,
newTensorCreated: true // Always true since we converted
};
}
else if (channels === 1) {
// Already grayscale, do nothing
console.log("Grayscale ON: Input is already 1 channel.");
return { tensor: floatInput, newTensorCreated: createdFloatTensor };
}
else {
// Unsupported channel count for grayscale
console.warn(`Grayscale ON: Unsupported channel count ${channels}. Returning original tensor.`);
return { tensor: floatInput, newTensorCreated: createdFloatTensor };
}
}
// Case 2: Grayscale is OFF
else {
if (channels === 4) {
// RGBA input, grayscale OFF -> Slice off alpha, process as RGB
console.log("Grayscale OFF: Slicing off alpha from RGBA to process as RGB.");
const rgbTensor = tf.tidy(() => floatInput.slice([0, 0, 0], [-1, -1, 3]));
if (createdFloatTensor)
floatInput.dispose(); // Dispose intermediate float tensor
return {
tensor: rgbTensor,
newTensorCreated: true // Always true since we sliced
};
}
else if (channels === 3) {
// RGB input, grayscale OFF -> Process as is
console.log("Grayscale OFF: Processing RGB input as is.");
return { tensor: floatInput, newTensorCreated: createdFloatTensor };
}
else if (channels === 1) {
// Grayscale input, grayscale OFF -> Process as is
console.log("Grayscale OFF: Processing 1-channel input as is.");
return { tensor: floatInput, newTensorCreated: createdFloatTensor };
}
else {
// Unsupported channel count
console.warn(`Grayscale OFF: Unsupported channel count ${channels}. Returning original tensor.`);
return { tensor: floatInput, newTensorCreated: createdFloatTensor };
}
}
}