Skip to content

Commit 40ef93f

Browse files
committed
use optimum from branch
1 parent 31507cc commit 40ef93f

File tree

2 files changed

+47
-61
lines changed

2 files changed

+47
-61
lines changed

optimum/exporters/openvino/model_configs.py

+46-60
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def init_model_configs():
145145
# for model registration in auto transformers classses
146146
if importlib.util.find_spec("janus") is not None:
147147
try:
148-
from janus.models import MultiModalityCausalLM, VLChatProcessor
148+
from janus.models import MultiModalityCausalLM # noqa: F401
149149
except ImportError:
150150
pass
151151

@@ -1352,9 +1352,7 @@ def patch_model_for_export(
13521352

13531353

13541354
class LMInputEmbedsConfigHelper(TextDecoderWithPositionIdsOnnxConfig):
1355-
def __init__(
1356-
self, export_config, patcher_cls=None, dummy_input_generator=None, inputs_update=None, remove_lm_head=False
1357-
):
1355+
def __init__(self, export_config, patcher_cls=None, dummy_input_generator=None, inputs_update=None, remove_lm_head=False):
13581356
self.orig_export_config = export_config
13591357
if dummy_input_generator is not None:
13601358
export_config.DUMMY_INPUT_GENERATOR_CLASSES = (
@@ -1373,15 +1371,16 @@ def __init__(
13731371
def patch_model_for_export(
13741372
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
13751373
) -> "ModelPatcher":
1374+
13761375
if self.patcher_cls is not None:
13771376
patcher = self.patcher_cls(self, model, model_kwargs=model_kwargs)
13781377
# Refer to DecoderModelPatcher.
1379-
else:
1378+
else:
13801379
patcher = self.orig_export_config.patch_model_for_export(model, model_kwargs=model_kwargs)
1381-
1380+
13821381
if self.remove_lm_head:
13831382
patcher = RemoveLMHeadPatcherHelper(self, model, model_kwargs, patcher)
1384-
1383+
13851384
return patcher
13861385

13871386
@property
@@ -1390,7 +1389,7 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
13901389
if self.remove_lm_head:
13911390
logits_info = outputs.pop("logits")
13921391
updated_outputs = {"last_hidden_state": logits_info}
1393-
return {**updated_outputs, **outputs}
1392+
return {**updated_outputs, **outputs}
13941393
return outputs
13951394

13961395
@property
@@ -1479,15 +1478,15 @@ def get_vlm_text_generation_config(
14791478
model_patcher=None,
14801479
dummy_input_generator=None,
14811480
inputs_update=None,
1482-
remove_lm_head=False,
1481+
remove_lm_head=False
14831482
):
14841483
internal_export_config = get_vlm_internal_text_generation_config(model_type, model_config, int_dtype, float_dtype)
14851484
export_config = LMInputEmbedsConfigHelper(
14861485
internal_export_config,
14871486
patcher_cls=model_patcher,
14881487
dummy_input_generator=dummy_input_generator,
14891488
inputs_update=inputs_update,
1490-
remove_lm_head=remove_lm_head,
1489+
remove_lm_head=remove_lm_head
14911490
)
14921491
export_config._normalized_config = internal_export_config._normalized_config
14931492
return export_config
@@ -2812,60 +2811,45 @@ class JanusConfigBehavior(str, enum.Enum):
28122811

28132812

28142813
class JanusDummyVisionGenInputGenerator(DummyInputGenerator):
2815-
SUPPORTED_INPUT_NAMES = ("pixel_values", "image_ids", "code_b", "image_shape", "lm_hidden_state", "hidden_state")
2814+
SUPPORTED_INPUT_NAMES = (
2815+
"pixel_values",
2816+
"image_ids",
2817+
"code_b",
2818+
"image_shape",
2819+
"lm_hidden_state",
2820+
"hidden_state"
2821+
)
28162822

28172823
def __init__(
2818-
self,
2819-
task: str,
2820-
normalized_config: NormalizedConfig,
2821-
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
2822-
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
2823-
**kwargs,
2824-
):
2825-
self.task = task
2826-
self.batch_size = batch_size
2827-
self.sequence_length = sequence_length
2828-
self.normalized_config = normalized_config
2829-
2824+
self,
2825+
task: str,
2826+
normalized_config: NormalizedConfig,
2827+
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
2828+
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
2829+
**kwargs,
2830+
):
2831+
self.task = task
2832+
self.batch_size = batch_size
2833+
self.sequence_length = sequence_length
2834+
self.normalized_config = normalized_config
2835+
28302836
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
28312837
if input_name == "pixel_values":
2832-
return self.random_float_tensor(
2833-
[
2834-
self.batch_size,
2835-
1,
2836-
3,
2837-
self.normalized_config.config.params.image_size,
2838-
self.normalized_config.config.params.image_size,
2839-
]
2840-
)
2841-
2838+
return self.random_float_tensor([self.batch_size, 1, 3, self.normalized_config.config.params.image_size, self.normalized_config.config.params.image_size])
2839+
28422840
if input_name == "image_ids":
2843-
return self.random_int_tensor(
2844-
[self.sequence_length],
2845-
max_value=self.normalized_config.config.params.image_token_size,
2846-
framework=framework,
2847-
dtype=int_dtype,
2848-
)
2841+
return self.random_int_tensor([self.sequence_length], max_value=self.normalized_config.config.params.image_token_size, framework=framework, dtype=int_dtype)
28492842
if input_name == "code_b":
2850-
return self.random_int_tensor(
2851-
[self.batch_size, 576],
2852-
max_value=self.normalized_config.config.params.image_token_size,
2853-
framework=framework,
2854-
dtype=int_dtype,
2855-
)
2843+
return self.random_int_tensor([self.batch_size, 576], max_value=self.normalized_config.config.params.image_token_size, framework=framework, dtype=int_dtype)
28562844
if input_name == "image_shape":
28572845
import torch
2858-
2859-
return torch.tensor(
2860-
[self.batch_size, self.normalized_config.config.params.n_embed, 24, 24], dtype=torch.int64
2861-
)
2846+
return torch.tensor([self.batch_size, self.normalized_config.config.params.n_embed, 24, 24], dtype=torch.int64)
28622847
if input_name == "hidden_state":
2863-
return self.random_float_tensor(
2864-
[self.batch_size, self.sequence_length, self.normalized_config.hidden_size]
2865-
)
2848+
return self.random_float_tensor([self.batch_size, self.sequence_length, self.normalized_config.hidden_size])
28662849
if input_name == "lm_hidden_state":
28672850
return self.random_float_tensor([self.sequence_length, self.normalized_config.hidden_size])
28682851
return super().generate(input_name, framework, int_dtype, float_dtype)
2852+
28692853

28702854

28712855
@register_in_tasks_manager("multi-modality", *["image-text-to-text", "any-to-any"], library_name="transformers")
@@ -2883,7 +2867,7 @@ def __init__(
28832867
float_dtype: str = "fp32",
28842868
behavior: JanusConfigBehavior = JanusConfigBehavior.VISION_EMBEDDINGS,
28852869
preprocessors: Optional[List[Any]] = None,
2886-
**kwargs,
2870+
**kwargs
28872871
):
28882872
super().__init__(
28892873
config=config,
@@ -2897,9 +2881,7 @@ def __init__(
28972881
if self._behavior == JanusConfigBehavior.VISION_EMBEDDINGS and hasattr(config, "vision_config"):
28982882
self._config = config.vision_config
28992883
self._normalized_config = NormalizedVisionConfig(self._config)
2900-
if self._behavior in [JanusConfigBehavior.LM_HEAD, JanusConfigBehavior.VISION_GEN_HEAD] and hasattr(
2901-
config, "language_config"
2902-
):
2884+
if self._behavior in [JanusConfigBehavior.LM_HEAD, JanusConfigBehavior.VISION_GEN_HEAD] and hasattr(config, "language_config"):
29032885
self._config = config.language_config
29042886
self._normalized_config = NormalizedTextConfig(self._config)
29052887
if self._behavior == JanusConfigBehavior.VISION_GEN_EMBEDDINGS and hasattr(config, "gen_head_config"):
@@ -2929,7 +2911,7 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
29292911
return {"last_hidden_state": {0: "batch_size"}}
29302912
if self._behavior == JanusConfigBehavior.VISION_GEN_EMBEDDINGS:
29312913
return {"last_hidden_state": {0: "num_tokens"}}
2932-
2914+
29332915
if self._behavior == JanusConfigBehavior.LM_HEAD:
29342916
return {"logits": {0: "batch_size", 1: "sequence_length"}}
29352917

@@ -2996,6 +2978,7 @@ def with_behavior(
29962978
preprocessors=self._preprocessors,
29972979
)
29982980

2981+
29992982
if behavior == JanusConfigBehavior.VISION_EMBEDDINGS:
30002983
return self.__class__(
30012984
self._orig_config,
@@ -3005,7 +2988,7 @@ def with_behavior(
30052988
behavior=behavior,
30062989
preprocessors=self._preprocessors,
30072990
)
3008-
2991+
30092992
if behavior == JanusConfigBehavior.VISION_GEN_DECODER:
30102993
return self.__class__(
30112994
self._orig_config,
@@ -3016,6 +2999,7 @@ def with_behavior(
30162999
preprocessors=self._preprocessors,
30173000
)
30183001

3002+
30193003
def get_model_for_behavior(self, model, behavior: Union[str, JanusConfigBehavior]):
30203004
if isinstance(behavior, str) and not isinstance(behavior, JanusConfigBehavior):
30213005
behavior = JanusConfigBehavior(behavior)
@@ -3038,7 +3022,7 @@ def get_model_for_behavior(self, model, behavior: Union[str, JanusConfigBehavior
30383022

30393023
if behavior == JanusConfigBehavior.VISION_GEN_EMBEDDINGS:
30403024
return model
3041-
3025+
30423026
if behavior == JanusConfigBehavior.VISION_GEN_HEAD:
30433027
gen_head = model.gen_head
30443028
gen_head.config = model.language_model.config
@@ -3047,6 +3031,7 @@ def get_model_for_behavior(self, model, behavior: Union[str, JanusConfigBehavior
30473031
if behavior == JanusConfigBehavior.VISION_GEN_DECODER:
30483032
return model.gen_vision_model
30493033

3034+
30503035
def patch_model_for_export(
30513036
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
30523037
):
@@ -3059,6 +3044,7 @@ def patch_model_for_export(
30593044
return JanusVisionGenDecoderModelPatcher(self, model, model_kwargs)
30603045
return super().patch_model_for_export(model, model_kwargs)
30613046

3047+
30623048
def rename_ambiguous_inputs(self, inputs):
30633049
if self._behavior == JanusConfigBehavior.VISION_GEN_HEAD:
30643050
data = inputs.pop("lm_hidden_state")
@@ -3069,4 +3055,4 @@ def rename_ambiguous_inputs(self, inputs):
30693055
if self._behavior == JanusConfigBehavior.VISION_GEN_DECODER:
30703056
data = inputs.pop("image_shape")
30713057
inputs["shape"] = data
3072-
return inputs
3058+
return inputs

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
INSTALL_REQUIRE = [
3030
"torch>=1.11",
31-
"optimum~=1.24",
31+
"optimum @ git+https://github.com/eaidova/optimum.git@ea/avoid_lib_guessing_in_standartize_args",
3232
"transformers>=4.36,<4.48",
3333
"datasets>=1.4.0",
3434
"sentencepiece",

0 commit comments

Comments
 (0)