UNPKG

@huggingface/transformers

Version:

State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in your browser, with no need for a server!

835 lines (706 loc) • 29.9 kB
/** * @file Helper module for image processing. * * These functions and classes are only used internally, * meaning an end-user shouldn't need to access anything here. * * @module utils/image */ import { isNullishDimension, saveBlob } from './core.js'; import { getFile } from './hub.js'; import { apis } from '../env.js'; import { Tensor } from './tensor.js'; // Will be empty (or not used) if running in browser or web-worker import sharp from 'sharp'; let createCanvasFunction; let ImageDataClass; let loadImageFunction; const IS_BROWSER_OR_WEBWORKER = apis.IS_BROWSER_ENV || apis.IS_WEBWORKER_ENV; if (IS_BROWSER_OR_WEBWORKER) { // Running in browser or web-worker createCanvasFunction = (/** @type {number} */ width, /** @type {number} */ height) => { if (!self.OffscreenCanvas) { throw new Error('OffscreenCanvas not supported by this browser.'); } return new self.OffscreenCanvas(width, height) }; loadImageFunction = self.createImageBitmap; ImageDataClass = self.ImageData; } else if (sharp) { // Running in Node.js, electron, or other non-browser environment loadImageFunction = async (/**@type {sharp.Sharp}*/img) => { const metadata = await img.metadata(); const rawChannels = metadata.channels; const { data, info } = await img.rotate().raw().toBuffer({ resolveWithObject: true }); const newImage = new RawImage(new Uint8ClampedArray(data), info.width, info.height, info.channels); if (rawChannels !== undefined && rawChannels !== info.channels) { // Make sure the new image has the same number of channels as the input image. // This is necessary for grayscale images. newImage.convert(rawChannels); } return newImage; } } else { throw new Error('Unable to load image processing library.'); } // Defined here: https://github.com/python-pillow/Pillow/blob/a405e8406b83f8bfb8916e93971edc7407b8b1ff/src/libImaging/Imaging.h#L262-L268 const RESAMPLING_MAPPING = { 0: 'nearest', 1: 'lanczos', 2: 'bilinear', 3: 'bicubic', 4: 'box', 5: 'hamming', } /** * Mapping from file extensions to MIME types. */ const CONTENT_TYPE_MAP = new Map([ ['png', 'image/png'], ['jpg', 'image/jpeg'], ['jpeg', 'image/jpeg'], ['gif', 'image/gif'], ]); export class RawImage { /** * Create a new `RawImage` object. * @param {Uint8ClampedArray|Uint8Array} data The pixel data. * @param {number} width The width of the image. * @param {number} height The height of the image. * @param {1|2|3|4} channels The number of channels. */ constructor(data, width, height, channels) { this.data = data; this.width = width; this.height = height; this.channels = channels; } /** * Returns the size of the image (width, height). * @returns {[number, number]} The size of the image (width, height). */ get size() { return [this.width, this.height]; } /** * Helper method for reading an image from a variety of input types. * @param {RawImage|string|URL|Blob|HTMLCanvasElement|OffscreenCanvas} input * @returns The image object. * * **Example:** Read image from a URL. * ```javascript * let image = await RawImage.read('https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/football-match.jpg'); * // RawImage { * // "data": Uint8ClampedArray [ 25, 25, 25, 19, 19, 19, ... ], * // "width": 800, * // "height": 533, * // "channels": 3 * // } * ``` */ static async read(input) { if (input instanceof RawImage) { return input; } else if (typeof input === 'string' || input instanceof URL) { return await this.fromURL(input); } else if (input instanceof Blob) { return await this.fromBlob(input); } else if ( (typeof HTMLCanvasElement !== "undefined" && input instanceof HTMLCanvasElement) || (typeof OffscreenCanvas !== "undefined" && input instanceof OffscreenCanvas) ) { return this.fromCanvas(input); } else { throw new Error(`Unsupported input type: ${typeof input}`); } } /** * Read an image from a canvas. * @param {HTMLCanvasElement|OffscreenCanvas} canvas The canvas to read the image from. * @returns {RawImage} The image object. */ static fromCanvas(canvas) { if (!IS_BROWSER_OR_WEBWORKER) { throw new Error('fromCanvas() is only supported in browser environments.') } const ctx = canvas.getContext('2d'); const data = ctx.getImageData(0, 0, canvas.width, canvas.height).data; return new RawImage(data, canvas.width, canvas.height, 4); } /** * Read an image from a URL or file path. * @param {string|URL} url The URL or file path to read the image from. * @returns {Promise<RawImage>} The image object. */ static async fromURL(url) { const response = await getFile(url); if (response.status !== 200) { throw new Error(`Unable to read image from "${url}" (${response.status} ${response.statusText})`); } const blob = await response.blob(); return this.fromBlob(blob); } /** * Helper method to create a new Image from a blob. * @param {Blob} blob The blob to read the image from. * @returns {Promise<RawImage>} The image object. */ static async fromBlob(blob) { if (IS_BROWSER_OR_WEBWORKER) { // Running in environment with canvas const img = await loadImageFunction(blob); const ctx = createCanvasFunction(img.width, img.height).getContext('2d'); // Draw image to context ctx.drawImage(img, 0, 0); return new this(ctx.getImageData(0, 0, img.width, img.height).data, img.width, img.height, 4); } else { // Use sharp.js to read (and possible resize) the image. const img = sharp(await blob.arrayBuffer()); return await loadImageFunction(img); } } /** * Helper method to create a new Image from a tensor * @param {Tensor} tensor */ static fromTensor(tensor, channel_format = 'CHW') { if (tensor.dims.length !== 3) { throw new Error(`Tensor should have 3 dimensions, but has ${tensor.dims.length} dimensions.`); } if (channel_format === 'CHW') { tensor = tensor.transpose(1, 2, 0); } else if (channel_format === 'HWC') { // Do nothing } else { throw new Error(`Unsupported channel format: ${channel_format}`); } if (!(tensor.data instanceof Uint8ClampedArray || tensor.data instanceof Uint8Array)) { throw new Error(`Unsupported tensor type: ${tensor.type}`); } switch (tensor.dims[2]) { case 1: case 2: case 3: case 4: return new RawImage(tensor.data, tensor.dims[1], tensor.dims[0], tensor.dims[2]); default: throw new Error(`Unsupported number of channels: ${tensor.dims[2]}`); } } /** * Convert the image to grayscale format. * @returns {RawImage} `this` to support chaining. */ grayscale() { if (this.channels === 1) { return this; } const newData = new Uint8ClampedArray(this.width * this.height * 1); switch (this.channels) { case 3: // rgb to grayscale case 4: // rgba to grayscale for (let i = 0, offset = 0; i < this.data.length; i += this.channels) { const red = this.data[i]; const green = this.data[i + 1]; const blue = this.data[i + 2]; newData[offset++] = Math.round(0.2989 * red + 0.5870 * green + 0.1140 * blue); } break; default: throw new Error(`Conversion failed due to unsupported number of channels: ${this.channels}`); } return this._update(newData, this.width, this.height, 1); } /** * Convert the image to RGB format. * @returns {RawImage} `this` to support chaining. */ rgb() { if (this.channels === 3) { return this; } const newData = new Uint8ClampedArray(this.width * this.height * 3); switch (this.channels) { case 1: // grayscale to rgb for (let i = 0, offset = 0; i < this.data.length; ++i) { newData[offset++] = this.data[i]; newData[offset++] = this.data[i]; newData[offset++] = this.data[i]; } break; case 4: // rgba to rgb for (let i = 0, offset = 0; i < this.data.length; i += 4) { newData[offset++] = this.data[i]; newData[offset++] = this.data[i + 1]; newData[offset++] = this.data[i + 2]; } break; default: throw new Error(`Conversion failed due to unsupported number of channels: ${this.channels}`); } return this._update(newData, this.width, this.height, 3); } /** * Convert the image to RGBA format. * @returns {RawImage} `this` to support chaining. */ rgba() { if (this.channels === 4) { return this; } const newData = new Uint8ClampedArray(this.width * this.height * 4); switch (this.channels) { case 1: // grayscale to rgba for (let i = 0, offset = 0; i < this.data.length; ++i) { newData[offset++] = this.data[i]; newData[offset++] = this.data[i]; newData[offset++] = this.data[i]; newData[offset++] = 255; } break; case 3: // rgb to rgba for (let i = 0, offset = 0; i < this.data.length; i += 3) { newData[offset++] = this.data[i]; newData[offset++] = this.data[i + 1]; newData[offset++] = this.data[i + 2]; newData[offset++] = 255; } break; default: throw new Error(`Conversion failed due to unsupported number of channels: ${this.channels}`); } return this._update(newData, this.width, this.height, 4); } /** * Apply an alpha mask to the image. Operates in place. * @param {RawImage} mask The mask to apply. It should have a single channel. * @returns {RawImage} The masked image. * @throws {Error} If the mask is not the same size as the image. * @throws {Error} If the image does not have 4 channels. * @throws {Error} If the mask is not a single channel. */ putAlpha(mask) { if (mask.width !== this.width || mask.height !== this.height) { throw new Error(`Expected mask size to be ${this.width}x${this.height}, but got ${mask.width}x${mask.height}`); } if (mask.channels !== 1) { throw new Error(`Expected mask to have 1 channel, but got ${mask.channels}`); } const this_data = this.data; const mask_data = mask.data; const num_pixels = this.width * this.height; if (this.channels === 3) { // Convert to RGBA and simultaneously apply mask to alpha channel const newData = new Uint8ClampedArray(num_pixels * 4); for (let i = 0, in_offset = 0, out_offset = 0; i < num_pixels; ++i) { newData[out_offset++] = this_data[in_offset++]; newData[out_offset++] = this_data[in_offset++]; newData[out_offset++] = this_data[in_offset++]; newData[out_offset++] = mask_data[i]; } return this._update(newData, this.width, this.height, 4); } else if (this.channels === 4) { // Apply mask to alpha channel in place for (let i = 0; i < num_pixels; ++i) { this_data[4 * i + 3] = mask_data[i]; } return this; } throw new Error(`Expected image to have 3 or 4 channels, but got ${this.channels}`); } /** * Resize the image to the given dimensions. This method uses the canvas API to perform the resizing. * @param {number} width The width of the new image. `null` or `-1` will preserve the aspect ratio. * @param {number} height The height of the new image. `null` or `-1` will preserve the aspect ratio. * @param {Object} options Additional options for resizing. * @param {0|1|2|3|4|5|string} [options.resample] The resampling method to use. * @returns {Promise<RawImage>} `this` to support chaining. */ async resize(width, height, { resample = 2, } = {}) { // Do nothing if the image already has the desired size if (this.width === width && this.height === height) { return this; } // Ensure resample method is a string let resampleMethod = RESAMPLING_MAPPING[resample] ?? resample; // Calculate width / height to maintain aspect ratio, in the event that // the user passed a null value in. // This allows users to pass in something like `resize(320, null)` to // resize to 320 width, but maintain aspect ratio. const nullish_width = isNullishDimension(width); const nullish_height = isNullishDimension(height); if (nullish_width && nullish_height) { return this; } else if (nullish_width) { width = (height / this.height) * this.width; } else if (nullish_height) { height = (width / this.width) * this.height; } if (IS_BROWSER_OR_WEBWORKER) { // TODO use `resample` in browser environment // Store number of channels before resizing const numChannels = this.channels; // Create canvas object for this image const canvas = this.toCanvas(); // Actually perform resizing using the canvas API const ctx = createCanvasFunction(width, height).getContext('2d'); // Draw image to context, resizing in the process ctx.drawImage(canvas, 0, 0, width, height); // Create image from the resized data const resizedImage = new RawImage(ctx.getImageData(0, 0, width, height).data, width, height, 4); // Convert back so that image has the same number of channels as before return resizedImage.convert(numChannels); } else { // Create sharp image from raw data, and resize let img = this.toSharp(); switch (resampleMethod) { case 'box': case 'hamming': if (resampleMethod === 'box' || resampleMethod === 'hamming') { console.warn(`Resampling method ${resampleMethod} is not yet supported. Using bilinear instead.`); resampleMethod = 'bilinear'; } case 'nearest': case 'bilinear': case 'bicubic': // Perform resizing using affine transform. // This matches how the python Pillow library does it. img = img.affine([width / this.width, 0, 0, height / this.height], { interpolator: resampleMethod }); break; case 'lanczos': // https://github.com/python-pillow/Pillow/discussions/5519 // https://github.com/lovell/sharp/blob/main/docs/api-resize.md img = img.resize({ width, height, fit: 'fill', kernel: 'lanczos3', // PIL Lanczos uses a kernel size of 3 }); break; default: throw new Error(`Resampling method ${resampleMethod} is not supported.`); } return await loadImageFunction(img); } } async pad([left, right, top, bottom]) { left = Math.max(left, 0); right = Math.max(right, 0); top = Math.max(top, 0); bottom = Math.max(bottom, 0); if (left === 0 && right === 0 && top === 0 && bottom === 0) { // No padding needed return this; } if (IS_BROWSER_OR_WEBWORKER) { // Store number of channels before padding const numChannels = this.channels; // Create canvas object for this image const canvas = this.toCanvas(); const newWidth = this.width + left + right; const newHeight = this.height + top + bottom; // Create a new canvas of the desired size. const ctx = createCanvasFunction(newWidth, newHeight).getContext('2d'); // Draw image to context, padding in the process ctx.drawImage(canvas, 0, 0, this.width, this.height, left, top, this.width, this.height ); // Create image from the padded data const paddedImage = new RawImage( ctx.getImageData(0, 0, newWidth, newHeight).data, newWidth, newHeight, 4 ); // Convert back so that image has the same number of channels as before return paddedImage.convert(numChannels); } else { const img = this.toSharp().extend({ left, right, top, bottom }); return await loadImageFunction(img); } } async crop([x_min, y_min, x_max, y_max]) { // Ensure crop bounds are within the image x_min = Math.max(x_min, 0); y_min = Math.max(y_min, 0); x_max = Math.min(x_max, this.width - 1); y_max = Math.min(y_max, this.height - 1); // Do nothing if the crop is the entire image if (x_min === 0 && y_min === 0 && x_max === this.width - 1 && y_max === this.height - 1) { return this; } const crop_width = x_max - x_min + 1; const crop_height = y_max - y_min + 1; if (IS_BROWSER_OR_WEBWORKER) { // Store number of channels before resizing const numChannels = this.channels; // Create canvas object for this image const canvas = this.toCanvas(); // Create a new canvas of the desired size. This is needed since if the // image is too small, we need to pad it with black pixels. const ctx = createCanvasFunction(crop_width, crop_height).getContext('2d'); // Draw image to context, cropping in the process ctx.drawImage(canvas, x_min, y_min, crop_width, crop_height, 0, 0, crop_width, crop_height ); // Create image from the resized data const resizedImage = new RawImage(ctx.getImageData(0, 0, crop_width, crop_height).data, crop_width, crop_height, 4); // Convert back so that image has the same number of channels as before return resizedImage.convert(numChannels); } else { // Create sharp image from raw data const img = this.toSharp().extract({ left: x_min, top: y_min, width: crop_width, height: crop_height, }); return await loadImageFunction(img); } } async center_crop(crop_width, crop_height) { // If the image is already the desired size, return it if (this.width === crop_width && this.height === crop_height) { return this; } // Determine bounds of the image in the new canvas const width_offset = (this.width - crop_width) / 2; const height_offset = (this.height - crop_height) / 2; if (IS_BROWSER_OR_WEBWORKER) { // Store number of channels before resizing const numChannels = this.channels; // Create canvas object for this image const canvas = this.toCanvas(); // Create a new canvas of the desired size. This is needed since if the // image is too small, we need to pad it with black pixels. const ctx = createCanvasFunction(crop_width, crop_height).getContext('2d'); let sourceX = 0; let sourceY = 0; let destX = 0; let destY = 0; if (width_offset >= 0) { sourceX = width_offset; } else { destX = -width_offset; } if (height_offset >= 0) { sourceY = height_offset; } else { destY = -height_offset; } // Draw image to context, cropping in the process ctx.drawImage(canvas, sourceX, sourceY, crop_width, crop_height, destX, destY, crop_width, crop_height ); // Create image from the resized data const resizedImage = new RawImage(ctx.getImageData(0, 0, crop_width, crop_height).data, crop_width, crop_height, 4); // Convert back so that image has the same number of channels as before return resizedImage.convert(numChannels); } else { // Create sharp image from raw data let img = this.toSharp(); if (width_offset >= 0 && height_offset >= 0) { // Cropped image lies entirely within the original image img = img.extract({ left: Math.floor(width_offset), top: Math.floor(height_offset), width: crop_width, height: crop_height, }) } else if (width_offset <= 0 && height_offset <= 0) { // Cropped image lies entirely outside the original image, // so we add padding const top = Math.floor(-height_offset); const left = Math.floor(-width_offset); img = img.extend({ top: top, left: left, // Ensures the resulting image has the desired dimensions right: crop_width - this.width - left, bottom: crop_height - this.height - top, }); } else { // Cropped image lies partially outside the original image. // We first pad, then crop. let y_padding = [0, 0]; let y_extract = 0; if (height_offset < 0) { y_padding[0] = Math.floor(-height_offset); y_padding[1] = crop_height - this.height - y_padding[0]; } else { y_extract = Math.floor(height_offset); } let x_padding = [0, 0]; let x_extract = 0; if (width_offset < 0) { x_padding[0] = Math.floor(-width_offset); x_padding[1] = crop_width - this.width - x_padding[0]; } else { x_extract = Math.floor(width_offset); } img = img.extend({ top: y_padding[0], bottom: y_padding[1], left: x_padding[0], right: x_padding[1], }).extract({ left: x_extract, top: y_extract, width: crop_width, height: crop_height, }) } return await loadImageFunction(img); } } async toBlob(type = 'image/png', quality = 1) { if (!IS_BROWSER_OR_WEBWORKER) { throw new Error('toBlob() is only supported in browser environments.') } const canvas = this.toCanvas(); return await canvas.convertToBlob({ type, quality }); } toTensor(channel_format = 'CHW') { let tensor = new Tensor( 'uint8', new Uint8Array(this.data), [this.height, this.width, this.channels] ); if (channel_format === 'HWC') { // Do nothing } else if (channel_format === 'CHW') { // hwc -> chw tensor = tensor.permute(2, 0, 1); } else { throw new Error(`Unsupported channel format: ${channel_format}`); } return tensor; } toCanvas() { if (!IS_BROWSER_OR_WEBWORKER) { throw new Error('toCanvas() is only supported in browser environments.') } // Clone, and convert data to RGBA before drawing to canvas. // This is because the canvas API only supports RGBA const cloned = this.clone().rgba(); // Create canvas object for the cloned image const clonedCanvas = createCanvasFunction(cloned.width, cloned.height); // Draw image to context const data = new ImageDataClass(cloned.data, cloned.width, cloned.height); clonedCanvas.getContext('2d').putImageData(data, 0, 0); return clonedCanvas; } /** * Split this image into individual bands. This method returns an array of individual image bands from an image. * For example, splitting an "RGB" image creates three new images each containing a copy of one of the original bands (red, green, blue). * * Inspired by PIL's `Image.split()` [function](https://pillow.readthedocs.io/en/latest/reference/Image.html#PIL.Image.Image.split). * @returns {RawImage[]} An array containing bands. */ split() { const { data, width, height, channels } = this; /** @type {typeof Uint8Array | typeof Uint8ClampedArray} */ const data_type = /** @type {any} */(data.constructor); const per_channel_length = data.length / channels; // Pre-allocate buffers for each channel const split_data = Array.from( { length: channels }, () => new data_type(per_channel_length), ); // Write pixel data for (let i = 0; i < per_channel_length; ++i) { const data_offset = channels * i; for (let j = 0; j < channels; ++j) { split_data[j][i] = data[data_offset + j]; } } return split_data.map((data) => new RawImage(data, width, height, 1)); } /** * Helper method to update the image data. * @param {Uint8ClampedArray} data The new image data. * @param {number} width The new width of the image. * @param {number} height The new height of the image. * @param {1|2|3|4|null} [channels] The new number of channels of the image. * @private */ _update(data, width, height, channels = null) { this.data = data; this.width = width; this.height = height; if (channels !== null) { this.channels = channels; } return this; } /** * Clone the image * @returns {RawImage} The cloned image */ clone() { return new RawImage(this.data.slice(), this.width, this.height, this.channels); } /** * Helper method for converting image to have a certain number of channels * @param {number} numChannels The number of channels. Must be 1, 3, or 4. * @returns {RawImage} `this` to support chaining. */ convert(numChannels) { if (this.channels === numChannels) return this; // Already correct number of channels switch (numChannels) { case 1: this.grayscale(); break; case 3: this.rgb(); break; case 4: this.rgba(); break; default: throw new Error(`Conversion failed due to unsupported number of channels: ${this.channels}`); } return this; } /** * Save the image to the given path. * @param {string} path The path to save the image to. */ async save(path) { if (IS_BROWSER_OR_WEBWORKER) { if (apis.IS_WEBWORKER_ENV) { throw new Error('Unable to save an image from a Web Worker.') } const extension = path.split('.').pop().toLowerCase(); const mime = CONTENT_TYPE_MAP.get(extension) ?? 'image/png'; // Convert image to Blob const blob = await this.toBlob(mime); saveBlob(path, blob) } else if (!apis.IS_FS_AVAILABLE) { throw new Error('Unable to save the image because filesystem is disabled in this environment.') } else { const img = this.toSharp(); return await img.toFile(path); } } toSharp() { if (IS_BROWSER_OR_WEBWORKER) { throw new Error('toSharp() is only supported in server-side environments.') } return sharp(this.data, { raw: { width: this.width, height: this.height, channels: this.channels } }); } } /** * Helper function to load an image from a URL, path, etc. */ export const load_image = RawImage.read.bind(RawImage);