Skip to content

Commit 9d9b574

Browse files
Simplify usage
1 parent 3400e81 commit 9d9b574

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

optimum/intel/openvino/modeling_diffusion.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,6 @@ def __init__(
9898
ov_config: Optional[Dict[str, str]] = None,
9999
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
100100
quantization_config: Optional[Union[OVWeightQuantizationConfig, Dict]] = None,
101-
vae_decoder_ov_config: Optional[Dict[str, str]] = None,
102-
vae_encoder_ov_config: Optional[Dict[str, str]] = None,
103101
**kwargs,
104102
):
105103
self._internal_dict = config
@@ -118,23 +116,15 @@ def __init__(
118116
else:
119117
self._model_save_dir = model_save_dir
120118

121-
default_vae_ov_config = deepcopy(self.ov_config)
122-
if "GPU" in self._device:
123-
default_vae_ov_config.update({"INFERENCE_PRECISION_HINT": "f32"})
124-
125-
self.vae_decoder = OVModelVaeDecoder(vae_decoder, self, vae_decoder_ov_config or default_vae_ov_config)
119+
self.vae_decoder = OVModelVaeDecoder(vae_decoder, self)
126120
self.unet = OVModelUnet(unet, self)
127121
self.text_encoder = OVModelTextEncoder(text_encoder, self) if text_encoder is not None else None
128122
self.text_encoder_2 = (
129123
OVModelTextEncoder(text_encoder_2, self, model_name=DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER)
130124
if text_encoder_2 is not None
131125
else None
132126
)
133-
self.vae_encoder = (
134-
OVModelVaeEncoder(vae_encoder, self, vae_encoder_ov_config or default_vae_ov_config)
135-
if vae_encoder is not None
136-
else None
137-
)
127+
self.vae_encoder = OVModelVaeEncoder(vae_encoder, self) if vae_encoder is not None else None
138128

139129
if "block_out_channels" in self.vae_decoder.config:
140130
self.vae_scale_factor = 2 ** (len(self.vae_decoder.config["block_out_channels"]) - 1)
@@ -726,6 +716,11 @@ def __call__(self, latent_sample: np.ndarray):
726716
outputs = self.request(inputs, share_inputs=True)
727717
return list(outputs.values())
728718

719+
def _compile(self):
720+
if "GPU" in self.device and "INFERENCE_PRECISION_HINT" not in self.ov_config:
721+
self.ov_config.update({"INFERENCE_PRECISION_HINT": "f32"})
722+
super()._compile()
723+
729724

730725
class OVModelVaeEncoder(OVModelPart):
731726
def __init__(
@@ -742,6 +737,11 @@ def __call__(self, sample: np.ndarray):
742737
outputs = self.request(inputs, share_inputs=True)
743738
return list(outputs.values())
744739

740+
def _compile(self):
741+
if "GPU" in self.device and "INFERENCE_PRECISION_HINT" not in self.ov_config:
742+
self.ov_config.update({"INFERENCE_PRECISION_HINT": "f32"})
743+
super()._compile()
744+
745745

746746
class OVStableDiffusionPipeline(OVStableDiffusionPipelineBase, StableDiffusionPipelineMixin):
747747
def __call__(

0 commit comments

Comments
 (0)