@@ -145,11 +145,10 @@ def init_model_configs():
145
145
# for model registration in auto transformers classses
146
146
if importlib .util .find_spec ("janus" ) is not None :
147
147
try :
148
- from janus .models import MultiModalityCausalLM
148
+ from janus .models import MultiModalityCausalLM , VLChatProcessor
149
149
except ImportError :
150
150
pass
151
151
152
-
153
152
if is_diffusers_available () and "fill" not in TasksManager ._DIFFUSERS_TASKS_TO_MODEL_LOADERS :
154
153
TasksManager ._DIFFUSERS_TASKS_TO_MODEL_LOADERS ["fill" ] = "FluxFillPipeline"
155
154
TasksManager ._DIFFUSERS_TASKS_TO_MODEL_MAPPINGS ["fill" ] = {"flux" : "FluxFillPipeline" }
@@ -1353,7 +1352,9 @@ def patch_model_for_export(
1353
1352
1354
1353
1355
1354
class LMInputEmbedsConfigHelper (TextDecoderWithPositionIdsOnnxConfig ):
1356
- def __init__ (self , export_config , patcher_cls = None , dummy_input_generator = None , inputs_update = None , remove_lm_head = False ):
1355
+ def __init__ (
1356
+ self , export_config , patcher_cls = None , dummy_input_generator = None , inputs_update = None , remove_lm_head = False
1357
+ ):
1357
1358
self .orig_export_config = export_config
1358
1359
if dummy_input_generator is not None :
1359
1360
export_config .DUMMY_INPUT_GENERATOR_CLASSES = (
@@ -1372,16 +1373,15 @@ def __init__(self, export_config, patcher_cls=None, dummy_input_generator=None,
1372
1373
def patch_model_for_export (
1373
1374
self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
1374
1375
) -> "ModelPatcher" :
1375
-
1376
1376
if self .patcher_cls is not None :
1377
1377
patcher = self .patcher_cls (self , model , model_kwargs = model_kwargs )
1378
1378
# Refer to DecoderModelPatcher.
1379
- else :
1379
+ else :
1380
1380
patcher = self .orig_export_config .patch_model_for_export (model , model_kwargs = model_kwargs )
1381
-
1381
+
1382
1382
if self .remove_lm_head :
1383
1383
patcher = RemoveLMHeadPatcherHelper (self , model , model_kwargs , patcher )
1384
-
1384
+
1385
1385
return patcher
1386
1386
1387
1387
@property
@@ -1390,7 +1390,7 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
1390
1390
if self .remove_lm_head :
1391
1391
logits_info = outputs .pop ("logits" )
1392
1392
updated_outputs = {"last_hidden_state" : logits_info }
1393
- return {** updated_outputs , ** outputs }
1393
+ return {** updated_outputs , ** outputs }
1394
1394
return outputs
1395
1395
1396
1396
@property
@@ -1479,15 +1479,15 @@ def get_vlm_text_generation_config(
1479
1479
model_patcher = None ,
1480
1480
dummy_input_generator = None ,
1481
1481
inputs_update = None ,
1482
- remove_lm_head = False
1482
+ remove_lm_head = False ,
1483
1483
):
1484
1484
internal_export_config = get_vlm_internal_text_generation_config (model_type , model_config , int_dtype , float_dtype )
1485
1485
export_config = LMInputEmbedsConfigHelper (
1486
1486
internal_export_config ,
1487
1487
patcher_cls = model_patcher ,
1488
1488
dummy_input_generator = dummy_input_generator ,
1489
1489
inputs_update = inputs_update ,
1490
- remove_lm_head = remove_lm_head
1490
+ remove_lm_head = remove_lm_head ,
1491
1491
)
1492
1492
export_config ._normalized_config = internal_export_config ._normalized_config
1493
1493
return export_config
@@ -2812,45 +2812,60 @@ class JanusConfigBehavior(str, enum.Enum):
2812
2812
2813
2813
2814
2814
class JanusDummyVisionGenInputGenerator (DummyInputGenerator ):
2815
- SUPPORTED_INPUT_NAMES = (
2816
- "pixel_values" ,
2817
- "image_ids" ,
2818
- "code_b" ,
2819
- "image_shape" ,
2820
- "lm_hidden_state" ,
2821
- "hidden_state"
2822
- )
2815
+ SUPPORTED_INPUT_NAMES = ("pixel_values" , "image_ids" , "code_b" , "image_shape" , "lm_hidden_state" , "hidden_state" )
2823
2816
2824
2817
def __init__ (
2825
- self ,
2826
- task : str ,
2827
- normalized_config : NormalizedConfig ,
2828
- batch_size : int = DEFAULT_DUMMY_SHAPES ["batch_size" ],
2829
- sequence_length : int = DEFAULT_DUMMY_SHAPES ["sequence_length" ],
2830
- ** kwargs ,
2831
- ):
2832
- self .task = task
2833
- self .batch_size = batch_size
2834
- self .sequence_length = sequence_length
2835
- self .normalized_config = normalized_config
2836
-
2818
+ self ,
2819
+ task : str ,
2820
+ normalized_config : NormalizedConfig ,
2821
+ batch_size : int = DEFAULT_DUMMY_SHAPES ["batch_size" ],
2822
+ sequence_length : int = DEFAULT_DUMMY_SHAPES ["sequence_length" ],
2823
+ ** kwargs ,
2824
+ ):
2825
+ self .task = task
2826
+ self .batch_size = batch_size
2827
+ self .sequence_length = sequence_length
2828
+ self .normalized_config = normalized_config
2829
+
2837
2830
def generate (self , input_name : str , framework : str = "pt" , int_dtype : str = "int64" , float_dtype : str = "fp32" ):
2838
2831
if input_name == "pixel_values" :
2839
- return self .random_float_tensor ([self .batch_size , 1 , 3 , self .normalized_config .config .params .image_size , self .normalized_config .config .params .image_size ])
2840
-
2832
+ return self .random_float_tensor (
2833
+ [
2834
+ self .batch_size ,
2835
+ 1 ,
2836
+ 3 ,
2837
+ self .normalized_config .config .params .image_size ,
2838
+ self .normalized_config .config .params .image_size ,
2839
+ ]
2840
+ )
2841
+
2841
2842
if input_name == "image_ids" :
2842
- return self .random_int_tensor ([self .sequence_length ], max_value = self .normalized_config .config .params .image_token_size , framework = framework , dtype = int_dtype )
2843
+ return self .random_int_tensor (
2844
+ [self .sequence_length ],
2845
+ max_value = self .normalized_config .config .params .image_token_size ,
2846
+ framework = framework ,
2847
+ dtype = int_dtype ,
2848
+ )
2843
2849
if input_name == "code_b" :
2844
- return self .random_int_tensor ([self .batch_size , 576 ], max_value = self .normalized_config .config .params .image_token_size , framework = framework , dtype = int_dtype )
2850
+ return self .random_int_tensor (
2851
+ [self .batch_size , 576 ],
2852
+ max_value = self .normalized_config .config .params .image_token_size ,
2853
+ framework = framework ,
2854
+ dtype = int_dtype ,
2855
+ )
2845
2856
if input_name == "image_shape" :
2846
2857
import torch
2847
- return torch .tensor ([self .batch_size , self .normalized_config .config .params .n_embed , 24 , 24 ], dtype = torch .int64 )
2858
+
2859
+ return torch .tensor (
2860
+ [self .batch_size , self .normalized_config .config .params .n_embed , 24 , 24 ], dtype = torch .int64
2861
+ )
2848
2862
if input_name == "hidden_state" :
2849
- return self .random_float_tensor ([self .batch_size , self .sequence_length , self .normalized_config .hidden_size ])
2863
+ return self .random_float_tensor (
2864
+ [self .batch_size , self .sequence_length , self .normalized_config .hidden_size ]
2865
+ )
2850
2866
if input_name == "lm_hidden_state" :
2851
2867
return self .random_float_tensor ([self .sequence_length , self .normalized_config .hidden_size ])
2852
2868
return super ().generate (input_name , framework , int_dtype , float_dtype )
2853
-
2854
2869
2855
2870
2856
2871
@register_in_tasks_manager ("multi-modality" , * ["image-text-to-text" , "any-to-any" ], library_name = "transformers" )
@@ -2868,7 +2883,7 @@ def __init__(
2868
2883
float_dtype : str = "fp32" ,
2869
2884
behavior : JanusConfigBehavior = JanusConfigBehavior .VISION_EMBEDDINGS ,
2870
2885
preprocessors : Optional [List [Any ]] = None ,
2871
- ** kwargs
2886
+ ** kwargs ,
2872
2887
):
2873
2888
super ().__init__ (
2874
2889
config = config ,
@@ -2882,7 +2897,9 @@ def __init__(
2882
2897
if self ._behavior == JanusConfigBehavior .VISION_EMBEDDINGS and hasattr (config , "vision_config" ):
2883
2898
self ._config = config .vision_config
2884
2899
self ._normalized_config = NormalizedVisionConfig (self ._config )
2885
- if self ._behavior in [JanusConfigBehavior .LM_HEAD , JanusConfigBehavior .VISION_GEN_HEAD ] and hasattr (config , "language_config" ):
2900
+ if self ._behavior in [JanusConfigBehavior .LM_HEAD , JanusConfigBehavior .VISION_GEN_HEAD ] and hasattr (
2901
+ config , "language_config"
2902
+ ):
2886
2903
self ._config = config .language_config
2887
2904
self ._normalized_config = NormalizedTextConfig (self ._config )
2888
2905
if self ._behavior == JanusConfigBehavior .VISION_GEN_EMBEDDINGS and hasattr (config , "gen_head_config" ):
@@ -2912,7 +2929,7 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
2912
2929
return {"last_hidden_state" : {0 : "batch_size" }}
2913
2930
if self ._behavior == JanusConfigBehavior .VISION_GEN_EMBEDDINGS :
2914
2931
return {"last_hidden_state" : {0 : "num_tokens" }}
2915
-
2932
+
2916
2933
if self ._behavior == JanusConfigBehavior .LM_HEAD :
2917
2934
return {"logits" : {0 : "batch_size" , 1 : "sequence_length" }}
2918
2935
@@ -2979,7 +2996,6 @@ def with_behavior(
2979
2996
preprocessors = self ._preprocessors ,
2980
2997
)
2981
2998
2982
-
2983
2999
if behavior == JanusConfigBehavior .VISION_EMBEDDINGS :
2984
3000
return self .__class__ (
2985
3001
self ._orig_config ,
@@ -2989,7 +3005,7 @@ def with_behavior(
2989
3005
behavior = behavior ,
2990
3006
preprocessors = self ._preprocessors ,
2991
3007
)
2992
-
3008
+
2993
3009
if behavior == JanusConfigBehavior .VISION_GEN_DECODER :
2994
3010
return self .__class__ (
2995
3011
self ._orig_config ,
@@ -3000,7 +3016,6 @@ def with_behavior(
3000
3016
preprocessors = self ._preprocessors ,
3001
3017
)
3002
3018
3003
-
3004
3019
def get_model_for_behavior (self , model , behavior : Union [str , JanusConfigBehavior ]):
3005
3020
if isinstance (behavior , str ) and not isinstance (behavior , JanusConfigBehavior ):
3006
3021
behavior = JanusConfigBehavior (behavior )
@@ -3023,7 +3038,7 @@ def get_model_for_behavior(self, model, behavior: Union[str, JanusConfigBehavior
3023
3038
3024
3039
if behavior == JanusConfigBehavior .VISION_GEN_EMBEDDINGS :
3025
3040
return model
3026
-
3041
+
3027
3042
if behavior == JanusConfigBehavior .VISION_GEN_HEAD :
3028
3043
gen_head = model .gen_head
3029
3044
gen_head .config = model .language_model .config
@@ -3032,7 +3047,6 @@ def get_model_for_behavior(self, model, behavior: Union[str, JanusConfigBehavior
3032
3047
if behavior == JanusConfigBehavior .VISION_GEN_DECODER :
3033
3048
return model .gen_vision_model
3034
3049
3035
-
3036
3050
def patch_model_for_export (
3037
3051
self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
3038
3052
):
@@ -3045,7 +3059,6 @@ def patch_model_for_export(
3045
3059
return JanusVisionGenDecoderModelPatcher (self , model , model_kwargs )
3046
3060
return super ().patch_model_for_export (model , model_kwargs )
3047
3061
3048
-
3049
3062
def rename_ambiguous_inputs (self , inputs ):
3050
3063
if self ._behavior == JanusConfigBehavior .VISION_GEN_HEAD :
3051
3064
data = inputs .pop ("lm_hidden_state" )
@@ -3056,4 +3069,4 @@ def rename_ambiguous_inputs(self, inputs):
3056
3069
if self ._behavior == JanusConfigBehavior .VISION_GEN_DECODER :
3057
3070
data = inputs .pop ("image_shape" )
3058
3071
inputs ["shape" ] = data
3059
- return inputs
3072
+ return inputs
0 commit comments