From e06d47ad5e6401deddb922a37375654472ce3f8a Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Tue, 15 Oct 2024 15:54:06 +0200 Subject: [PATCH 1/2] Fix compatibility with diffusers < 0.25.0 --- optimum/onnxruntime/modeling_diffusion.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/optimum/onnxruntime/modeling_diffusion.py b/optimum/onnxruntime/modeling_diffusion.py index 87fcb68c7e..5657a12e9e 100644 --- a/optimum/onnxruntime/modeling_diffusion.py +++ b/optimum/onnxruntime/modeling_diffusion.py @@ -26,7 +26,6 @@ import numpy as np import torch from diffusers.configuration_utils import ConfigMixin -from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution from diffusers.pipelines import ( AutoPipelineForImage2Image, AutoPipelineForInpainting, @@ -66,6 +65,7 @@ from .modeling_ort import ONNX_MODEL_END_DOCSTRING, ORTModel from .utils import ( ONNX_WEIGHTS_NAME, + check_if_diffusers_greater, get_provider_for_device, np_to_pt_generators, parse_device, @@ -73,6 +73,12 @@ ) +if check_if_diffusers_greater("0.25.0"): + from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution +else: + from diffusers.models.vae import DiagonalGaussianDistribution + + logger = logging.getLogger(__name__) From 222cbd2142b3917de3752e82595e2567f01cf81a Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Tue, 15 Oct 2024 17:54:22 +0200 Subject: [PATCH 2/2] fix import --- optimum/onnxruntime/modeling_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/onnxruntime/modeling_diffusion.py b/optimum/onnxruntime/modeling_diffusion.py index 5657a12e9e..3899a7b36b 100644 --- a/optimum/onnxruntime/modeling_diffusion.py +++ b/optimum/onnxruntime/modeling_diffusion.py @@ -51,6 +51,7 @@ from transformers.modeling_outputs import ModelOutput import onnxruntime as ort +from optimum.utils import check_if_diffusers_greater from ..exporters.onnx import main_export from ..onnx.utils import _get_model_external_data_paths @@ -65,7 +66,6 @@ from .modeling_ort import ONNX_MODEL_END_DOCSTRING, ORTModel from .utils import ( ONNX_WEIGHTS_NAME, - check_if_diffusers_greater, get_provider_for_device, np_to_pt_generators, parse_device,