Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added image-to-image task for ORT Pipeline #2031

Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add ORTModelForImageToImage for image-to-image task SwinSR
h3110Fr13nd committed Sep 19, 2024
commit c41cdbfd4b94f8ac05654c933de129f2b109dc76
2 changes: 2 additions & 0 deletions optimum/onnxruntime/__init__.py
Original file line number Diff line number Diff line change
@@ -44,6 +44,7 @@
"ORTModelForSemanticSegmentation",
"ORTModelForSequenceClassification",
"ORTModelForTokenClassification",
"ORTModelForImageToImage",
],
"modeling_seq2seq": [
"ORTModelForSeq2SeqLM",
@@ -112,6 +113,7 @@
ORTModelForCustomTasks,
ORTModelForFeatureExtraction,
ORTModelForImageClassification,
ORTModelForImageToImage,
ORTModelForMaskedLM,
ORTModelForMultipleChoice,
ORTModelForQuestionAnswering,
74 changes: 74 additions & 0 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
@@ -34,6 +34,7 @@
AutoModelForAudioXVector,
AutoModelForCTC,
AutoModelForImageClassification,
AutoModelForImageToImage,
AutoModelForMaskedLM,
AutoModelForMultipleChoice,
AutoModelForQuestionAnswering,
@@ -47,6 +48,7 @@
BaseModelOutput,
CausalLMOutput,
ImageClassifierOutput,
ImageSuperResolutionOutput,
MaskedLMOutput,
ModelOutput,
MultipleChoiceModelOutput,
@@ -86,6 +88,7 @@

_TOKENIZER_FOR_DOC = "AutoTokenizer"
_FEATURE_EXTRACTOR_FOR_DOC = "AutoFeatureExtractor"
_PROCESSOR_FOR_IMAGE = "AutoImageProcessor"
_PROCESSOR_FOR_DOC = "AutoProcessor"

ONNX_MODEL_END_DOCSTRING = r"""
@@ -2183,6 +2186,77 @@ def forward(
return TokenClassifierOutput(logits=logits)


IMAGE_TO_IMAGE_EXAMPLE = r"""
Example of image-to-image (Super Resolution):

```python
>>> from transformers import {processor_class}
>>> from optimum.onnxruntime import {model_class}
>>> from PIL import Image

>>> image = Image.open("path/to/image.jpg")

>>> image_processor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")

>>> inputs = image_processor(images=image, return_tensors="pt")

>>> with torch.no_grad():
... logits = model(**inputs).logits
```
"""


@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTModelForImageToImage(ORTModel):
"""
ONNX Model for image-to-image tasks. This class officially supports pix2pix, cyclegan, wav2vec2, wav2vec2-conformer.
"""

auto_model_class = AutoModelForImageToImage

@add_start_docstrings_to_model_forward(
ONNX_IMAGE_INPUTS_DOCSTRING.format("batch_size, num_channels, height, width")
+ IMAGE_TO_IMAGE_EXAMPLE.format(
processor_class=_PROCESSOR_FOR_IMAGE,
model_class="ORTModelForImgageToImage",
checkpoint="caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr",
)
)
def forward(
self,
pixel_values: Union[torch.Tensor, np.ndarray],
**kwargs,
):
use_torch = isinstance(pixel_values, torch.Tensor)
self.raise_on_numpy_input_io_binding(use_torch)
if self.device.type == "cuda" and self.use_io_binding:
input_shapes = pixel_values.shape
io_binding, output_shapes, output_buffers = self.prepare_io_binding(
pixel_values,
ordered_input_names=self._ordered_input_names,
known_output_shapes={
"reconstruction": [
input_shapes[0],
input_shapes[1],
input_shapes[2] * self.config.upscale,
input_shapes[3] * self.config.upscale,
]
},
)
io_binding.synchronize_inputs()
self.model.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()
reconstruction = output_buffers["reconstruction"].view(output_shapes["reconstruction"])
else:
model_inputs = {"pixel_values": pixel_values}
onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_outputs = self.model.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)
reconstruction = model_outputs["reconstruction"]
return ImageSuperResolutionOutput(reconstruction=reconstruction)


CUSTOM_TASKS_EXAMPLE = r"""
Example of custom tasks(e.g. a sentence transformers taking `pooler_output` as output):