Skip to content

Commit ce7789f

Browse files
committed
fix code style
1 parent c26a450 commit ce7789f

File tree

3 files changed

+19
-11
lines changed

3 files changed

+19
-11
lines changed

optimum/exporters/openvino/model_patcher.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -3386,8 +3386,8 @@ class Qwen2VLLanguageModelPatcher(DecoderModelPatcher):
33863386
def __init__(
33873387
self,
33883388
config: OnnxConfig,
3389-
model: PreTrainedModel | TFPreTrainedModel,
3390-
model_kwargs: Dict[str, Any] | None = None,
3389+
model: Union[PreTrainedModel, TFPreTrainedModel],
3390+
model_kwargs: Dict[str, Any] = None,
33913391
):
33923392

33933393
model.__orig_forward = model.forward
@@ -3426,8 +3426,8 @@ class Qwen2VLVisionEmbMergerPatcher(ModelPatcher):
34263426
def __init__(
34273427
self,
34283428
config: OnnxConfig,
3429-
model: PreTrainedModel | TFPreTrainedModel,
3430-
model_kwargs: Dict[str, Any] | None = None,
3429+
model: Union[PreTrainedModel, TFPreTrainedModel],
3430+
model_kwargs: Dict[str, Any] = None,
34313431
):
34323432
model.__orig_forward = model.forward
34333433

optimum/exporters/openvino/utils.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,15 @@ def get_submodels(model):
216216
return custom_export, fn_get_submodels
217217

218218

219-
MULTI_MODAL_TEXT_GENERATION_MODELS = ["llava", "llava-next", "llava-qwen2", "internvl-chat", "minicpmv", "phi3-v", "qwen2-vl"]
219+
MULTI_MODAL_TEXT_GENERATION_MODELS = [
220+
"llava",
221+
"llava-next",
222+
"llava-qwen2",
223+
"internvl-chat",
224+
"minicpmv",
225+
"phi3-v",
226+
"qwen2-vl",
227+
]
220228

221229

222230
def save_config(config, save_dir):

optimum/intel/openvino/modeling_base.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,9 @@ def __init__(
111111
for idx, key in enumerate(model.inputs):
112112
names = tuple(key.get_names())
113113
input_names[next((name for name in names if "/" not in name), names[0])] = idx
114-
input_dtypes[
115-
next((name for name in names if "/" not in name), names[0])
116-
] = key.get_element_type().get_type_name()
114+
input_dtypes[next((name for name in names if "/" not in name), names[0])] = (
115+
key.get_element_type().get_type_name()
116+
)
117117
self.input_names = input_names
118118
self.input_dtypes = input_dtypes
119119

@@ -122,9 +122,9 @@ def __init__(
122122
for idx, key in enumerate(model.outputs):
123123
names = tuple(key.get_names())
124124
output_names[next((name for name in names if "/" not in name), names[0])] = idx
125-
output_dtypes[
126-
next((name for name in names if "/" not in name), names[0])
127-
] = key.get_element_type().get_type_name()
125+
output_dtypes[next((name for name in names if "/" not in name), names[0])] = (
126+
key.get_element_type().get_type_name()
127+
)
128128

129129
self.output_names = output_names
130130
self.output_dtypes = output_dtypes

0 commit comments

Comments
 (0)