diff --git a/README.md b/README.md index 84b8c5992..b5918d98f 100644 --- a/README.md +++ b/README.md @@ -328,6 +328,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te 1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (from Microsoft), released together with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei. 1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. +1. **[ViTMatte](https://huggingface.co/docs/transformers/model_doc/vitmatte)** (from HUST-VL) released with the paper [ViTMatte: Boosting Image Matting with Pretrained Plain Vision Transformers](https://arxiv.org/abs/2305.15272) by Jingfeng Yao, Xinggang Wang, Shusheng Yang, Baoyuan Wang. 1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. 1. **[WavLM](https://huggingface.co/docs/transformers/model_doc/wavlm)** (from Microsoft Research) released with the paper [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei. 1. **[Whisper](https://huggingface.co/docs/transformers/model_doc/whisper)** (from OpenAI) released with the paper [Robust Speech Recognition via Large-Scale Weak Supervision](https://cdn.openai.com/papers/whisper.pdf) by Alec Radford, Jong Wook Kim, Tao Xu, Greg Brockman, Christine McLeavey, Ilya Sutskever. diff --git a/docs/snippets/6_supported-models.snippet b/docs/snippets/6_supported-models.snippet index 608c46042..53aa4190e 100644 --- a/docs/snippets/6_supported-models.snippet +++ b/docs/snippets/6_supported-models.snippet @@ -64,6 +64,7 @@ 1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (from Microsoft), released together with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei. 1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. +1. **[ViTMatte](https://huggingface.co/docs/transformers/model_doc/vitmatte)** (from HUST-VL) released with the paper [ViTMatte: Boosting Image Matting with Pretrained Plain Vision Transformers](https://arxiv.org/abs/2305.15272) by Jingfeng Yao, Xinggang Wang, Shusheng Yang, Baoyuan Wang. 1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. 1. **[WavLM](https://huggingface.co/docs/transformers/model_doc/wavlm)** (from Microsoft Research) released with the paper [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei. 1. **[Whisper](https://huggingface.co/docs/transformers/model_doc/whisper)** (from OpenAI) released with the paper [Robust Speech Recognition via Large-Scale Weak Supervision](https://cdn.openai.com/papers/whisper.pdf) by Alec Radford, Jong Wook Kim, Tao Xu, Greg Brockman, Christine McLeavey, Ilya Sutskever. diff --git a/scripts/supported_models.py b/scripts/supported_models.py index 8a114bb5b..856f68e44 100644 --- a/scripts/supported_models.py +++ b/scripts/supported_models.py @@ -743,6 +743,15 @@ 'google/vit-base-patch16-224', ], }, + 'vitmatte': { + # Image matting + 'image-matting': [ + 'hustvl/vitmatte-small-distinctions-646', + 'hustvl/vitmatte-base-distinctions-646', + 'hustvl/vitmatte-small-composition-1k', + 'hustvl/vitmatte-base-composition-1k', + ], + }, 'wav2vec2': { # Feature extraction # NOTE: requires --task feature-extraction 'feature-extraction': [ diff --git a/src/models.js b/src/models.js index ae1f5814e..20b750e0b 100644 --- a/src/models.js +++ b/src/models.js @@ -3310,6 +3310,74 @@ export class ViTForImageClassification extends ViTPreTrainedModel { } ////////////////////////////////////////////////// +////////////////////////////////////////////////// +export class VitMattePreTrainedModel extends PreTrainedModel { } + +/** + * ViTMatte framework leveraging any vision backbone e.g. for ADE20k, CityScapes. + * + * **Example:** Perform image matting with a `VitMatteForImageMatting` model. + * ```javascript + * import { AutoProcessor, VitMatteForImageMatting, RawImage } from '@xenova/transformers'; + * + * // Load processor and model + * const processor = await AutoProcessor.from_pretrained('Xenova/vitmatte-small-distinctions-646'); + * const model = await VitMatteForImageMatting.from_pretrained('Xenova/vitmatte-small-distinctions-646'); + * + * // Load image and trimap + * const image = await RawImage.fromURL('https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/vitmatte_image.png'); + * const trimap = await RawImage.fromURL('https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/vitmatte_trimap.png'); + * + * // Prepare image + trimap for the model + * const inputs = await processor(image, trimap); + * + * // Predict alpha matte + * const { alphas } = await model(inputs); + * // Tensor { + * // dims: [ 1, 1, 640, 960 ], + * // type: 'float32', + * // size: 614400, + * // data: Float32Array(614400) [ 0.9894027709960938, 0.9970508813858032, ... ] + * // } + * ``` + * + * You can visualize the alpha matte as follows: + * ```javascript + * import { Tensor, cat } from '@xenova/transformers'; + * + * // Visualize predicted alpha matte + * const imageTensor = new Tensor( + * 'uint8', + * new Uint8Array(image.data), + * [image.height, image.width, image.channels] + * ).transpose(2, 0, 1); + * + * // Convert float (0-1) alpha matte to uint8 (0-255) + * const alphaChannel = alphas + * .squeeze(0) + * .mul_(255) + * .clamp_(0, 255) + * .round_() + * .to('uint8'); + * + * // Concatenate original image with predicted alpha + * const imageData = cat([imageTensor, alphaChannel], 0); + * + * // Save output image + * const outputImage = RawImage.fromTensor(imageData); + * outputImage.save('output.png'); + * ``` + */ +export class VitMatteForImageMatting extends VitMattePreTrainedModel { + /** + * @param {any} model_inputs + */ + async _call(model_inputs) { + return new ImageMattingOutput(await super._call(model_inputs)); + } +} +////////////////////////////////////////////////// + ////////////////////////////////////////////////// export class MobileViTPreTrainedModel extends PreTrainedModel { } export class MobileViTModel extends MobileViTPreTrainedModel { } @@ -4687,7 +4755,9 @@ const MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = new Map([ ['audio-spectrogram-transformer', ['ASTForAudioClassification', ASTForAudioClassification]], ]); - +const MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES = new Map([ + ['vitmatte', ['VitMatteForImageMatting', VitMatteForImageMatting]], +]); const MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = new Map([ ['swin2sr', ['Swin2SRForImageSuperResolution', Swin2SRForImageSuperResolution]], @@ -4713,6 +4783,7 @@ const MODEL_CLASS_TYPE_MAPPING = [ [MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, MODEL_TYPES.Vision2Seq], [MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], + [MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], @@ -4918,6 +4989,10 @@ export class AutoModelForDocumentQuestionAnswering extends PretrainedMixin { static MODEL_CLASS_MAPPINGS = [MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES]; } +export class AutoModelForImageMatting extends PretrainedMixin { + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES]; +} + export class AutoModelForImageToImage extends PretrainedMixin { static MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES]; } @@ -5037,3 +5112,14 @@ export class CausalLMOutputWithPast extends ModelOutput { this.past_key_values = past_key_values; } } + +export class ImageMattingOutput extends ModelOutput { + /** + * @param {Object} output The output of the model. + * @param {Tensor} output.alphas Estimated alpha values, of shape `(batch_size, num_channels, height, width)`. + */ + constructor({ alphas }) { + super(); + this.alphas = alphas; + } +} diff --git a/src/processors.js b/src/processors.js index 8750894a2..d38128ab4 100644 --- a/src/processors.js +++ b/src/processors.js @@ -36,7 +36,7 @@ import { } from './utils/maths.js'; -import { Tensor, transpose, cat, interpolate } from './utils/tensor.js'; +import { Tensor, transpose, cat, interpolate, stack } from './utils/tensor.js'; import { RawImage } from './utils/image.js'; import { @@ -222,7 +222,7 @@ export class ImageFeatureExtractor extends FeatureExtractor { this.do_resize = this.config.do_resize; this.do_thumbnail = this.config.do_thumbnail; this.size = this.config.size; - this.size_divisor = this.config.size_divisor; + this.size_divisibility = this.config.size_divisibility ?? this.config.size_divisor; this.do_center_crop = this.config.do_center_crop; this.crop_size = this.config.crop_size; @@ -232,7 +232,7 @@ export class ImageFeatureExtractor extends FeatureExtractor { this.pad_size = this.config.pad_size; this.do_pad = this.config.do_pad; - if (this.do_pad && !this.pad_size && this.size.width !== undefined && this.size.height !== undefined) { + if (this.do_pad && !this.pad_size && this.size && this.size.width !== undefined && this.size.height !== undefined) { // Should pad, but no pad size specified // We infer the pad size from the resize size this.pad_size = this.size @@ -407,10 +407,15 @@ export class ImageFeatureExtractor extends FeatureExtractor { * Preprocesses the given image. * * @param {RawImage} image The image to preprocess. + * @param {Object} overrides The overrides for the preprocessing options. * @returns {Promise} The preprocessed image. */ - async preprocess(image) { - + async preprocess(image, { + do_normalize = null, + do_pad = null, + do_convert_rgb = null, + do_convert_grayscale = null, + } = {}) { if (this.do_crop_margin) { // NOTE: Specific to nougat processors. This is done before resizing, // and can be interpreted as a pre-preprocessing step. @@ -421,8 +426,10 @@ export class ImageFeatureExtractor extends FeatureExtractor { const srcHeight = image.height; // original height // Convert image to RGB if specified in config. - if (this.do_convert_rgb) { + if (do_convert_rgb ?? this.do_convert_rgb) { image = image.rgb(); + } else if (do_convert_grayscale) { + image = image.grayscale(); } // Resize all images @@ -485,10 +492,10 @@ export class ImageFeatureExtractor extends FeatureExtractor { resample: this.resample, }); - } else if (this.size_divisor !== undefined) { - // Rounds the height and width down to the closest multiple of size_divisor - const newWidth = Math.floor(srcWidth / this.size_divisor) * this.size_divisor; - const newHeight = Math.floor(srcHeight / this.size_divisor) * this.size_divisor; + } else if (this.size_divisibility !== undefined) { + // Rounds the height and width down to the closest multiple of size_divisibility + const newWidth = Math.floor(srcWidth / this.size_divisibility) * this.size_divisibility; + const newHeight = Math.floor(srcHeight / this.size_divisibility) * this.size_divisibility; image = await image.resize(newWidth, newHeight, { resample: this.resample, }); @@ -530,7 +537,7 @@ export class ImageFeatureExtractor extends FeatureExtractor { } } - if (this.do_normalize) { + if (do_normalize ?? this.do_normalize) { let image_mean = this.image_mean; if (!Array.isArray(this.image_mean)) { image_mean = new Array(image.channels).fill(image_mean); @@ -553,7 +560,7 @@ export class ImageFeatureExtractor extends FeatureExtractor { } // do padding after rescaling/normalizing - if (this.do_pad && this.pad_size) { + if (do_pad ?? (this.do_pad && this.pad_size)) { const padded = this.pad_image(pixelData, [image.width, image.height, image.channels], this.pad_size); [pixelData, imgDims] = padded; // Update pixel data and image dimensions } @@ -572,10 +579,10 @@ export class ImageFeatureExtractor extends FeatureExtractor { } /** - * Calls the feature extraction process on an array of image - * URLs, preprocesses each image, and concatenates the resulting + * Calls the feature extraction process on an array of images, + * preprocesses each image, and concatenates the resulting * features into a single Tensor. - * @param {any[]} images The URL(s) of the image(s) to extract features from. + * @param {RawImage[]} images The image(s) to extract features from. * @param {...any} args Additional arguments. * @returns {Promise} An object containing the concatenated pixel values (and other metadata) of the preprocessed images. */ @@ -586,12 +593,8 @@ export class ImageFeatureExtractor extends FeatureExtractor { /** @type {PreprocessedImage[]} */ const imageData = await Promise.all(images.map(x => this.preprocess(x))); - // TODO: - - // Concatenate pixel values - // TEMP: Add batch dimension so that concat works - imageData.forEach(x => x.pixel_values.dims = [1, ...x.pixel_values.dims]); - const pixel_values = cat(imageData.map(x => x.pixel_values)); + // Stack pixel values + const pixel_values = stack(imageData.map(x => x.pixel_values), 0); return { pixel_values: pixel_values, @@ -632,7 +635,7 @@ export class DonutFeatureExtractor extends ImageFeatureExtractor { } let image_std = this.image_std; - if (!Array.isArray(this.image_std)) { + if (!Array.isArray(image_std)) { image_std = new Array(imageChannels).fill(image_mean); } @@ -663,13 +666,13 @@ export class NougatImageProcessor extends DonutFeatureExtractor { } // NOTE exte */ export class DetrFeatureExtractor extends ImageFeatureExtractor { /** - * Calls the feature extraction process on an array of image URLs, preprocesses + * Calls the feature extraction process on an array of images, preprocesses * each image, and concatenates the resulting features into a single Tensor. - * @param {any[]} urls The URL(s) of the image(s) to extract features from. + * @param {RawImage[]} images The image(s) to extract features from. * @returns {Promise} An object containing the concatenated pixel values of the preprocessed images. */ - async _call(urls) { - const result = await super._call(urls); + async _call(images) { + const result = await super._call(images); // TODO support differently-sized images, for now assume all images are the same size. // TODO support different mask sizes (not just 64x64) @@ -991,7 +994,7 @@ export class YolosFeatureExtractor extends ImageFeatureExtractor { export class SamImageProcessor extends ImageFeatureExtractor { /** - * @param {any[]} images The URL(s) of the image(s) to extract features from. + * @param {RawImage[]} images The image(s) to extract features from. * @param {*} input_points A 3D or 4D array, representing the input points provided by the user. * - 3D: `[point_batch_size, nb_points_per_image, 2]`. In this case, `batch_size` is assumed to be 1. * - 4D: `[batch_size, point_batch_size, nb_points_per_image, 2]`. @@ -1146,6 +1149,47 @@ export class Swin2SRImageProcessor extends ImageFeatureExtractor { } } +export class VitMatteImageProcessor extends ImageFeatureExtractor { + /** + * Calls the feature extraction process on an array of images, preprocesses + * each image, and concatenates the resulting features into a single Tensor. + * @param {RawImage[]} images The image(s) to extract features from. + * @param {RawImage[]} trimaps The trimaps(s) to extract features from. + * @returns {Promise} An object containing the concatenated pixel values of the preprocessed images. + */ + async _call(images, trimaps) { + if (!Array.isArray(images)) { + images = [images]; + } + if (!Array.isArray(trimaps)) { + trimaps = [trimaps]; + } + + const imageData = await Promise.all(images.map(x => this.preprocess(x))); + const trimapData = await Promise.all(trimaps.map(x => this.preprocess(x, { + do_normalize: false, + do_convert_rgb: false, + do_convert_grayscale: true, + }))); + + + // Stack pixel values + const pixel_values = stack(imageData.map( + // Concatenate images and trimaps + (x, i) => cat([x.pixel_values, trimapData[i].pixel_values], 0) + ), 0); + + return { + pixel_values: pixel_values, + + // Original sizes of images + original_sizes: imageData.map(x => x.original_size), + + // Reshaped sizes of images, before padding or cropping + reshaped_input_sizes: imageData.map(x => x.reshaped_input_size), + } + } +} export class WhisperFeatureExtractor extends FeatureExtractor { @@ -1549,7 +1593,7 @@ export class Processor extends Callable { * @returns {Promise} A Promise that resolves with the extracted features. */ async _call(input, ...args) { - return await this.feature_extractor(input); + return await this.feature_extractor(input, ...args); } } @@ -1663,6 +1707,7 @@ export class AutoProcessor { DonutFeatureExtractor, NougatImageProcessor, + VitMatteImageProcessor, SamImageProcessor, Swin2SRImageProcessor, Wav2Vec2FeatureExtractor, diff --git a/tests/processors.test.js b/tests/processors.test.js index 73d5b3e86..ed83f775d 100644 --- a/tests/processors.test.js +++ b/tests/processors.test.js @@ -43,6 +43,7 @@ describe('Processors', () => { nougat: 'facebook/nougat-small', owlvit: 'google/owlvit-base-patch32', clip: 'openai/clip-vit-base-patch16', + vitmatte: 'hustvl/vitmatte-small-distinctions-646', dinov2: 'facebook/dinov2-small-imagenet1k-1-layer', } @@ -56,6 +57,9 @@ describe('Processors', () => { // grayscale image skateboard: 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/ml-web-games/skateboard.png', + + vitmatte_image: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/vitmatte_image.png', + vitmatte_trimap: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/vitmatte_trimap.png', } // Swin2SRImageProcessor @@ -338,6 +342,33 @@ describe('Processors', () => { } }, MAX_TEST_EXECUTION_TIME); + // VitMatteImageProcessor + // - tests custom overrides + // - tests multiple inputs + it(MODELS.vitmatte, async () => { + const processor = await AutoProcessor.from_pretrained(m(MODELS.vitmatte)) + + { + const image = await load_image(TEST_IMAGES.vitmatte_image); + const image2 = await load_image(TEST_IMAGES.vitmatte_trimap); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image, image2); + + compare(pixel_values.dims, [1, 4, 640, 960]); + expect(avg(pixel_values.data)).toBeCloseTo(-0.4028555154800415); + expect(pixel_values.data[0]).toBeCloseTo(-0.9921568632125854); + expect(pixel_values.data[1]).toBeCloseTo(-0.9921568632125854); + expect(pixel_values.data[5]).toBeCloseTo(-1.0); + expect(pixel_values.data[640]).toBeCloseTo(-0.6784313917160034); + expect(pixel_values.data[641]).toBeCloseTo(-0.6705882549285889); + expect(pixel_values.data[640 * 960]).toBeCloseTo(-1.0); + expect(pixel_values.data[640 * 960 + 1]).toBeCloseTo(-1.0); + expect(pixel_values.data.at(-1)).toBeCloseTo(0.0); + + compare(original_sizes, [[640, 960]]); + compare(reshaped_input_sizes, [[640, 960]]); + } + }, MAX_TEST_EXECUTION_TIME); + // BitImageProcessor it(MODELS.dinov2, async () => { const processor = await AutoProcessor.from_pretrained(m(MODELS.dinov2))