transformers-fork
Version:
State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in your browser, with no need for a server!
220 lines (176 loc) • 9.04 kB
JavaScript
import {
ImageProcessor,
} from "../../base/image_processors_utils.js";
import { cat, full, interpolate_4d, stack } from "../../utils/tensor.js";
export class Idefics3ImageProcessor extends ImageProcessor {
constructor(config) {
super(config);
this.do_image_splitting = config.do_image_splitting ?? true;
this.max_image_size = config.max_image_size;
}
/**
* @typedef {import('../../utils/image.js').RawImage} RawImage
* @typedef {import('../../utils/tensor.js').Tensor} Tensor
*/
/**
* Calculate size to resize images to, to be multiples of `vision_encoder_max_size` while preserving the aspect ratio.
* @param {Tensor} pixel_values Tensor of the image to resize.
* @param {number} vision_encoder_max_size Maximum size of the output image. If the image is larger than this size,
* it will be split into patches of this size, and the original image will be concatenated with the patches, resized to max_size.
*/
get_resize_for_vision_encoder(pixel_values, vision_encoder_max_size) {
let [height, width] = pixel_values.dims.slice(-2);
const aspect_ratio = width / height;
if (width >= height) {
width = Math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size;
height = Math.floor(width / aspect_ratio);
height = Math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size;
} else {
height = Math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size;
width = Math.floor(height * aspect_ratio);
width = Math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size;
}
return { height, width };
}
/** @param {RawImage|RawImage[]|RawImage[][]} images */
async _call(images, {
do_image_splitting = null,
return_row_col_info = false,
} = {}) {
/** @type {RawImage[][]} */
let batched_2d_images;
if (!Array.isArray(images)) {
batched_2d_images = [[images]];
} else {
if (images.length === 0 || !images[0]) {
throw new Error("No images provided.");
}
if (!Array.isArray(images[0])) {
batched_2d_images = [/** @type {RawImage[]} */(images)];
} else {
batched_2d_images = /** @type {RawImage[][]} */(images);
}
}
// List of tensors, each with shape [patches, channels, height, width]
let all_pixel_values = [];
let images_list_rows = [];
let images_list_cols = [];
const original_sizes = [];
const reshaped_input_sizes = [];
for (const image_batch of batched_2d_images) {
let images_list = await Promise.all(image_batch.map(x => this.preprocess(x)));
// Original sizes of images
original_sizes.push(...images_list.map(x => x.original_size));
// Reshaped sizes of images, before padding or cropping
reshaped_input_sizes.push(...images_list.map(x => x.reshaped_input_size));
// Convert images to 4D tensors for easier processing
images_list.forEach(x => x.pixel_values.unsqueeze_(0));
const { longest_edge } = this.max_image_size;
/** @type {Tensor[]} */
let images_tensor;
if (do_image_splitting ?? this.do_image_splitting) {
let image_rows = new Array(images_list.length);
let image_cols = new Array(images_list.length);
// We first resize both height and width of each image to the nearest max_image_size multiple, disregarding the aspect ratio
images_tensor = await Promise.all(images_list.map(async (x, i) => {
const new_size = this.get_resize_for_vision_encoder(x.pixel_values, longest_edge);
const resized = await interpolate_4d(x.pixel_values, {
size: [new_size.height, new_size.width],
});
const { frames, num_splits_h, num_splits_w } = await this.split_image(resized, this.max_image_size);
image_rows[i] = num_splits_h;
image_cols[i] = num_splits_w;
return cat(frames, 0);
}));
images_list_rows.push(image_rows);
images_list_cols.push(image_cols);
} else {
/** @type {[number, number]} */
const size = [longest_edge, longest_edge];
images_tensor = await Promise.all(
images_list.map(x => interpolate_4d(x.pixel_values, { size }))
);
images_list_rows.push(new Array(images_list.length).fill(0));
images_list_cols.push(new Array(images_list.length).fill(0));
}
all_pixel_values.push(cat(images_tensor, 0));
}
const batch_size = all_pixel_values.length;
const [n, c, h, w] = all_pixel_values[0].dims;
// Stack pixel values
let pixel_values;
let pixel_attention_mask;
if (batch_size === 1) {
pixel_values = all_pixel_values[0].unsqueeze_(0);
pixel_attention_mask = full([batch_size, n, h, w], true);
} else {
// Add padding (if necessary) to images with less patches than the maximum number of patches
const max_num_patches = Math.max(...all_pixel_values.map(x => x.dims.at(0)));
pixel_attention_mask = full([batch_size, max_num_patches, h, w], true);
const pixel_attention_mask_data = pixel_attention_mask.data;
const pixel_attention_mask_stride = max_num_patches * h * w;
for (let i = 0; i < batch_size; ++i) {
const num_patches = all_pixel_values[i].dims[0];
if (num_patches < max_num_patches) {
all_pixel_values[i] = cat([
all_pixel_values[i],
full([max_num_patches - num_patches, c, h, w], 0),
], 0);
const start_offset = i * pixel_attention_mask_stride + num_patches * h * w;
const end_offset = (i + 1) * pixel_attention_mask_stride;
pixel_attention_mask_data.fill(false, start_offset, end_offset);
}
}
pixel_values = stack(all_pixel_values, 0);
}
return {
pixel_values,
pixel_attention_mask,
original_sizes,
reshaped_input_sizes,
...(
return_row_col_info
? { rows: images_list_rows, cols: images_list_cols }
: {}
),
}
}
async split_image(pixel_values, { longest_edge }) {
const max_height = longest_edge;
const max_width = longest_edge;
const frames = [];
const [height, width] = pixel_values.dims.slice(-2);
let num_splits_h = 0, num_splits_w = 0;
if (height > max_height || width > max_width) {
// Calculate the number of splits
num_splits_h = Math.ceil(height / max_height);
num_splits_w = Math.ceil(width / max_width);
// Calculate the optimal width and height for the sub-images
const optimal_height = Math.ceil(height / num_splits_h);
const optimal_width = Math.ceil(width / num_splits_w);
// Iterate through each row and column
for (let r = 0; r < num_splits_h; r++) {
for (let c = 0; c < num_splits_w; c++) {
// Calculate the starting point of the crop
const start_x = c * optimal_width;
const start_y = r * optimal_height;
// Calculate the ending point of the crop
const end_x = Math.min(start_x + optimal_width, width);
const end_y = Math.min(start_y + optimal_height, height);
// Crop the image
frames.push(pixel_values.slice(null, null, [start_y, end_y], [start_x, end_x]));
}
}
// Resize the global image to match max dimensions for memory efficiency
const global_image_height = max_height;
const global_image_width = max_width;
if (height !== global_image_height || width !== global_image_width) {
pixel_values = await interpolate_4d(pixel_values, {
size: [global_image_height, global_image_width],
})
}
}
frames.push(pixel_values);
return { frames, num_splits_h, num_splits_w };
}
}