Skip to content

Commit e7c9daf

Browse files
added textual inversion
1 parent 95a80f0 commit e7c9daf

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

optimum/intel/openvino/modeling_diffusion.py

+21-8
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from ...exporters.openvino import main_export
6666
from .configuration import OVConfig, OVQuantizationMethod, OVWeightQuantizationConfig
6767
from .modeling_base import OVBaseModel
68+
from .loaders import OVTextualInversionLoaderMixin
6869
from .utils import (
6970
ONNX_WEIGHTS_NAME,
7071
OV_TO_PT_TYPE,
@@ -1010,7 +1011,7 @@ def to(self, *args, **kwargs):
10101011
self.encoder.to(*args, **kwargs)
10111012

10121013

1013-
class OVStableDiffusionPipeline(OVDiffusionPipeline, StableDiffusionPipeline):
1014+
class OVStableDiffusionPipeline(OVDiffusionPipeline, OVTextualInversionLoaderMixin, StableDiffusionPipeline):
10141015
"""
10151016
OpenVINO-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion#diffusers.StableDiffusionPipeline).
10161017
"""
@@ -1020,7 +1021,9 @@ class OVStableDiffusionPipeline(OVDiffusionPipeline, StableDiffusionPipeline):
10201021
auto_model_class = StableDiffusionPipeline
10211022

10221023

1023-
class OVStableDiffusionImg2ImgPipeline(OVDiffusionPipeline, StableDiffusionImg2ImgPipeline):
1024+
class OVStableDiffusionImg2ImgPipeline(
1025+
OVDiffusionPipeline, OVTextualInversionLoaderMixin, StableDiffusionImg2ImgPipeline
1026+
):
10241027
"""
10251028
OpenVINO-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionImg2ImgPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_img2img#diffusers.StableDiffusionImg2ImgPipeline).
10261029
"""
@@ -1030,7 +1033,9 @@ class OVStableDiffusionImg2ImgPipeline(OVDiffusionPipeline, StableDiffusionImg2I
10301033
auto_model_class = StableDiffusionImg2ImgPipeline
10311034

10321035

1033-
class OVStableDiffusionInpaintPipeline(OVDiffusionPipeline, StableDiffusionInpaintPipeline):
1036+
class OVStableDiffusionInpaintPipeline(
1037+
OVDiffusionPipeline, OVTextualInversionLoaderMixin, StableDiffusionInpaintPipeline
1038+
):
10341039
"""
10351040
OpenVINO-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionInpaintPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_inpaint#diffusers.StableDiffusionInpaintPipeline).
10361041
"""
@@ -1040,7 +1045,7 @@ class OVStableDiffusionInpaintPipeline(OVDiffusionPipeline, StableDiffusionInpai
10401045
auto_model_class = StableDiffusionInpaintPipeline
10411046

10421047

1043-
class OVStableDiffusionXLPipeline(OVDiffusionPipeline, StableDiffusionXLPipeline):
1048+
class OVStableDiffusionXLPipeline(OVDiffusionPipeline, OVTextualInversionLoaderMixin, StableDiffusionXLPipeline):
10441049
"""
10451050
OpenVINO-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionXLPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLPipeline).
10461051
"""
@@ -1063,7 +1068,9 @@ def _get_add_time_ids(
10631068
return add_time_ids
10641069

10651070

1066-
class OVStableDiffusionXLImg2ImgPipeline(OVDiffusionPipeline, StableDiffusionXLImg2ImgPipeline):
1071+
class OVStableDiffusionXLImg2ImgPipeline(
1072+
OVDiffusionPipeline, OVTextualInversionLoaderMixin, StableDiffusionXLImg2ImgPipeline
1073+
):
10671074
"""
10681075
OpenVINO-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionXLImg2ImgPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLImg2ImgPipeline).
10691076
"""
@@ -1100,7 +1107,9 @@ def _get_add_time_ids(
11001107
return add_time_ids, add_neg_time_ids
11011108

11021109

1103-
class OVStableDiffusionXLInpaintPipeline(OVDiffusionPipeline, StableDiffusionXLInpaintPipeline):
1110+
class OVStableDiffusionXLInpaintPipeline(
1111+
OVDiffusionPipeline, OVTextualInversionLoaderMixin, StableDiffusionXLInpaintPipeline
1112+
):
11041113
"""
11051114
OpenVINO-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionXLInpaintPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLInpaintPipeline).
11061115
"""
@@ -1137,7 +1146,9 @@ def _get_add_time_ids(
11371146
return add_time_ids, add_neg_time_ids
11381147

11391148

1140-
class OVLatentConsistencyModelPipeline(OVDiffusionPipeline, LatentConsistencyModelPipeline):
1149+
class OVLatentConsistencyModelPipeline(
1150+
OVDiffusionPipeline, OVTextualInversionLoaderMixin, LatentConsistencyModelPipeline
1151+
):
11411152
"""
11421153
OpenVINO-powered stable diffusion pipeline corresponding to [diffusers.LatentConsistencyModelPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/latent_consistency#diffusers.LatentConsistencyModelPipeline).
11431154
"""
@@ -1147,7 +1158,9 @@ class OVLatentConsistencyModelPipeline(OVDiffusionPipeline, LatentConsistencyMod
11471158
auto_model_class = LatentConsistencyModelPipeline
11481159

11491160

1150-
class OVLatentConsistencyModelImg2ImgPipeline(OVDiffusionPipeline, LatentConsistencyModelImg2ImgPipeline):
1161+
class OVLatentConsistencyModelImg2ImgPipeline(
1162+
OVDiffusionPipeline, OVTextualInversionLoaderMixin, LatentConsistencyModelImg2ImgPipeline
1163+
):
11511164
"""
11521165
OpenVINO-powered stable diffusion pipeline corresponding to [diffusers.LatentConsistencyModelImg2ImgPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/latent_consistency_img2img#diffusers.LatentConsistencyModelImg2ImgPipeline).
11531166
"""

0 commit comments

Comments
 (0)