89
89
Phi3ModelPatcher ,
90
90
Phi3VisionImageEmbeddingsPatcher ,
91
91
QwenModelPatcher ,
92
+ Qwen2VLLanguageModelPatcher ,
93
+ Qwen2VLVisionEmbMergerPatcher ,
92
94
RotaryEmbPatcher ,
93
95
UpdateCausalMaskModelPatcher ,
94
96
XverseModelPatcher ,
@@ -106,9 +108,13 @@ def init_model_configs():
106
108
"transformers" ,
107
109
"LlavaNextForConditionalGeneration" ,
108
110
)
109
- TasksManager ._TRANSFORMERS_TASKS_TO_MODEL_LOADERS [
110
- "image-text-to-text"
111
- ] = TasksManager ._TRANSFORMERS_TASKS_TO_MODEL_LOADERS ["text-generation" ]
111
+ TasksManager ._CUSTOM_CLASSES [("pt" , "qwen2-vl" , "image-text-to-text" )] = (
112
+ "transformers" ,
113
+ "Qwen2VLForConditionalGeneration" ,
114
+ )
115
+ TasksManager ._TRANSFORMERS_TASKS_TO_MODEL_LOADERS ["image-text-to-text" ] = (
116
+ TasksManager ._TRANSFORMERS_TASKS_TO_MODEL_LOADERS ["text-generation" ]
117
+ )
112
118
113
119
supported_model_types = [
114
120
"_SUPPORTED_MODEL_TYPE" ,
@@ -1288,18 +1294,26 @@ def patch_model_for_export(
1288
1294
1289
1295
1290
1296
class LMInputEmbedsConfigHelper (TextDecoderWithPositionIdsOnnxConfig ):
1291
- def __init__ (self , export_config ):
1297
+ def __init__ (self , export_config , patcher_cls = None , dummy_input_generator = None , inputs_update = None ):
1292
1298
self .orig_export_config = export_config
1299
+ if dummy_input_generator is not None :
1300
+ export_config .DUMMY_INPUT_GENERATOR_CLASSES = (
1301
+ dummy_input_generator ,
1302
+ ) + export_config .DUMMY_INPUT_GENERATOR_CLASSES
1293
1303
self .DUMMY_INPUT_GENERATOR_CLASSES = export_config .DUMMY_INPUT_GENERATOR_CLASSES
1294
1304
self .DEFAULT_ONNX_OPSET = export_config .DEFAULT_ONNX_OPSET
1295
1305
self .DUMMY_PKV_GENERATOR_CLASS = export_config .DUMMY_PKV_GENERATOR_CLASS
1296
1306
self ._config = export_config ._config
1297
1307
self ._normalized_config = export_config ._normalized_config
1298
1308
self .use_past = export_config .use_past
1309
+ self .patcher_cls = patcher_cls
1310
+ self .input_info_upd = inputs_update
1299
1311
1300
1312
def patch_model_for_export (
1301
1313
self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
1302
1314
) -> "ModelPatcher" :
1315
+ if self .patcher_cls is not None :
1316
+ return self .patcher_cls (self , model , model_kwargs = model_kwargs )
1303
1317
# Refer to DecoderModelPatcher.
1304
1318
return self .orig_export_config .patch_model_for_export (model , model_kwargs = model_kwargs )
1305
1319
@@ -1312,6 +1326,8 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
1312
1326
orig_inputs = self .orig_export_config .inputs
1313
1327
input_ids_config = orig_inputs .pop ("input_ids" )
1314
1328
orig_inputs ["inputs_embeds" ] = input_ids_config
1329
+ if self .input_info_upd is not None :
1330
+ orig_inputs .update (self .input_info_upd )
1315
1331
return orig_inputs
1316
1332
1317
1333
def generate_dummy_inputs (self , framework : str = "pt" , ** kwargs ):
@@ -1383,9 +1399,22 @@ def get_vlm_text_embeddings_config(model_type, model_config, int_dtype, float_dt
1383
1399
return export_config
1384
1400
1385
1401
1386
- def get_vlm_text_generation_config (model_type , model_config , int_dtype , float_dtype ):
1402
+ def get_vlm_text_generation_config (
1403
+ model_type ,
1404
+ model_config ,
1405
+ int_dtype ,
1406
+ float_dtype ,
1407
+ model_patcher = None ,
1408
+ dummy_input_generator = None ,
1409
+ inputs_update = None ,
1410
+ ):
1387
1411
internal_export_config = get_vlm_internal_text_generation_config (model_type , model_config , int_dtype , float_dtype )
1388
- export_config = LMInputEmbedsConfigHelper (internal_export_config )
1412
+ export_config = LMInputEmbedsConfigHelper (
1413
+ internal_export_config ,
1414
+ patcher_cls = model_patcher ,
1415
+ dummy_input_generator = dummy_input_generator ,
1416
+ inputs_update = inputs_update ,
1417
+ )
1389
1418
export_config ._normalized_config = internal_export_config ._normalized_config
1390
1419
return export_config
1391
1420
@@ -1820,9 +1849,11 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
1820
1849
img_ids_height = self .height // 2
1821
1850
img_ids_width = self .width // 2
1822
1851
return self .random_int_tensor (
1823
- [self .batch_size , img_ids_height * img_ids_width , 3 ]
1824
- if is_diffusers_version ("<" , "0.31.0" )
1825
- else [img_ids_height * img_ids_width , 3 ],
1852
+ (
1853
+ [self .batch_size , img_ids_height * img_ids_width , 3 ]
1854
+ if is_diffusers_version ("<" , "0.31.0" )
1855
+ else [img_ids_height * img_ids_width , 3 ]
1856
+ ),
1826
1857
min_value = 0 ,
1827
1858
max_value = min (img_ids_height , img_ids_width ),
1828
1859
framework = framework ,
@@ -2259,3 +2290,218 @@ def patch_model_for_export(
2259
2290
if self ._behavior == Phi3VisionConfigBehavior .VISION_EMBEDDINGS :
2260
2291
return Phi3VisionImageEmbeddingsPatcher (self , model , model_kwargs )
2261
2292
return super ().patch_model_for_export (model , model_kwargs )
2293
+
2294
+
2295
+ class DummyQwen2VLLMInputGenerator (DummyTextInputGenerator ):
2296
+ def generate (self , input_name : str , framework : str = "pt" , int_dtype : str = "int64" , float_dtype : str = "fp32" ):
2297
+ generated_input = super ().generate (input_name , framework , int_dtype , float_dtype )
2298
+ if input_name == "position_ids" :
2299
+ return generated_input .unsqueeze (0 ).expand (3 , - 1 , - 1 )
2300
+ return generated_input
2301
+
2302
+
2303
+ class DummyQwen2VLVisionEMbedInputGenerator (DummyVisionInputGenerator ):
2304
+ SUPPORTED_INPUT_NAMES = ("hidden_states" ,)
2305
+
2306
+ def __init__ (
2307
+ self ,
2308
+ task : str ,
2309
+ normalized_config : NormalizedVisionConfig ,
2310
+ batch_size : int = 1 ,
2311
+ num_channels : int = DEFAULT_DUMMY_SHAPES ["num_channels" ],
2312
+ width : int = 420 ,
2313
+ height : int = 420 ,
2314
+ ** kwargs ,
2315
+ ):
2316
+ self .batch_size = batch_size
2317
+ self .height = height
2318
+ self .width = width
2319
+ self .num_channels = num_channels
2320
+ self .temporal_patch_size = normalized_config .config .temporal_patch_size
2321
+ self .patch_size = normalized_config .config .patch_size
2322
+
2323
+ def generate (self , input_name : str , framework : str = "pt" , int_dtype : str = "int64" , float_dtype : str = "fp32" ):
2324
+ grid_h , grid_w = self .height // self .patch_size , self .width // self .patch_size
2325
+ grid_t = self .batch_size
2326
+ shape = [
2327
+ grid_t * grid_h * grid_w ,
2328
+ self .num_channels * self .temporal_patch_size * self .patch_size * self .patch_size ,
2329
+ ]
2330
+ return self .random_float_tensor (shape , framework = framework , dtype = float_dtype )
2331
+
2332
+
2333
+ class DummyQwen2VLVisionEmbedMergerInputGenerator (DummyVisionInputGenerator ):
2334
+ SUPPORTED_INPUT_NAMES = ("hidden_states" , "attention_mask" , "rotary_pos_emb" )
2335
+
2336
+ def __init__ (
2337
+ self ,
2338
+ task : str ,
2339
+ normalized_config : NormalizedVisionConfig ,
2340
+ batch_size : int = 1 ,
2341
+ num_channels : int = DEFAULT_DUMMY_SHAPES ["num_channels" ],
2342
+ width : int = 420 ,
2343
+ height : int = 420 ,
2344
+ ** kwargs ,
2345
+ ):
2346
+ self .batch_size = batch_size
2347
+ self .height = height
2348
+ self .width = width
2349
+ self .num_channels = num_channels
2350
+ self .temporal_patch_size = normalized_config .config .temporal_patch_size
2351
+ self .patch_size = normalized_config .config .patch_size
2352
+ self .embed_dim = normalized_config .config .embed_dim
2353
+ self .num_heads = normalized_config .config .num_heads
2354
+
2355
+ def generate (self , input_name : str , framework : str = "pt" , int_dtype : str = "int64" , float_dtype : str = "fp32" ):
2356
+ grid_h , grid_w = self .height // self .patch_size , self .width // self .patch_size
2357
+ grid_t = self .batch_size
2358
+
2359
+ if input_name == "hidden_states" :
2360
+ return self .random_float_tensor (
2361
+ [grid_t * grid_h * grid_w , self .embed_dim ], framework = framework , dtype = float_dtype
2362
+ )
2363
+
2364
+ if input_name == "attention_mask" :
2365
+ return self .random_mask_tensor (
2366
+ [1 , grid_t * grid_h * grid_w , grid_t * grid_h * grid_w ], framework = framework , dtype = float_dtype
2367
+ )
2368
+
2369
+ if input_name == "rotary_pos_emb" :
2370
+ dim = self .embed_dim // self .num_heads // 2
2371
+ return self .random_float_tensor ([grid_h * grid_t * grid_w , dim ], framework = framework , dtype = float_dtype )
2372
+
2373
+
2374
+ class Qwen2VLConfigBehavior (str , enum .Enum ):
2375
+ LANGUAGE = "language"
2376
+ VISION_EMBEDDINGS = "vision_embeddings"
2377
+ VISION_EMBEDDINGS_MERGER = "vision_embeddings_merger"
2378
+ TEXT_EMBEDDINGS = "text_embeddings"
2379
+
2380
+
2381
+ @register_in_tasks_manager ("qwen2-vl" , * ["image-text-to-text" ], library_name = "transformers" )
2382
+ class Qwen2VLOpenVINOConfig (OnnxConfig ):
2383
+ SUPPORTED_BEHAVIORS = [model_type .value for model_type in Qwen2VLConfigBehavior ]
2384
+ NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
2385
+ DUMMY_INPUT_GENERATOR_CLASSES = (DummyQwen2VLVisionEMbedInputGenerator ,)
2386
+ MIN_TRANSFORMERS_VERSION = version .parse ("4.45.0" )
2387
+
2388
+ def __init__ (
2389
+ self ,
2390
+ config : "PretrainedConfig" ,
2391
+ task : str = "feature-extraction" ,
2392
+ int_dtype : str = "int64" ,
2393
+ float_dtype : str = "fp32" ,
2394
+ behavior : Qwen2VLConfigBehavior = Qwen2VLConfigBehavior .VISION_EMBEDDINGS ,
2395
+ preprocessors : Optional [List [Any ]] = None ,
2396
+ ):
2397
+ super ().__init__ (
2398
+ config = config ,
2399
+ task = task ,
2400
+ int_dtype = int_dtype ,
2401
+ float_dtype = float_dtype ,
2402
+ preprocessors = preprocessors ,
2403
+ )
2404
+ self ._behavior = behavior
2405
+ self ._orig_config = config
2406
+ if self ._behavior == Qwen2VLConfigBehavior .VISION_EMBEDDINGS and hasattr (config , "vision_config" ):
2407
+ self ._config = config .vision_config
2408
+ self ._normalized_config = self .NORMALIZED_CONFIG_CLASS (self ._config )
2409
+ self .DUMMY_INPUT_GENERATOR_CLASSES = (DummyQwen2VLVisionEMbedInputGenerator ,)
2410
+ if self ._behavior == Qwen2VLConfigBehavior .VISION_EMBEDDINGS_MERGER and hasattr (config , "vision_config" ):
2411
+ self ._config = config .vision_config
2412
+ self ._normalized_config = self .NORMALIZED_CONFIG_CLASS (self ._config )
2413
+ self .DUMMY_INPUT_GENERATOR_CLASSES = (DummyQwen2VLVisionEmbedMergerInputGenerator ,)
2414
+
2415
+ @staticmethod
2416
+ def get_model_for_behavior (model , behavior : Union [str , Qwen2VLConfigBehavior ]):
2417
+ if isinstance (behavior , str ) and not isinstance (behavior , Qwen2VLConfigBehavior ):
2418
+ behavior = Qwen2VLConfigBehavior (behavior )
2419
+
2420
+ if behavior == Qwen2VLConfigBehavior .LANGUAGE :
2421
+ return model
2422
+
2423
+ if behavior == Qwen2VLConfigBehavior .VISION_EMBEDDINGS :
2424
+ vision_embeddings = model .visual .patch_embed
2425
+ vision_embeddings .config = model .config .vision_config
2426
+ return vision_embeddings
2427
+
2428
+ if behavior == Qwen2VLConfigBehavior .VISION_EMBEDDINGS_MERGER :
2429
+ vision_emb_merger = model .visual
2430
+ vision_emb_merger .config = model .config .vision_config
2431
+ return vision_emb_merger
2432
+
2433
+ if behavior == Qwen2VLConfigBehavior .TEXT_EMBEDDINGS :
2434
+ text_embedding = model .model .embed_tokens
2435
+ text_embedding .config = model .config
2436
+ return text_embedding
2437
+
2438
+ def with_behavior (
2439
+ self ,
2440
+ behavior : Union [str , Qwen2VLConfigBehavior ],
2441
+ ):
2442
+ """
2443
+ Creates a config for different behaviour.
2444
+ Args:
2445
+ behavior ([`ConfigBehavior`]):
2446
+ The behavior to use for the new instance.
2447
+ """
2448
+ if isinstance (behavior , str ) and not isinstance (behavior , Qwen2VLConfigBehavior ):
2449
+ behavior = Qwen2VLConfigBehavior (behavior )
2450
+
2451
+ if behavior == Qwen2VLConfigBehavior .TEXT_EMBEDDINGS :
2452
+ return get_vlm_text_embeddings_config ("qwen2" , self ._orig_config , self .int_dtype , self .float_dtype )
2453
+
2454
+ if behavior == Qwen2VLConfigBehavior .LANGUAGE :
2455
+ return get_vlm_text_generation_config (
2456
+ "qwen2" ,
2457
+ self ._orig_config ,
2458
+ self .int_dtype ,
2459
+ self .float_dtype ,
2460
+ model_patcher = Qwen2VLLanguageModelPatcher ,
2461
+ dummy_input_generator = DummyQwen2VLLMInputGenerator ,
2462
+ inputs_update = {"position_ids" : {1 : "batch_size" , 2 : "sequence_length" }},
2463
+ )
2464
+
2465
+ if behavior == Qwen2VLConfigBehavior .VISION_EMBEDDINGS :
2466
+ return self .__class__ (
2467
+ self ._orig_config ,
2468
+ task = self .task ,
2469
+ int_dtype = self .int_dtype ,
2470
+ float_dtype = self .float_dtype ,
2471
+ behavior = behavior ,
2472
+ preprocessors = self ._preprocessors ,
2473
+ )
2474
+ if behavior == Qwen2VLConfigBehavior .VISION_EMBEDDINGS_MERGER :
2475
+ return self .__class__ (
2476
+ self ._orig_config ,
2477
+ task = self .task ,
2478
+ int_dtype = self .int_dtype ,
2479
+ float_dtype = self .float_dtype ,
2480
+ behavior = behavior ,
2481
+ preprocessors = self ._preprocessors ,
2482
+ )
2483
+
2484
+ def patch_model_for_export (
2485
+ self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
2486
+ ):
2487
+ model_kwargs = model_kwargs or {}
2488
+ if self ._behavior == Qwen2VLConfigBehavior .VISION_EMBEDDINGS_MERGER :
2489
+ return Qwen2VLVisionEmbMergerPatcher (self , model , model_kwargs )
2490
+ return super ().patch_model_for_export (model , model_kwargs )
2491
+
2492
+ @property
2493
+ def inputs (self ) -> Dict [str , Dict [int , str ]]:
2494
+ if self ._behavior == Phi3VisionConfigBehavior .VISION_EMBEDDINGS :
2495
+ return {"hidden_states" : {0 : "patch_thw_grid" , 1 : "patch_temporal_channels" }}
2496
+ if self ._behavior == Qwen2VLConfigBehavior .VISION_EMBEDDINGS_MERGER :
2497
+ return {
2498
+ "hidden_states" : {0 : "sequence_length" },
2499
+ "attention_mask" : {1 : "sequence_length" , 2 : "sequence_length" },
2500
+ "rotary_pos_emb" : {0 : "sequence_length" },
2501
+ }
2502
+
2503
+ @property
2504
+ def outputs (self ) -> Dict [str , Dict [int , str ]]:
2505
+ if self ._behavior in [Qwen2VLConfigBehavior .VISION_EMBEDDINGS , Qwen2VLConfigBehavior .VISION_EMBEDDINGS_MERGER ]:
2506
+ return {"last_hidden_state" : {0 : "seq_len" }}
2507
+ return {}
0 commit comments