UNPKG

fast-sobel-tfjs

Version:

GPU-accelerated Sobel edge detection for TensorFlow.js - 5-10x faster than CPU implementations

224 lines (223 loc) 10.5 kB
var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) { function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); } return new (P || (P = Promise))(function (resolve, reject) { function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } } function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } } function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); } step((generator = generator.apply(thisArg, _arguments || [])).next()); }); }; import * as tf from '@tensorflow/tfjs'; import { normalizeTensor } from './tensor-utils'; /** * Converts a tensor to an ImageData object * * @param tensor Input tensor (should be [height, width, channels]) * @param normalize Whether to normalize the values to 0-255 range * @returns ImageData object */ export function tensorToImageData(tensor_1) { return __awaiter(this, arguments, void 0, function* (tensor, normalize = true) { // Ensure the tensor has the correct number of dimensions const shape = tensor.shape; console.log("Input tensor shape in tensorToImageData:", shape); // If it's not a 3D tensor, try to reshape it let processedTensor = tensor; let disposeTensor = false; try { if (shape.length !== 3) { console.warn(`Expected 3D tensor but got ${shape.length}D tensor. Attempting to reshape.`); if (shape.length === 2) { // It's a 2D tensor [height, width], add a channel dimension processedTensor = tf.tidy(() => tensor.expandDims(-1)); disposeTensor = true; console.log("Expanded 2D tensor to 3D:", processedTensor.shape); } else if (shape.length === 4 && shape[0] === 1) { // It's a 4D tensor with batch size 1, remove the batch dimension processedTensor = tf.tidy(() => { // Use squeeze to remove the batch dimension return tensor.squeeze([0]); }); disposeTensor = true; console.log("Squeezed 4D tensor to 3D:", processedTensor.shape); } else { throw new Error(`Cannot convert tensor of shape [${shape}] to ImageData`); } } const [height, width, channels] = processedTensor.shape; let imageTensor = processedTensor; let disposeImageTensor = false; if (normalize) { // Normalize to 0-255 range console.log("Normalizing tensor to 0-255 range"); imageTensor = tf.tidy(() => normalizeTensor(processedTensor, 0, 255)); disposeImageTensor = true; } // Make sure we have at least 1 channel if (channels < 1) { throw new Error(`Tensor must have at least 1 channel but has ${channels}`); } // Ensure tensor has proper format for display (1, 3, or 4 channels) let displayTensor = imageTensor; let disposeDisplayTensor = false; // Handle case where channels don't match expected format if (![1, 3, 4].includes(channels)) { console.warn(`Unusual number of channels: ${channels}. Converting to grayscale.`); // Convert to grayscale (1 channel) displayTensor = tf.tidy(() => tf.mean(imageTensor, -1, true)); disposeDisplayTensor = true; console.log("Converted to grayscale, shape:", displayTensor.shape); } // Print some tensor stats for debugging tf.tidy(() => { const minVal = tf.min(displayTensor).dataSync()[0]; const maxVal = tf.max(displayTensor).dataSync()[0]; const meanVal = tf.mean(displayTensor).dataSync()[0]; console.log(`Tensor stats - Min: ${minVal}, Max: ${maxVal}, Mean: ${meanVal}`); }); // Cast to int32 to ensure values are in the correct range console.log("Casting to int32 and clipping to 0-255"); const intTensor = tf.tidy(() => displayTensor.clipByValue(0, 255).cast('int32')); // Get the data as a typed array console.log("Converting tensor to typed array"); const data = yield intTensor.data(); console.log(`Data array length: ${data.length}, expected: ${width * height * displayTensor.shape[2]}`); // Sample some values to check console.log("Data sample:", data.slice(0, 20)); const finalChannels = displayTensor.shape[2]; // Create the appropriate array for ImageData console.log(`Creating Uint8ClampedArray for ${width}x${height} image with ${finalChannels} channels`); const pixelArray = new Uint8ClampedArray(width * height * 4); // Fill the array based on the number of channels in the tensor if (finalChannels === 1) { // Grayscale to RGBA console.log("Converting grayscale to RGBA"); for (let i = 0; i < height * width; i++) { const value = data[i]; pixelArray[i * 4] = value; // R pixelArray[i * 4 + 1] = value; // G pixelArray[i * 4 + 2] = value; // B pixelArray[i * 4 + 3] = 255; // A (fully opaque) } } else if (finalChannels === 3) { // RGB to RGBA console.log("Converting RGB to RGBA"); for (let i = 0; i < height * width; i++) { pixelArray[i * 4] = data[i * 3]; // R pixelArray[i * 4 + 1] = data[i * 3 + 1]; // G pixelArray[i * 4 + 2] = data[i * 3 + 2]; // B pixelArray[i * 4 + 3] = 255; // A (fully opaque) } } else if (finalChannels === 4) { // RGBA data needs explicit conversion from Int32Array to Uint8 console.log("Converting RGBA Int32Array to Uint8"); for (let i = 0; i < data.length; i++) { pixelArray[i] = data[i]; // Uint8ClampedArray will automatically clamp to 0-255 } // Verify alpha channel let alphaSum = 0; for (let i = 3; i < pixelArray.length; i += 4) { alphaSum += pixelArray[i]; } console.log(`Alpha channel average: ${alphaSum / (width * height)}`); } // Check for zeros in the pixel array const nonZeroPixels = Array.from(pixelArray).filter(val => val > 0).length; const total = pixelArray.length; console.log(`Non-zero pixels: ${nonZeroPixels} out of ${total} (${(nonZeroPixels / total * 100).toFixed(2)}%)`); // Clean up intermediate tensors intTensor.dispose(); if (disposeDisplayTensor) { displayTensor.dispose(); } if (disposeImageTensor && imageTensor !== displayTensor) { imageTensor.dispose(); } if (disposeTensor && processedTensor !== imageTensor) { processedTensor.dispose(); } // Create and return ImageData console.log(`Creating ImageData object with dimensions ${width}x${height}`); return new ImageData(pixelArray, width, height); } catch (error) { // Clean up in case of error if (disposeTensor && processedTensor !== tensor) { processedTensor.dispose(); } console.error("Error in tensorToImageData:", error); throw error; } }); } /** * Creates a canvas element from an ImageData object * * @param imageData ImageData to put on canvas * @returns Canvas element */ export function imageDataToCanvas(imageData) { const canvas = document.createElement('canvas'); const { width, height } = imageData; canvas.width = width; canvas.height = height; const ctx = canvas.getContext('2d'); if (!ctx) { throw new Error('Could not get canvas context'); } ctx.putImageData(imageData, 0, 0); return canvas; } /** * Converts a pixel array to a tensor * * @param pixels Pixel data * @param width Image width * @param height Image height * @param channels Number of channels (1, 3, or 4) * @returns 3D tensor of shape [height, width, channels] */ export function pixelArrayToTensor(pixels, width, height, channels = 4) { // Validate channels if (![1, 3, 4].includes(channels)) { throw new Error('Channels must be 1, 3, or 4'); } // Validate array length if (pixels.length !== width * height * channels) { throw new Error(`Expected array of length ${width * height * channels} but got ${pixels.length}`); } // Create and return tensor return tf.tensor3d(Array.from(pixels), [height, width, channels], 'int32'); } /** * Processes an HTML Image element and returns a canvas with the filtered result * * @param image HTML Image element * @param processFunction Function to process the ImageData * @returns Canvas element with processed image */ export function processHTMLImage(image, processFunction) { return __awaiter(this, void 0, void 0, function* () { // Create a canvas to draw the image const canvas = document.createElement('canvas'); canvas.width = image.width; canvas.height = image.height; const ctx = canvas.getContext('2d'); if (!ctx) { throw new Error('Could not get canvas context'); } // Draw the image on the canvas ctx.drawImage(image, 0, 0); // Get the image data const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height); // Process the image data const resultData = yield processFunction(imageData); // Put the processed data back on the canvas ctx.putImageData(resultData, 0, 0); return canvas; }); }