Skip to content

Commit f35aa15

Browse files
committed
fix style
1 parent db22a52 commit f35aa15

6 files changed

+32
-14
lines changed

optimum/intel/openvino/modeling_base.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,10 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
150150

151151
self._save_openvino_config(save_directory)
152152

153-
154153
def _save_openvino_config(self, save_directory: Union[str, Path]):
155154
if self._openvino_config is not None:
156155
self._openvino_config.save_pretrained(save_directory)
157156

158-
159157
@classmethod
160158
def _from_pretrained(
161159
cls,
@@ -216,12 +214,21 @@ def _from_pretrained(
216214
local_files_only=local_files_only,
217215
)
218216

219-
quantization_config = self._prepare_quantization_config(quantization_config, load_in_8bit)
217+
quantization_config = cls._prepare_quantization_config(quantization_config, load_in_8bit)
220218

221219
model = cls.load_model(model_cache_path, quantization_config=quantization_config)
222-
return cls(model, config=config, model_save_dir=model_cache_path.parent, quantization_config=quantization_config, **kwargs)
220+
return cls(
221+
model,
222+
config=config,
223+
model_save_dir=model_cache_path.parent,
224+
quantization_config=quantization_config,
225+
**kwargs,
226+
)
223227

224-
def _prepare_quantization_config(quantization_config : Optional[Union[OVWeightQuantizationConfig, Dict]] = None, load_in_8bit:bool= False):
228+
@staticmethod
229+
def _prepare_quantization_config(
230+
quantization_config: Optional[Union[OVWeightQuantizationConfig, Dict]] = None, load_in_8bit: bool = False
231+
):
225232
# Give default quantization config if not provided and load_in_8bit=True
226233
if not quantization_config and load_in_8bit:
227234
quantization_config = OVWeightQuantizationConfig(bits=8)

optimum/intel/openvino/modeling_base_seq2seq.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def _from_pretrained(
164164
decoder_with_past_file_name = decoder_with_past_file_name or default_decoder_with_past_file_name
165165
decoder_with_past = None
166166

167-
quantization_config = self._prepare_quantization_config(quantization_config, load_in_8bit)
167+
quantization_config = cls._prepare_quantization_config(quantization_config, load_in_8bit)
168168

169169
# Load model from a local directory
170170
if os.path.isdir(model_id):

optimum/intel/openvino/modeling_decoder.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -581,9 +581,9 @@ def _from_pretrained(
581581
)
582582

583583
if isinstance(quantization_config, dict) and quantization_config == {"bits": 4}:
584-
quantization_config = _DEFAULT_4BIT_CONFIGS.get(config.name_or_path, quantization_config)
584+
quantization_config = _DEFAULT_4BIT_CONFIGS.get(config.name_or_path, quantization_config)
585585

586-
quantization_config = self._prepare_quantization_config(quantization_config, load_in_8bit)
586+
quantization_config = cls._prepare_quantization_config(quantization_config, load_in_8bit)
587587

588588
load_in_4bit = quantization_config.bits == 4 if quantization_config else False
589589
model = cls.load_model(model_cache_path, quantization_config=None if load_in_4bit else quantization_config)
@@ -602,7 +602,12 @@ def _from_pretrained(
602602

603603
enable_compilation = kwargs.pop("compile", True) and not load_in_4bit
604604
causal_model = init_cls(
605-
model=model, config=config, model_save_dir=model_cache_path.parent, compile=enable_compilation, quantization_config=quantization_config, **kwargs
605+
model=model,
606+
config=config,
607+
model_save_dir=model_cache_path.parent,
608+
compile=enable_compilation,
609+
quantization_config=quantization_config,
610+
**kwargs,
606611
)
607612

608613
if load_in_4bit:

optimum/intel/openvino/modeling_diffusion.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,6 @@ def __init__(
145145
if quantization_config:
146146
self._openvino_config = OVConfig(quantization_config=quantization_config)
147147

148-
149148
def _save_pretrained(self, save_directory: Union[str, Path]):
150149
"""
151150
Saves the model to the OpenVINO IR format so that it can be re-loaded using the
@@ -265,7 +264,7 @@ def _from_pretrained(
265264
else:
266265
kwargs[name] = load_method(new_model_save_dir)
267266

268-
quantization_config = self._prepare_quantization_config(quantization_config, load_in_8bit)
267+
quantization_config = cls._prepare_quantization_config(quantization_config, load_in_8bit)
269268
unet = cls.load_model(
270269
new_model_save_dir / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name, quantization_config
271270
)
@@ -283,7 +282,14 @@ def _from_pretrained(
283282
if model_save_dir is None:
284283
model_save_dir = new_model_save_dir
285284

286-
return cls(unet=unet, config=config, model_save_dir=model_save_dir, quantization_config=quantization_config, **components, **kwargs)
285+
return cls(
286+
unet=unet,
287+
config=config,
288+
model_save_dir=model_save_dir,
289+
quantization_config=quantization_config,
290+
**components,
291+
**kwargs,
292+
)
287293

288294
@classmethod
289295
def _from_transformers(

optimum/intel/openvino/quantization.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from ..utils.constant import _TASK_ALIASES
4545
from ..utils.import_utils import DATASETS_IMPORT_ERROR, is_datasets_available
4646
from ..utils.modeling_utils import get_model_device
47-
from .configuration import OVConfig, OVWeightQuantizationConfig, DEFAULT_QUANTIZATION_CONFIG
47+
from .configuration import DEFAULT_QUANTIZATION_CONFIG, OVConfig, OVWeightQuantizationConfig
4848
from .modeling_base import OVBaseModel
4949
from .utils import (
5050
MAX_ONNX_OPSET,

optimum/intel/openvino/trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@
8989

9090
from ..utils.constant import _TASK_ALIASES
9191
from ..utils.import_utils import is_transformers_version
92-
from .configuration import OVConfig, DEFAULT_QUANTIZATION_CONFIG
92+
from .configuration import DEFAULT_QUANTIZATION_CONFIG, OVConfig
9393
from .quantization import OVDataLoader
9494
from .training_args import OVTrainingArguments
9595
from .utils import (

0 commit comments

Comments
 (0)