@@ -145,7 +145,7 @@ def init_model_configs():
145
145
# for model registration in auto transformers classses
146
146
if importlib .util .find_spec ("janus" ) is not None :
147
147
try :
148
- from janus .models import MultiModalityCausalLM , VLChatProcessor
148
+ from janus .models import MultiModalityCausalLM # noqa: F401
149
149
except ImportError :
150
150
pass
151
151
@@ -1352,9 +1352,7 @@ def patch_model_for_export(
1352
1352
1353
1353
1354
1354
class LMInputEmbedsConfigHelper (TextDecoderWithPositionIdsOnnxConfig ):
1355
- def __init__ (
1356
- self , export_config , patcher_cls = None , dummy_input_generator = None , inputs_update = None , remove_lm_head = False
1357
- ):
1355
+ def __init__ (self , export_config , patcher_cls = None , dummy_input_generator = None , inputs_update = None , remove_lm_head = False ):
1358
1356
self .orig_export_config = export_config
1359
1357
if dummy_input_generator is not None :
1360
1358
export_config .DUMMY_INPUT_GENERATOR_CLASSES = (
@@ -1373,15 +1371,16 @@ def __init__(
1373
1371
def patch_model_for_export (
1374
1372
self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
1375
1373
) -> "ModelPatcher" :
1374
+
1376
1375
if self .patcher_cls is not None :
1377
1376
patcher = self .patcher_cls (self , model , model_kwargs = model_kwargs )
1378
1377
# Refer to DecoderModelPatcher.
1379
- else :
1378
+ else :
1380
1379
patcher = self .orig_export_config .patch_model_for_export (model , model_kwargs = model_kwargs )
1381
-
1380
+
1382
1381
if self .remove_lm_head :
1383
1382
patcher = RemoveLMHeadPatcherHelper (self , model , model_kwargs , patcher )
1384
-
1383
+
1385
1384
return patcher
1386
1385
1387
1386
@property
@@ -1390,7 +1389,7 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
1390
1389
if self .remove_lm_head :
1391
1390
logits_info = outputs .pop ("logits" )
1392
1391
updated_outputs = {"last_hidden_state" : logits_info }
1393
- return {** updated_outputs , ** outputs }
1392
+ return {** updated_outputs , ** outputs }
1394
1393
return outputs
1395
1394
1396
1395
@property
@@ -1479,15 +1478,15 @@ def get_vlm_text_generation_config(
1479
1478
model_patcher = None ,
1480
1479
dummy_input_generator = None ,
1481
1480
inputs_update = None ,
1482
- remove_lm_head = False ,
1481
+ remove_lm_head = False
1483
1482
):
1484
1483
internal_export_config = get_vlm_internal_text_generation_config (model_type , model_config , int_dtype , float_dtype )
1485
1484
export_config = LMInputEmbedsConfigHelper (
1486
1485
internal_export_config ,
1487
1486
patcher_cls = model_patcher ,
1488
1487
dummy_input_generator = dummy_input_generator ,
1489
1488
inputs_update = inputs_update ,
1490
- remove_lm_head = remove_lm_head ,
1489
+ remove_lm_head = remove_lm_head
1491
1490
)
1492
1491
export_config ._normalized_config = internal_export_config ._normalized_config
1493
1492
return export_config
@@ -2812,60 +2811,45 @@ class JanusConfigBehavior(str, enum.Enum):
2812
2811
2813
2812
2814
2813
class JanusDummyVisionGenInputGenerator (DummyInputGenerator ):
2815
- SUPPORTED_INPUT_NAMES = ("pixel_values" , "image_ids" , "code_b" , "image_shape" , "lm_hidden_state" , "hidden_state" )
2814
+ SUPPORTED_INPUT_NAMES = (
2815
+ "pixel_values" ,
2816
+ "image_ids" ,
2817
+ "code_b" ,
2818
+ "image_shape" ,
2819
+ "lm_hidden_state" ,
2820
+ "hidden_state"
2821
+ )
2816
2822
2817
2823
def __init__ (
2818
- self ,
2819
- task : str ,
2820
- normalized_config : NormalizedConfig ,
2821
- batch_size : int = DEFAULT_DUMMY_SHAPES ["batch_size" ],
2822
- sequence_length : int = DEFAULT_DUMMY_SHAPES ["sequence_length" ],
2823
- ** kwargs ,
2824
- ):
2825
- self .task = task
2826
- self .batch_size = batch_size
2827
- self .sequence_length = sequence_length
2828
- self .normalized_config = normalized_config
2829
-
2824
+ self ,
2825
+ task : str ,
2826
+ normalized_config : NormalizedConfig ,
2827
+ batch_size : int = DEFAULT_DUMMY_SHAPES ["batch_size" ],
2828
+ sequence_length : int = DEFAULT_DUMMY_SHAPES ["sequence_length" ],
2829
+ ** kwargs ,
2830
+ ):
2831
+ self .task = task
2832
+ self .batch_size = batch_size
2833
+ self .sequence_length = sequence_length
2834
+ self .normalized_config = normalized_config
2835
+
2830
2836
def generate (self , input_name : str , framework : str = "pt" , int_dtype : str = "int64" , float_dtype : str = "fp32" ):
2831
2837
if input_name == "pixel_values" :
2832
- return self .random_float_tensor (
2833
- [
2834
- self .batch_size ,
2835
- 1 ,
2836
- 3 ,
2837
- self .normalized_config .config .params .image_size ,
2838
- self .normalized_config .config .params .image_size ,
2839
- ]
2840
- )
2841
-
2838
+ return self .random_float_tensor ([self .batch_size , 1 , 3 , self .normalized_config .config .params .image_size , self .normalized_config .config .params .image_size ])
2839
+
2842
2840
if input_name == "image_ids" :
2843
- return self .random_int_tensor (
2844
- [self .sequence_length ],
2845
- max_value = self .normalized_config .config .params .image_token_size ,
2846
- framework = framework ,
2847
- dtype = int_dtype ,
2848
- )
2841
+ return self .random_int_tensor ([self .sequence_length ], max_value = self .normalized_config .config .params .image_token_size , framework = framework , dtype = int_dtype )
2849
2842
if input_name == "code_b" :
2850
- return self .random_int_tensor (
2851
- [self .batch_size , 576 ],
2852
- max_value = self .normalized_config .config .params .image_token_size ,
2853
- framework = framework ,
2854
- dtype = int_dtype ,
2855
- )
2843
+ return self .random_int_tensor ([self .batch_size , 576 ], max_value = self .normalized_config .config .params .image_token_size , framework = framework , dtype = int_dtype )
2856
2844
if input_name == "image_shape" :
2857
2845
import torch
2858
-
2859
- return torch .tensor (
2860
- [self .batch_size , self .normalized_config .config .params .n_embed , 24 , 24 ], dtype = torch .int64
2861
- )
2846
+ return torch .tensor ([self .batch_size , self .normalized_config .config .params .n_embed , 24 , 24 ], dtype = torch .int64 )
2862
2847
if input_name == "hidden_state" :
2863
- return self .random_float_tensor (
2864
- [self .batch_size , self .sequence_length , self .normalized_config .hidden_size ]
2865
- )
2848
+ return self .random_float_tensor ([self .batch_size , self .sequence_length , self .normalized_config .hidden_size ])
2866
2849
if input_name == "lm_hidden_state" :
2867
2850
return self .random_float_tensor ([self .sequence_length , self .normalized_config .hidden_size ])
2868
2851
return super ().generate (input_name , framework , int_dtype , float_dtype )
2852
+
2869
2853
2870
2854
2871
2855
@register_in_tasks_manager ("multi-modality" , * ["image-text-to-text" , "any-to-any" ], library_name = "transformers" )
@@ -2883,7 +2867,7 @@ def __init__(
2883
2867
float_dtype : str = "fp32" ,
2884
2868
behavior : JanusConfigBehavior = JanusConfigBehavior .VISION_EMBEDDINGS ,
2885
2869
preprocessors : Optional [List [Any ]] = None ,
2886
- ** kwargs ,
2870
+ ** kwargs
2887
2871
):
2888
2872
super ().__init__ (
2889
2873
config = config ,
@@ -2897,9 +2881,7 @@ def __init__(
2897
2881
if self ._behavior == JanusConfigBehavior .VISION_EMBEDDINGS and hasattr (config , "vision_config" ):
2898
2882
self ._config = config .vision_config
2899
2883
self ._normalized_config = NormalizedVisionConfig (self ._config )
2900
- if self ._behavior in [JanusConfigBehavior .LM_HEAD , JanusConfigBehavior .VISION_GEN_HEAD ] and hasattr (
2901
- config , "language_config"
2902
- ):
2884
+ if self ._behavior in [JanusConfigBehavior .LM_HEAD , JanusConfigBehavior .VISION_GEN_HEAD ] and hasattr (config , "language_config" ):
2903
2885
self ._config = config .language_config
2904
2886
self ._normalized_config = NormalizedTextConfig (self ._config )
2905
2887
if self ._behavior == JanusConfigBehavior .VISION_GEN_EMBEDDINGS and hasattr (config , "gen_head_config" ):
@@ -2929,7 +2911,7 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
2929
2911
return {"last_hidden_state" : {0 : "batch_size" }}
2930
2912
if self ._behavior == JanusConfigBehavior .VISION_GEN_EMBEDDINGS :
2931
2913
return {"last_hidden_state" : {0 : "num_tokens" }}
2932
-
2914
+
2933
2915
if self ._behavior == JanusConfigBehavior .LM_HEAD :
2934
2916
return {"logits" : {0 : "batch_size" , 1 : "sequence_length" }}
2935
2917
@@ -2996,6 +2978,7 @@ def with_behavior(
2996
2978
preprocessors = self ._preprocessors ,
2997
2979
)
2998
2980
2981
+
2999
2982
if behavior == JanusConfigBehavior .VISION_EMBEDDINGS :
3000
2983
return self .__class__ (
3001
2984
self ._orig_config ,
@@ -3005,7 +2988,7 @@ def with_behavior(
3005
2988
behavior = behavior ,
3006
2989
preprocessors = self ._preprocessors ,
3007
2990
)
3008
-
2991
+
3009
2992
if behavior == JanusConfigBehavior .VISION_GEN_DECODER :
3010
2993
return self .__class__ (
3011
2994
self ._orig_config ,
@@ -3016,6 +2999,7 @@ def with_behavior(
3016
2999
preprocessors = self ._preprocessors ,
3017
3000
)
3018
3001
3002
+
3019
3003
def get_model_for_behavior (self , model , behavior : Union [str , JanusConfigBehavior ]):
3020
3004
if isinstance (behavior , str ) and not isinstance (behavior , JanusConfigBehavior ):
3021
3005
behavior = JanusConfigBehavior (behavior )
@@ -3038,7 +3022,7 @@ def get_model_for_behavior(self, model, behavior: Union[str, JanusConfigBehavior
3038
3022
3039
3023
if behavior == JanusConfigBehavior .VISION_GEN_EMBEDDINGS :
3040
3024
return model
3041
-
3025
+
3042
3026
if behavior == JanusConfigBehavior .VISION_GEN_HEAD :
3043
3027
gen_head = model .gen_head
3044
3028
gen_head .config = model .language_model .config
@@ -3047,6 +3031,7 @@ def get_model_for_behavior(self, model, behavior: Union[str, JanusConfigBehavior
3047
3031
if behavior == JanusConfigBehavior .VISION_GEN_DECODER :
3048
3032
return model .gen_vision_model
3049
3033
3034
+
3050
3035
def patch_model_for_export (
3051
3036
self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
3052
3037
):
@@ -3059,6 +3044,7 @@ def patch_model_for_export(
3059
3044
return JanusVisionGenDecoderModelPatcher (self , model , model_kwargs )
3060
3045
return super ().patch_model_for_export (model , model_kwargs )
3061
3046
3047
+
3062
3048
def rename_ambiguous_inputs (self , inputs ):
3063
3049
if self ._behavior == JanusConfigBehavior .VISION_GEN_HEAD :
3064
3050
data = inputs .pop ("lm_hidden_state" )
@@ -3069,4 +3055,4 @@ def rename_ambiguous_inputs(self, inputs):
3069
3055
if self ._behavior == JanusConfigBehavior .VISION_GEN_DECODER :
3070
3056
data = inputs .pop ("image_shape" )
3071
3057
inputs ["shape" ] = data
3072
- return inputs
3058
+ return inputs
0 commit comments