@@ -98,8 +98,6 @@ def __init__(
98
98
ov_config : Optional [Dict [str , str ]] = None ,
99
99
model_save_dir : Optional [Union [str , Path , TemporaryDirectory ]] = None ,
100
100
quantization_config : Optional [Union [OVWeightQuantizationConfig , Dict ]] = None ,
101
- vae_decoder_ov_config : Optional [Dict [str , str ]] = None ,
102
- vae_encoder_ov_config : Optional [Dict [str , str ]] = None ,
103
101
** kwargs ,
104
102
):
105
103
self ._internal_dict = config
@@ -118,23 +116,15 @@ def __init__(
118
116
else :
119
117
self ._model_save_dir = model_save_dir
120
118
121
- default_vae_ov_config = deepcopy (self .ov_config )
122
- if "GPU" in self ._device :
123
- default_vae_ov_config .update ({"INFERENCE_PRECISION_HINT" : "f32" })
124
-
125
- self .vae_decoder = OVModelVaeDecoder (vae_decoder , self , vae_decoder_ov_config or default_vae_ov_config )
119
+ self .vae_decoder = OVModelVaeDecoder (vae_decoder , self )
126
120
self .unet = OVModelUnet (unet , self )
127
121
self .text_encoder = OVModelTextEncoder (text_encoder , self ) if text_encoder is not None else None
128
122
self .text_encoder_2 = (
129
123
OVModelTextEncoder (text_encoder_2 , self , model_name = DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER )
130
124
if text_encoder_2 is not None
131
125
else None
132
126
)
133
- self .vae_encoder = (
134
- OVModelVaeEncoder (vae_encoder , self , vae_encoder_ov_config or default_vae_ov_config )
135
- if vae_encoder is not None
136
- else None
137
- )
127
+ self .vae_encoder = OVModelVaeEncoder (vae_encoder , self ) if vae_encoder is not None else None
138
128
139
129
if "block_out_channels" in self .vae_decoder .config :
140
130
self .vae_scale_factor = 2 ** (len (self .vae_decoder .config ["block_out_channels" ]) - 1 )
@@ -726,6 +716,11 @@ def __call__(self, latent_sample: np.ndarray):
726
716
outputs = self .request (inputs , share_inputs = True )
727
717
return list (outputs .values ())
728
718
719
+ def _compile (self ):
720
+ if "GPU" in self .device and "INFERENCE_PRECISION_HINT" not in self .ov_config :
721
+ self .ov_config .update ({"INFERENCE_PRECISION_HINT" : "f32" })
722
+ super ()._compile ()
723
+
729
724
730
725
class OVModelVaeEncoder (OVModelPart ):
731
726
def __init__ (
@@ -742,6 +737,11 @@ def __call__(self, sample: np.ndarray):
742
737
outputs = self .request (inputs , share_inputs = True )
743
738
return list (outputs .values ())
744
739
740
+ def _compile (self ):
741
+ if "GPU" in self .device and "INFERENCE_PRECISION_HINT" not in self .ov_config :
742
+ self .ov_config .update ({"INFERENCE_PRECISION_HINT" : "f32" })
743
+ super ()._compile ()
744
+
745
745
746
746
class OVStableDiffusionPipeline (OVStableDiffusionPipelineBase , StableDiffusionPipelineMixin ):
747
747
def __call__ (
0 commit comments