From c41cdbfd4b94f8ac05654c933de129f2b109dc76 Mon Sep 17 00:00:00 2001 From: h3110Fr13nd Date: Fri, 20 Sep 2024 04:53:22 +0530 Subject: [PATCH 1/6] Add ORTModelForImageToImage for image-to-image task SwinSR --- optimum/onnxruntime/__init__.py | 2 + optimum/onnxruntime/modeling_ort.py | 74 +++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+) diff --git a/optimum/onnxruntime/__init__.py b/optimum/onnxruntime/__init__.py index 09a48ec955..1cb5b7c47b 100644 --- a/optimum/onnxruntime/__init__.py +++ b/optimum/onnxruntime/__init__.py @@ -44,6 +44,7 @@ "ORTModelForSemanticSegmentation", "ORTModelForSequenceClassification", "ORTModelForTokenClassification", + "ORTModelForImageToImage", ], "modeling_seq2seq": [ "ORTModelForSeq2SeqLM", @@ -112,6 +113,7 @@ ORTModelForCustomTasks, ORTModelForFeatureExtraction, ORTModelForImageClassification, + ORTModelForImageToImage, ORTModelForMaskedLM, ORTModelForMultipleChoice, ORTModelForQuestionAnswering, diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index 254b771e33..e39ec14006 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -34,6 +34,7 @@ AutoModelForAudioXVector, AutoModelForCTC, AutoModelForImageClassification, + AutoModelForImageToImage, AutoModelForMaskedLM, AutoModelForMultipleChoice, AutoModelForQuestionAnswering, @@ -47,6 +48,7 @@ BaseModelOutput, CausalLMOutput, ImageClassifierOutput, + ImageSuperResolutionOutput, MaskedLMOutput, ModelOutput, MultipleChoiceModelOutput, @@ -86,6 +88,7 @@ _TOKENIZER_FOR_DOC = "AutoTokenizer" _FEATURE_EXTRACTOR_FOR_DOC = "AutoFeatureExtractor" +_PROCESSOR_FOR_IMAGE = "AutoImageProcessor" _PROCESSOR_FOR_DOC = "AutoProcessor" ONNX_MODEL_END_DOCSTRING = r""" @@ -2183,6 +2186,77 @@ def forward( return TokenClassifierOutput(logits=logits) +IMAGE_TO_IMAGE_EXAMPLE = r""" + Example of image-to-image (Super Resolution): + + ```python + >>> from transformers import {processor_class} + >>> from optimum.onnxruntime import {model_class} + >>> from PIL import Image + + >>> image = Image.open("path/to/image.jpg") + + >>> image_processor = {processor_class}.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> with torch.no_grad(): + ... logits = model(**inputs).logits + ``` +""" + + +@add_end_docstrings(ONNX_MODEL_END_DOCSTRING) +class ORTModelForImageToImage(ORTModel): + """ + ONNX Model for image-to-image tasks. This class officially supports pix2pix, cyclegan, wav2vec2, wav2vec2-conformer. + """ + + auto_model_class = AutoModelForImageToImage + + @add_start_docstrings_to_model_forward( + ONNX_IMAGE_INPUTS_DOCSTRING.format("batch_size, num_channels, height, width") + + IMAGE_TO_IMAGE_EXAMPLE.format( + processor_class=_PROCESSOR_FOR_IMAGE, + model_class="ORTModelForImgageToImage", + checkpoint="caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr", + ) + ) + def forward( + self, + pixel_values: Union[torch.Tensor, np.ndarray], + **kwargs, + ): + use_torch = isinstance(pixel_values, torch.Tensor) + self.raise_on_numpy_input_io_binding(use_torch) + if self.device.type == "cuda" and self.use_io_binding: + input_shapes = pixel_values.shape + io_binding, output_shapes, output_buffers = self.prepare_io_binding( + pixel_values, + ordered_input_names=self._ordered_input_names, + known_output_shapes={ + "reconstruction": [ + input_shapes[0], + input_shapes[1], + input_shapes[2] * self.config.upscale, + input_shapes[3] * self.config.upscale, + ] + }, + ) + io_binding.synchronize_inputs() + self.model.run_with_iobinding(io_binding) + io_binding.synchronize_outputs() + reconstruction = output_buffers["reconstruction"].view(output_shapes["reconstruction"]) + else: + model_inputs = {"pixel_values": pixel_values} + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.model.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + reconstruction = model_outputs["reconstruction"] + return ImageSuperResolutionOutput(reconstruction=reconstruction) + + CUSTOM_TASKS_EXAMPLE = r""" Example of custom tasks(e.g. a sentence transformers taking `pooler_output` as output): From 30eb97e7039b4ff8a0aa5fb792eef4147b1cec0a Mon Sep 17 00:00:00 2001 From: h3110Fr13nd Date: Fri, 20 Sep 2024 04:57:03 +0530 Subject: [PATCH 2/6] Added image-to-image task to optimum pipeline --- optimum/onnxruntime/runs/__init__.py | 6 +++--- optimum/pipelines/pipelines_base.py | 8 ++++++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/optimum/onnxruntime/runs/__init__.py b/optimum/onnxruntime/runs/__init__.py index 1d98294934..d21db2a4ac 100644 --- a/optimum/onnxruntime/runs/__init__.py +++ b/optimum/onnxruntime/runs/__init__.py @@ -110,9 +110,9 @@ def __init__(self, run_config): model_class = FeaturesManager.get_model_class_for_feature(get_autoclass_name(self.task)) self.torch_model = model_class.from_pretrained(run_config["model_name_or_path"]) - self.return_body[ - "model_type" - ] = self.torch_model.config.model_type # return_body is initialized in parent class + self.return_body["model_type"] = ( + self.torch_model.config.model_type + ) # return_body is initialized in parent class def _launch_time(self, trial): batch_size = trial.suggest_categorical("batch_size", self.batch_sizes) diff --git a/optimum/pipelines/pipelines_base.py b/optimum/pipelines/pipelines_base.py index a08ab8782a..4b4ef310d7 100644 --- a/optimum/pipelines/pipelines_base.py +++ b/optimum/pipelines/pipelines_base.py @@ -24,6 +24,7 @@ FillMaskPipeline, ImageClassificationPipeline, ImageSegmentationPipeline, + ImageToImagePipeline, ImageToTextPipeline, Pipeline, PreTrainedTokenizer, @@ -55,6 +56,7 @@ ORTModelForCausalLM, ORTModelForFeatureExtraction, ORTModelForImageClassification, + ORTModelForImageToImage, ORTModelForMaskedLM, ORTModelForQuestionAnswering, ORTModelForSemanticSegmentation, @@ -157,6 +159,12 @@ "default": "superb/hubert-base-superb-ks", "type": "audio", }, + "image-to-image": { + "impl": ImageToImagePipeline, + "class": (ORTModelForImageToImage,), + "default": "h3110Fr13nd/swin2sr-lightweight-2x-onnx", + "type": "image", + }, } else: ORT_SUPPORTED_TASKS = {} From d762ea244220526abcf471f4a4780e59ceb1357e Mon Sep 17 00:00:00 2001 From: h3110Fr13nd Date: Tue, 24 Sep 2024 03:42:11 +0530 Subject: [PATCH 3/6] Add Tests fpr ORTModelForImageToImage for image-to-image task SwinSR --- optimum/onnxruntime/__init__.py | 2 +- optimum/onnxruntime/runs/__init__.py | 6 +- tests/onnxruntime/test_modeling.py | 136 ++++++++++++++++++- tests/onnxruntime/utils_onnxruntime_tests.py | 1 + 4 files changed, 140 insertions(+), 5 deletions(-) diff --git a/optimum/onnxruntime/__init__.py b/optimum/onnxruntime/__init__.py index 1cb5b7c47b..471fb65e05 100644 --- a/optimum/onnxruntime/__init__.py +++ b/optimum/onnxruntime/__init__.py @@ -15,7 +15,7 @@ from transformers.utils import OptionalDependencyNotAvailable, _LazyModule -from ..utils import is_diffusers_available +from optimum.utils import is_diffusers_available _import_structure = { diff --git a/optimum/onnxruntime/runs/__init__.py b/optimum/onnxruntime/runs/__init__.py index d21db2a4ac..1d98294934 100644 --- a/optimum/onnxruntime/runs/__init__.py +++ b/optimum/onnxruntime/runs/__init__.py @@ -110,9 +110,9 @@ def __init__(self, run_config): model_class = FeaturesManager.get_model_class_for_feature(get_autoclass_name(self.task)) self.torch_model = model_class.from_pretrained(run_config["model_name_or_path"]) - self.return_body["model_type"] = ( - self.torch_model.config.model_type - ) # return_body is initialized in parent class + self.return_body[ + "model_type" + ] = self.torch_model.config.model_type # return_body is initialized in parent class def _launch_time(self, trial): batch_size = trial.suggest_categorical("batch_size", self.batch_sizes) diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 199b96342e..2518adfe1e 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -42,6 +42,7 @@ AutoModelForCausalLM, AutoModelForCTC, AutoModelForImageClassification, + AutoModelForImageToImage, AutoModelForMaskedLM, AutoModelForMultipleChoice, AutoModelForQuestionAnswering, @@ -57,7 +58,9 @@ PretrainedConfig, set_seed, ) +from transformers.modeling_outputs import ImageSuperResolutionOutput from transformers.modeling_utils import no_init_weights +from transformers.models.swin2sr.configuration_swin2sr import Swin2SRConfig from transformers.onnx.utils import get_preprocessor from transformers.testing_utils import get_gpu_count, require_torch_gpu, slow from utils_onnxruntime_tests import MODEL_NAMES, SEED, ORTModelTestMixin @@ -79,6 +82,7 @@ ORTModelForCustomTasks, ORTModelForFeatureExtraction, ORTModelForImageClassification, + ORTModelForImageToImage, ORTModelForMaskedLM, ORTModelForMultipleChoice, ORTModelForPix2Struct, @@ -4704,6 +4708,136 @@ def test_compare_generation_to_io_binding( gc.collect() +class ORTModelForImageToImageIntegrationTest(ORTModelTestMixin): + SUPPORTED_ARCHITECTURES = ["swin2sr"] + + ORTMODEL_CLASS = ORTModelForImageToImage + + TASK = "image-to-image" + + def _get_sample_image(self): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw) + return image + + def _get_preprocessors(self, model_id): + image_processor = AutoImageProcessor.from_pretrained(model_id) + + return image_processor + + def test_load_vanilla_transformers_which_is_not_supported(self): + with self.assertRaises(Exception) as context: + _ = ORTModelForImageToImage.from_pretrained(MODEL_NAMES["bert"], export=True) + + self.assertIn("only supports the tasks", str(context.exception)) + + @parameterized.expand( + grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [False], "use_merged": [False]}) + ) + def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool): + model_id = MODEL_NAMES[model_arch] + onnx_model = ORTModelForImageToImage.from_pretrained(model_id) + self.assertIsInstance(onnx_model.config, Swin2SRConfig) + set_seed(SEED) + + model_id_transformer = "caidas/swin2SR-lightweight-x2-64" + transformers_model = AutoModelForImageToImage.from_pretrained(model_id_transformer) + image_processor = self._get_preprocessors(model_id_transformer) + + data = self._get_sample_image() + features = image_processor(data, return_tensors="pt") + + with torch.no_grad(): + transformers_outputs = transformers_model(**features) + + onnx_outputs = onnx_model(**features) + self.assertIsInstance(onnx_outputs, ImageSuperResolutionOutput) + self.assertTrue("reconstruction" in onnx_outputs) + self.assertIsInstance(onnx_outputs.reconstruction, torch.Tensor) + self.assertTrue(torch.allclose(onnx_outputs.reconstruction, transformers_outputs.reconstruction, atol=1e-4)) + + gc.collect() + + @parameterized.expand( + grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [False], "use_merged": [False]}) + ) + def test_generate_utils(self, test_name: str, model_arch: str, use_cache: str, use_merged: str): + model_id = MODEL_NAMES[model_arch] + model = ORTModelForImageToImage.from_pretrained(model_id) + image_processor = self._get_preprocessors(model_id) + + data = self._get_sample_image() + features = image_processor(data, return_tensors="pt") + + outputs = model(**features) + self.assertIsInstance(outputs, ImageSuperResolutionOutput) + + gc.collect() + + @parameterized.expand( + grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [False], "use_merged": [False]}) + ) + def test_pipeline_image_to_image(self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool): + model_id = MODEL_NAMES[model_arch] + onnx_model = ORTModelForImageToImage.from_pretrained(model_id) + image_processor = self._get_preprocessors(model_id) + + pipe = pipeline( + "image-to-image", + model=onnx_model, + feature_extractor=image_processor, + ) + data = self._get_sample_image() + outputs = pipe(data) + self.assertEqual(pipe.device, onnx_model.device) + self.assertIsInstance(outputs, Image.Image) + + gc.collect() + + @parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [False]})) + @require_torch_gpu + @pytest.mark.cuda_ep_test + def test_pipeline_on_gpu(self, test_name: str, model_arch: str, use_cache: bool): + model_id = MODEL_NAMES[model_arch] + onnx_model = ORTModelForImageToImage.from_pretrained(model_id) + image_processor = self._get_preprocessors(model_id) + pipe = pipeline( + "image-to-image", + model=onnx_model, + feature_extractor=image_processor, + device=0, + ) + + data = self._get_sample_image() + outputs = pipe(data) + + self.assertEqual(pipe.model.device.type.lower(), "cuda") + self.assertIsInstance(outputs, Image.Image) + + @parameterized.expand( + grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [False], "use_merged": [False]}) + ) + @require_torch_gpu + @require_ort_rocm + @pytest.mark.rocm_ep_test + def test_pipeline_on_rocm(self, test_name: str, model_arch: str, use_cache: bool): + model_id = MODEL_NAMES[model_arch] + onnx_model = ORTModelForImageToImage.from_pretrained(model_id) + image_processor = self._get_preprocessors(model_id) + pipe = pipeline( + "image-to-image", + model=onnx_model, + feature_extractor=image_processor, + device=0, + ) + + data = self._get_sample_image() + outputs = pipe(data) + + self.assertEqual(pipe.model.device.type.lower(), "cuda") + self.assertIsInstance(outputs, Image.Image) + + class ORTModelForVision2SeqIntegrationTest(ORTModelTestMixin): SUPPORTED_ARCHITECTURES = ["vision-encoder-decoder", "trocr", "donut"] @@ -4831,7 +4965,6 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach len(onnx_outputs["past_key_values"][0]), len(transformers_outputs["past_key_values"][0]) ) for i in range(len(onnx_outputs["past_key_values"])): - print(onnx_outputs["past_key_values"][i]) for ort_pkv, trfs_pkv in zip( onnx_outputs["past_key_values"][i], transformers_outputs["past_key_values"][i] ): @@ -5517,6 +5650,7 @@ class TestBothExportersORTModel(unittest.TestCase): ["automatic-speech-recognition", ORTModelForCTCIntegrationTest], ["audio-xvector", ORTModelForAudioXVectorIntegrationTest], ["audio-frame-classification", ORTModelForAudioFrameClassificationIntegrationTest], + ["image-to-image", ORTModelForImageToImageIntegrationTest], ] ) def test_find_untested_architectures(self, task: str, test_class): diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index bb6935461d..f64f1e6c06 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -144,6 +144,7 @@ "stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl", "swin": "hf-internal-testing/tiny-random-SwinModel", "swin-window": "yujiepan/tiny-random-swin-patch4-window7-224", + "swin2sr": "h3110Fr13nd/swin2sr-lightweight-2x-onnx", "t5": "hf-internal-testing/tiny-random-t5", "table-transformer": "hf-internal-testing/tiny-random-TableTransformerModel", "trocr": "microsoft/trocr-small-handwritten", From 074dbca08f58e471de86c512c163f08a661f043c Mon Sep 17 00:00:00 2001 From: h3110Fr13nd Date: Tue, 24 Sep 2024 12:31:26 +0530 Subject: [PATCH 4/6] Use export=True for models from transformers, self._setup and more --- optimum/onnxruntime/__init__.py | 2 +- optimum/onnxruntime/modeling_ort.py | 2 +- tests/onnxruntime/test_modeling.py | 41 +++++++++++--------- tests/onnxruntime/utils_onnxruntime_tests.py | 2 +- 4 files changed, 25 insertions(+), 22 deletions(-) diff --git a/optimum/onnxruntime/__init__.py b/optimum/onnxruntime/__init__.py index 471fb65e05..1cb5b7c47b 100644 --- a/optimum/onnxruntime/__init__.py +++ b/optimum/onnxruntime/__init__.py @@ -15,7 +15,7 @@ from transformers.utils import OptionalDependencyNotAvailable, _LazyModule -from optimum.utils import is_diffusers_available +from ..utils import is_diffusers_available _import_structure = { diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index e39ec14006..dd152aca9e 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -2218,7 +2218,7 @@ class ORTModelForImageToImage(ORTModel): @add_start_docstrings_to_model_forward( ONNX_IMAGE_INPUTS_DOCSTRING.format("batch_size, num_channels, height, width") + IMAGE_TO_IMAGE_EXAMPLE.format( - processor_class=_PROCESSOR_FOR_IMAGE, + processor_class=_PROCESSOR_FOR_DOC, model_class="ORTModelForImgageToImage", checkpoint="caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr", ) diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 2518adfe1e..f4f05a42f3 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -4732,17 +4732,17 @@ def test_load_vanilla_transformers_which_is_not_supported(self): self.assertIn("only supports the tasks", str(context.exception)) @parameterized.expand( - grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [False], "use_merged": [False]}) + grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES}) ) - def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool): + def test_compare_to_transformers(self, test_name: str, model_arch: str): model_id = MODEL_NAMES[model_arch] - onnx_model = ORTModelForImageToImage.from_pretrained(model_id) + self._setup({"test_name": test_name, "model_arch": model_arch}) + onnx_model = ORTModelForImageToImage.from_pretrained(model_id, export=True) self.assertIsInstance(onnx_model.config, Swin2SRConfig) set_seed(SEED) - model_id_transformer = "caidas/swin2SR-lightweight-x2-64" - transformers_model = AutoModelForImageToImage.from_pretrained(model_id_transformer) - image_processor = self._get_preprocessors(model_id_transformer) + transformers_model = AutoModelForImageToImage.from_pretrained(model_id) + image_processor = self._get_preprocessors(model_id) data = self._get_sample_image() features = image_processor(data, return_tensors="pt") @@ -4759,11 +4759,12 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach gc.collect() @parameterized.expand( - grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [False], "use_merged": [False]}) + grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES}) ) - def test_generate_utils(self, test_name: str, model_arch: str, use_cache: str, use_merged: str): + def test_generate_utils(self, test_name: str, model_arch: str): model_id = MODEL_NAMES[model_arch] - model = ORTModelForImageToImage.from_pretrained(model_id) + self._setup({"test_name": test_name, "model_arch": model_arch}) + model = ORTModelForImageToImage.from_pretrained(model_id, export=True) image_processor = self._get_preprocessors(model_id) data = self._get_sample_image() @@ -4775,13 +4776,13 @@ def test_generate_utils(self, test_name: str, model_arch: str, use_cache: str, u gc.collect() @parameterized.expand( - grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [False], "use_merged": [False]}) + grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES}) ) - def test_pipeline_image_to_image(self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool): + def test_pipeline_image_to_image(self, test_name: str, model_arch: str): model_id = MODEL_NAMES[model_arch] - onnx_model = ORTModelForImageToImage.from_pretrained(model_id) + self._setup({"test_name": test_name, "model_arch": model_arch}) + onnx_model = ORTModelForImageToImage.from_pretrained(model_id, export=True) image_processor = self._get_preprocessors(model_id) - pipe = pipeline( "image-to-image", model=onnx_model, @@ -4794,12 +4795,13 @@ def test_pipeline_image_to_image(self, test_name: str, model_arch: str, use_cach gc.collect() - @parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [False]})) + @parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES})) @require_torch_gpu @pytest.mark.cuda_ep_test - def test_pipeline_on_gpu(self, test_name: str, model_arch: str, use_cache: bool): + def test_pipeline_on_gpu(self, test_name: str, model_arch: str): model_id = MODEL_NAMES[model_arch] - onnx_model = ORTModelForImageToImage.from_pretrained(model_id) + self._setup({"test_name": test_name, "model_arch": model_arch}) + onnx_model = ORTModelForImageToImage.from_pretrained(model_id, export=True) image_processor = self._get_preprocessors(model_id) pipe = pipeline( "image-to-image", @@ -4815,14 +4817,15 @@ def test_pipeline_on_gpu(self, test_name: str, model_arch: str, use_cache: bool) self.assertIsInstance(outputs, Image.Image) @parameterized.expand( - grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [False], "use_merged": [False]}) + grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES}) ) @require_torch_gpu @require_ort_rocm @pytest.mark.rocm_ep_test - def test_pipeline_on_rocm(self, test_name: str, model_arch: str, use_cache: bool): + def test_pipeline_on_rocm(self, test_name: str, model_arch: str): model_id = MODEL_NAMES[model_arch] - onnx_model = ORTModelForImageToImage.from_pretrained(model_id) + self._setup({"test_name": test_name, "model_arch": model_arch}) + onnx_model = ORTModelForImageToImage.from_pretrained(model_id, export=True) image_processor = self._get_preprocessors(model_id) pipe = pipeline( "image-to-image", diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index f64f1e6c06..0790f6329d 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -144,7 +144,7 @@ "stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl", "swin": "hf-internal-testing/tiny-random-SwinModel", "swin-window": "yujiepan/tiny-random-swin-patch4-window7-224", - "swin2sr": "h3110Fr13nd/swin2sr-lightweight-2x-onnx", + "swin2sr": "hf-internal-testing/tiny-random-Swin2SRForImageSuperResolution", "t5": "hf-internal-testing/tiny-random-t5", "table-transformer": "hf-internal-testing/tiny-random-TableTransformerModel", "trocr": "microsoft/trocr-small-handwritten", From b1b3e905e1896418785ebcd6d5b88caf53dec3e8 Mon Sep 17 00:00:00 2001 From: h3110Fr13nd Date: Tue, 24 Sep 2024 12:36:11 +0530 Subject: [PATCH 5/6] Code Refactor --- optimum/onnxruntime/modeling_ort.py | 1 - tests/onnxruntime/test_modeling.py | 16 ++++------------ 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index dd152aca9e..9166f7c2cb 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -88,7 +88,6 @@ _TOKENIZER_FOR_DOC = "AutoTokenizer" _FEATURE_EXTRACTOR_FOR_DOC = "AutoFeatureExtractor" -_PROCESSOR_FOR_IMAGE = "AutoImageProcessor" _PROCESSOR_FOR_DOC = "AutoProcessor" ONNX_MODEL_END_DOCSTRING = r""" diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index f4f05a42f3..f6765c600e 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -4731,9 +4731,7 @@ def test_load_vanilla_transformers_which_is_not_supported(self): self.assertIn("only supports the tasks", str(context.exception)) - @parameterized.expand( - grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES}) - ) + @parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES})) def test_compare_to_transformers(self, test_name: str, model_arch: str): model_id = MODEL_NAMES[model_arch] self._setup({"test_name": test_name, "model_arch": model_arch}) @@ -4758,9 +4756,7 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str): gc.collect() - @parameterized.expand( - grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES}) - ) + @parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES})) def test_generate_utils(self, test_name: str, model_arch: str): model_id = MODEL_NAMES[model_arch] self._setup({"test_name": test_name, "model_arch": model_arch}) @@ -4775,9 +4771,7 @@ def test_generate_utils(self, test_name: str, model_arch: str): gc.collect() - @parameterized.expand( - grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES}) - ) + @parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES})) def test_pipeline_image_to_image(self, test_name: str, model_arch: str): model_id = MODEL_NAMES[model_arch] self._setup({"test_name": test_name, "model_arch": model_arch}) @@ -4816,9 +4810,7 @@ def test_pipeline_on_gpu(self, test_name: str, model_arch: str): self.assertEqual(pipe.model.device.type.lower(), "cuda") self.assertIsInstance(outputs, Image.Image) - @parameterized.expand( - grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES}) - ) + @parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES})) @require_torch_gpu @require_ort_rocm @pytest.mark.rocm_ep_test From 31d74639c1a6e25562f2cda95542b0785b98caf6 Mon Sep 17 00:00:00 2001 From: h3110Fr13nd Date: Tue, 24 Sep 2024 13:21:34 +0530 Subject: [PATCH 6/6] Refactor ORTModelForImageToImageIntegrationTest --- optimum/pipelines/pipelines_base.py | 2 +- tests/onnxruntime/test_modeling.py | 47 ++++++++++++++++------------- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/optimum/pipelines/pipelines_base.py b/optimum/pipelines/pipelines_base.py index 4b4ef310d7..7690143f13 100644 --- a/optimum/pipelines/pipelines_base.py +++ b/optimum/pipelines/pipelines_base.py @@ -162,7 +162,7 @@ "image-to-image": { "impl": ImageToImagePipeline, "class": (ORTModelForImageToImage,), - "default": "h3110Fr13nd/swin2sr-lightweight-2x-onnx", + "default": "caidas/swin2SR-classical-sr-x2-64", "type": "image", }, } diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index f6765c600e..f6771ce761 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -4731,11 +4731,12 @@ def test_load_vanilla_transformers_which_is_not_supported(self): self.assertIn("only supports the tasks", str(context.exception)) - @parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES})) - def test_compare_to_transformers(self, test_name: str, model_arch: str): + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_compare_to_transformers(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) model_id = MODEL_NAMES[model_arch] - self._setup({"test_name": test_name, "model_arch": model_arch}) - onnx_model = ORTModelForImageToImage.from_pretrained(model_id, export=True) + onnx_model = ORTModelForImageToImage.from_pretrained(self.onnx_model_dirs[model_arch]) self.assertIsInstance(onnx_model.config, Swin2SRConfig) set_seed(SEED) @@ -4756,26 +4757,28 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str): gc.collect() - @parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES})) - def test_generate_utils(self, test_name: str, model_arch: str): + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_generate_utils(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) model_id = MODEL_NAMES[model_arch] - self._setup({"test_name": test_name, "model_arch": model_arch}) - model = ORTModelForImageToImage.from_pretrained(model_id, export=True) + onnx_model = ORTModelForImageToImage.from_pretrained(self.onnx_model_dirs[model_arch]) image_processor = self._get_preprocessors(model_id) data = self._get_sample_image() features = image_processor(data, return_tensors="pt") - outputs = model(**features) + outputs = onnx_model(**features) self.assertIsInstance(outputs, ImageSuperResolutionOutput) gc.collect() - @parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES})) - def test_pipeline_image_to_image(self, test_name: str, model_arch: str): + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_pipeline_image_to_image(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) model_id = MODEL_NAMES[model_arch] - self._setup({"test_name": test_name, "model_arch": model_arch}) - onnx_model = ORTModelForImageToImage.from_pretrained(model_id, export=True) + onnx_model = ORTModelForImageToImage.from_pretrained(self.onnx_model_dirs[model_arch]) image_processor = self._get_preprocessors(model_id) pipe = pipeline( "image-to-image", @@ -4789,13 +4792,14 @@ def test_pipeline_image_to_image(self, test_name: str, model_arch: str): gc.collect() - @parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES})) + @parameterized.expand(SUPPORTED_ARCHITECTURES) @require_torch_gpu @pytest.mark.cuda_ep_test - def test_pipeline_on_gpu(self, test_name: str, model_arch: str): + def test_pipeline_on_gpu(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) model_id = MODEL_NAMES[model_arch] - self._setup({"test_name": test_name, "model_arch": model_arch}) - onnx_model = ORTModelForImageToImage.from_pretrained(model_id, export=True) + onnx_model = ORTModelForImageToImage.from_pretrained(self.onnx_model_dirs[model_arch]) image_processor = self._get_preprocessors(model_id) pipe = pipeline( "image-to-image", @@ -4810,14 +4814,15 @@ def test_pipeline_on_gpu(self, test_name: str, model_arch: str): self.assertEqual(pipe.model.device.type.lower(), "cuda") self.assertIsInstance(outputs, Image.Image) - @parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES})) + @parameterized.expand(SUPPORTED_ARCHITECTURES) @require_torch_gpu @require_ort_rocm @pytest.mark.rocm_ep_test - def test_pipeline_on_rocm(self, test_name: str, model_arch: str): + def test_pipeline_on_rocm(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) model_id = MODEL_NAMES[model_arch] - self._setup({"test_name": test_name, "model_arch": model_arch}) - onnx_model = ORTModelForImageToImage.from_pretrained(model_id, export=True) + onnx_model = ORTModelForImageToImage.from_pretrained(self.onnx_model_dirs[model_arch]) image_processor = self._get_preprocessors(model_id) pipe = pipeline( "image-to-image",