Skip to content

Commit fbd04f3

Browse files
committed
finalize janus support
1 parent 7dc3257 commit fbd04f3

File tree

7 files changed

+288
-70
lines changed

7 files changed

+288
-70
lines changed

optimum/exporters/openvino/__main__.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
import gc
16-
import importlib
1716
import logging
1817
import operator
1918
import warnings
@@ -40,13 +39,13 @@
4039
_infer_library_from_model_name_or_path,
4140
_OpenClipForZeroShotImageClassification,
4241
)
43-
from optimum.utils.save_utils import maybe_load_preprocessors
4442

4543
from .utils import (
4644
_MAX_UNCOMPRESSED_SIZE,
4745
MULTI_MODAL_TEXT_GENERATION_MODELS,
4846
clear_class_registry,
4947
deduce_diffusers_dtype,
48+
load_preprocessors,
5049
)
5150

5251

@@ -193,6 +192,7 @@ def main_export(
193192
```
194193
"""
195194
from optimum.exporters.openvino.convert import export_from_model
195+
196196
if use_auth_token is not None:
197197
warnings.warn(
198198
"The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.",
@@ -214,7 +214,7 @@ def main_export(
214214
revision=revision,
215215
cache_dir=cache_dir,
216216
token=token,
217-
library_name=library_name
217+
library_name=library_name,
218218
)
219219
if library_name == "sentence_transformers":
220220
logger.warning(
@@ -434,7 +434,7 @@ class StoreAttr(object):
434434
possible_synonyms = ""
435435
logger.info(f"Automatic task detection to {task}{possible_synonyms}.")
436436

437-
preprocessors = maybe_load_preprocessors(
437+
preprocessors = load_preprocessors(
438438
model_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code
439439
)
440440

optimum/exporters/openvino/convert.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,7 @@ def export_from_model(
757757
logger.warning(
758758
f"The generation config will not be saved, saving failed with following error:\n{exception}"
759759
)
760-
760+
logger.warn(preprocessors)
761761
save_preprocessors(preprocessors, model.config, output, trust_remote_code)
762762

763763
files_subpaths = ["openvino_" + model_name + ".xml" for model_name in models_and_export_configs.keys()]

optimum/exporters/openvino/model_configs.py

+60-47
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,10 @@ def init_model_configs():
145145
# for model registration in auto transformers classses
146146
if importlib.util.find_spec("janus") is not None:
147147
try:
148-
from janus.models import MultiModalityCausalLM
148+
from janus.models import MultiModalityCausalLM, VLChatProcessor
149149
except ImportError:
150150
pass
151151

152-
153152
if is_diffusers_available() and "fill" not in TasksManager._DIFFUSERS_TASKS_TO_MODEL_LOADERS:
154153
TasksManager._DIFFUSERS_TASKS_TO_MODEL_LOADERS["fill"] = "FluxFillPipeline"
155154
TasksManager._DIFFUSERS_TASKS_TO_MODEL_MAPPINGS["fill"] = {"flux": "FluxFillPipeline"}
@@ -1353,7 +1352,9 @@ def patch_model_for_export(
13531352

13541353

13551354
class LMInputEmbedsConfigHelper(TextDecoderWithPositionIdsOnnxConfig):
1356-
def __init__(self, export_config, patcher_cls=None, dummy_input_generator=None, inputs_update=None, remove_lm_head=False):
1355+
def __init__(
1356+
self, export_config, patcher_cls=None, dummy_input_generator=None, inputs_update=None, remove_lm_head=False
1357+
):
13571358
self.orig_export_config = export_config
13581359
if dummy_input_generator is not None:
13591360
export_config.DUMMY_INPUT_GENERATOR_CLASSES = (
@@ -1372,16 +1373,15 @@ def __init__(self, export_config, patcher_cls=None, dummy_input_generator=None,
13721373
def patch_model_for_export(
13731374
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
13741375
) -> "ModelPatcher":
1375-
13761376
if self.patcher_cls is not None:
13771377
patcher = self.patcher_cls(self, model, model_kwargs=model_kwargs)
13781378
# Refer to DecoderModelPatcher.
1379-
else:
1379+
else:
13801380
patcher = self.orig_export_config.patch_model_for_export(model, model_kwargs=model_kwargs)
1381-
1381+
13821382
if self.remove_lm_head:
13831383
patcher = RemoveLMHeadPatcherHelper(self, model, model_kwargs, patcher)
1384-
1384+
13851385
return patcher
13861386

13871387
@property
@@ -1390,7 +1390,7 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
13901390
if self.remove_lm_head:
13911391
logits_info = outputs.pop("logits")
13921392
updated_outputs = {"last_hidden_state": logits_info}
1393-
return {**updated_outputs, **outputs}
1393+
return {**updated_outputs, **outputs}
13941394
return outputs
13951395

13961396
@property
@@ -1479,15 +1479,15 @@ def get_vlm_text_generation_config(
14791479
model_patcher=None,
14801480
dummy_input_generator=None,
14811481
inputs_update=None,
1482-
remove_lm_head=False
1482+
remove_lm_head=False,
14831483
):
14841484
internal_export_config = get_vlm_internal_text_generation_config(model_type, model_config, int_dtype, float_dtype)
14851485
export_config = LMInputEmbedsConfigHelper(
14861486
internal_export_config,
14871487
patcher_cls=model_patcher,
14881488
dummy_input_generator=dummy_input_generator,
14891489
inputs_update=inputs_update,
1490-
remove_lm_head=remove_lm_head
1490+
remove_lm_head=remove_lm_head,
14911491
)
14921492
export_config._normalized_config = internal_export_config._normalized_config
14931493
return export_config
@@ -2812,45 +2812,60 @@ class JanusConfigBehavior(str, enum.Enum):
28122812

28132813

28142814
class JanusDummyVisionGenInputGenerator(DummyInputGenerator):
2815-
SUPPORTED_INPUT_NAMES = (
2816-
"pixel_values",
2817-
"image_ids",
2818-
"code_b",
2819-
"image_shape",
2820-
"lm_hidden_state",
2821-
"hidden_state"
2822-
)
2815+
SUPPORTED_INPUT_NAMES = ("pixel_values", "image_ids", "code_b", "image_shape", "lm_hidden_state", "hidden_state")
28232816

28242817
def __init__(
2825-
self,
2826-
task: str,
2827-
normalized_config: NormalizedConfig,
2828-
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
2829-
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
2830-
**kwargs,
2831-
):
2832-
self.task = task
2833-
self.batch_size = batch_size
2834-
self.sequence_length = sequence_length
2835-
self.normalized_config = normalized_config
2836-
2818+
self,
2819+
task: str,
2820+
normalized_config: NormalizedConfig,
2821+
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
2822+
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
2823+
**kwargs,
2824+
):
2825+
self.task = task
2826+
self.batch_size = batch_size
2827+
self.sequence_length = sequence_length
2828+
self.normalized_config = normalized_config
2829+
28372830
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
28382831
if input_name == "pixel_values":
2839-
return self.random_float_tensor([self.batch_size, 1, 3, self.normalized_config.config.params.image_size, self.normalized_config.config.params.image_size])
2840-
2832+
return self.random_float_tensor(
2833+
[
2834+
self.batch_size,
2835+
1,
2836+
3,
2837+
self.normalized_config.config.params.image_size,
2838+
self.normalized_config.config.params.image_size,
2839+
]
2840+
)
2841+
28412842
if input_name == "image_ids":
2842-
return self.random_int_tensor([self.sequence_length], max_value=self.normalized_config.config.params.image_token_size, framework=framework, dtype=int_dtype)
2843+
return self.random_int_tensor(
2844+
[self.sequence_length],
2845+
max_value=self.normalized_config.config.params.image_token_size,
2846+
framework=framework,
2847+
dtype=int_dtype,
2848+
)
28432849
if input_name == "code_b":
2844-
return self.random_int_tensor([self.batch_size, 576], max_value=self.normalized_config.config.params.image_token_size, framework=framework, dtype=int_dtype)
2850+
return self.random_int_tensor(
2851+
[self.batch_size, 576],
2852+
max_value=self.normalized_config.config.params.image_token_size,
2853+
framework=framework,
2854+
dtype=int_dtype,
2855+
)
28452856
if input_name == "image_shape":
28462857
import torch
2847-
return torch.tensor([self.batch_size, self.normalized_config.config.params.n_embed, 24, 24], dtype=torch.int64)
2858+
2859+
return torch.tensor(
2860+
[self.batch_size, self.normalized_config.config.params.n_embed, 24, 24], dtype=torch.int64
2861+
)
28482862
if input_name == "hidden_state":
2849-
return self.random_float_tensor([self.batch_size, self.sequence_length, self.normalized_config.hidden_size])
2863+
return self.random_float_tensor(
2864+
[self.batch_size, self.sequence_length, self.normalized_config.hidden_size]
2865+
)
28502866
if input_name == "lm_hidden_state":
28512867
return self.random_float_tensor([self.sequence_length, self.normalized_config.hidden_size])
28522868
return super().generate(input_name, framework, int_dtype, float_dtype)
2853-
28542869

28552870

28562871
@register_in_tasks_manager("multi-modality", *["image-text-to-text", "any-to-any"], library_name="transformers")
@@ -2868,7 +2883,7 @@ def __init__(
28682883
float_dtype: str = "fp32",
28692884
behavior: JanusConfigBehavior = JanusConfigBehavior.VISION_EMBEDDINGS,
28702885
preprocessors: Optional[List[Any]] = None,
2871-
**kwargs
2886+
**kwargs,
28722887
):
28732888
super().__init__(
28742889
config=config,
@@ -2882,7 +2897,9 @@ def __init__(
28822897
if self._behavior == JanusConfigBehavior.VISION_EMBEDDINGS and hasattr(config, "vision_config"):
28832898
self._config = config.vision_config
28842899
self._normalized_config = NormalizedVisionConfig(self._config)
2885-
if self._behavior in [JanusConfigBehavior.LM_HEAD, JanusConfigBehavior.VISION_GEN_HEAD] and hasattr(config, "language_config"):
2900+
if self._behavior in [JanusConfigBehavior.LM_HEAD, JanusConfigBehavior.VISION_GEN_HEAD] and hasattr(
2901+
config, "language_config"
2902+
):
28862903
self._config = config.language_config
28872904
self._normalized_config = NormalizedTextConfig(self._config)
28882905
if self._behavior == JanusConfigBehavior.VISION_GEN_EMBEDDINGS and hasattr(config, "gen_head_config"):
@@ -2912,7 +2929,7 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
29122929
return {"last_hidden_state": {0: "batch_size"}}
29132930
if self._behavior == JanusConfigBehavior.VISION_GEN_EMBEDDINGS:
29142931
return {"last_hidden_state": {0: "num_tokens"}}
2915-
2932+
29162933
if self._behavior == JanusConfigBehavior.LM_HEAD:
29172934
return {"logits": {0: "batch_size", 1: "sequence_length"}}
29182935

@@ -2979,7 +2996,6 @@ def with_behavior(
29792996
preprocessors=self._preprocessors,
29802997
)
29812998

2982-
29832999
if behavior == JanusConfigBehavior.VISION_EMBEDDINGS:
29843000
return self.__class__(
29853001
self._orig_config,
@@ -2989,7 +3005,7 @@ def with_behavior(
29893005
behavior=behavior,
29903006
preprocessors=self._preprocessors,
29913007
)
2992-
3008+
29933009
if behavior == JanusConfigBehavior.VISION_GEN_DECODER:
29943010
return self.__class__(
29953011
self._orig_config,
@@ -3000,7 +3016,6 @@ def with_behavior(
30003016
preprocessors=self._preprocessors,
30013017
)
30023018

3003-
30043019
def get_model_for_behavior(self, model, behavior: Union[str, JanusConfigBehavior]):
30053020
if isinstance(behavior, str) and not isinstance(behavior, JanusConfigBehavior):
30063021
behavior = JanusConfigBehavior(behavior)
@@ -3023,7 +3038,7 @@ def get_model_for_behavior(self, model, behavior: Union[str, JanusConfigBehavior
30233038

30243039
if behavior == JanusConfigBehavior.VISION_GEN_EMBEDDINGS:
30253040
return model
3026-
3041+
30273042
if behavior == JanusConfigBehavior.VISION_GEN_HEAD:
30283043
gen_head = model.gen_head
30293044
gen_head.config = model.language_model.config
@@ -3032,7 +3047,6 @@ def get_model_for_behavior(self, model, behavior: Union[str, JanusConfigBehavior
30323047
if behavior == JanusConfigBehavior.VISION_GEN_DECODER:
30333048
return model.gen_vision_model
30343049

3035-
30363050
def patch_model_for_export(
30373051
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
30383052
):
@@ -3045,7 +3059,6 @@ def patch_model_for_export(
30453059
return JanusVisionGenDecoderModelPatcher(self, model, model_kwargs)
30463060
return super().patch_model_for_export(model, model_kwargs)
30473061

3048-
30493062
def rename_ambiguous_inputs(self, inputs):
30503063
if self._behavior == JanusConfigBehavior.VISION_GEN_HEAD:
30513064
data = inputs.pop("lm_hidden_state")
@@ -3056,4 +3069,4 @@ def rename_ambiguous_inputs(self, inputs):
30563069
if self._behavior == JanusConfigBehavior.VISION_GEN_DECODER:
30573070
data = inputs.pop("image_shape")
30583071
inputs["shape"] = data
3059-
return inputs
3072+
return inputs

optimum/exporters/openvino/model_patcher.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -3905,7 +3905,7 @@ def __exit__(self, exc_type, exc_value, traceback):
39053905

39063906
def janus_vision_embed_forward(self, pixel_values):
39073907
from einops import rearrange
3908-
3908+
39093909
bs, n = pixel_values.shape[0:2]
39103910
images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
39113911
# [b x n, T2, D]
@@ -3968,23 +3968,25 @@ def __exit__(self, exc_type, exc_value, traceback):
39683968

39693969

39703970
class RemoveLMHeadPatcherHelper(DecoderModelPatcher):
3971-
def __init__(self,
3971+
def __init__(
3972+
self,
39723973
config: "OnnxConfig",
39733974
model: Union["PreTrainedModel", "TFPreTrainedModel"],
39743975
model_kwargs: Dict[str, Any],
3975-
internal_patcher = None
3976+
internal_patcher=None,
39763977
):
39773978
model.__orig_forward = model.forward
3979+
39783980
@functools.wraps(model.__orig_forward)
39793981
def patched_forward(*args, **kwargs):
39803982
return model.model.forward(*args, **kwargs)
3983+
39813984
model.forward = patched_forward
39823985
self._internal_patcher = internal_patcher
39833986
if self._internal_patcher is not None:
39843987
self._patched_forward = self._internal_patcher.patched_forward
39853988
super().__init__(config, model, model_kwargs)
39863989

3987-
39883990
def __enter__(self):
39893991
if self._internal_patcher is not None:
39903992
return self._internal_patcher.__enter__()
@@ -4007,4 +4009,4 @@ def patched_forward(self):
40074009
def patched_forward(self, fn):
40084010
self._patched_forward = fn
40094011
if self._internal_patcher is not None:
4010-
self._internal_patcher.patched_forward = fn
4012+
self._internal_patcher.patched_forward = fn

optimum/exporters/openvino/utils.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import importlib
1516
import inspect
1617
import logging
1718
from collections import namedtuple
@@ -28,7 +29,7 @@
2829
from optimum.intel.utils import is_transformers_version
2930
from optimum.intel.utils.import_utils import is_safetensors_available
3031
from optimum.utils import is_diffusers_available
31-
from optimum.utils.save_utils import maybe_save_preprocessors
32+
from optimum.utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
3233

3334

3435
logger = logging.getLogger(__name__)
@@ -225,7 +226,7 @@ def get_submodels(model):
225226
"minicpmv",
226227
"phi3-v",
227228
"qwen2-vl",
228-
"multi-modality"
229+
"multi-modality",
229230
]
230231

231232

@@ -303,3 +304,20 @@ def save_preprocessors(
303304
logger.error(f"Saving {type(processor)} failed with {ex}")
304305
else:
305306
maybe_save_preprocessors(model_name_or_path, output, trust_remote_code=trust_remote_code)
307+
308+
309+
def load_preprocessors(src_name_or_path: Union[str, Path], subfolder: str = "", trust_remote_code: bool = False):
310+
preprocessors = maybe_load_preprocessors(
311+
src_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code
312+
)
313+
if importlib.util.find_spec("janus") is not None:
314+
from janus.models import VLChatProcessor
315+
316+
try:
317+
processor = VLChatProcessor.from_pretrained(
318+
src_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code
319+
)
320+
preprocessors.append(processor)
321+
except Exception:
322+
pass
323+
return preprocessors

0 commit comments

Comments
 (0)