88
88
PersimmonModelPatcher ,
89
89
Phi3ModelPatcher ,
90
90
Phi3VisionImageEmbeddingsPatcher ,
91
+ Qwen2VLLanguageModelPatcher ,
92
+ Qwen2VLVisionEmbMergerPatcher ,
91
93
QwenModelPatcher ,
92
94
RotaryEmbPatcher ,
93
95
UpdateCausalMaskModelPatcher ,
@@ -106,6 +108,10 @@ def init_model_configs():
106
108
"transformers" ,
107
109
"LlavaNextForConditionalGeneration" ,
108
110
)
111
+ TasksManager ._CUSTOM_CLASSES [("pt" , "qwen2-vl" , "image-text-to-text" )] = (
112
+ "transformers" ,
113
+ "Qwen2VLForConditionalGeneration" ,
114
+ )
109
115
TasksManager ._TRANSFORMERS_TASKS_TO_MODEL_LOADERS [
110
116
"image-text-to-text"
111
117
] = TasksManager ._TRANSFORMERS_TASKS_TO_MODEL_LOADERS ["text-generation" ]
@@ -1288,18 +1294,26 @@ def patch_model_for_export(
1288
1294
1289
1295
1290
1296
class LMInputEmbedsConfigHelper (TextDecoderWithPositionIdsOnnxConfig ):
1291
- def __init__ (self , export_config ):
1297
+ def __init__ (self , export_config , patcher_cls = None , dummy_input_generator = None , inputs_update = None ):
1292
1298
self .orig_export_config = export_config
1299
+ if dummy_input_generator is not None :
1300
+ export_config .DUMMY_INPUT_GENERATOR_CLASSES = (
1301
+ dummy_input_generator ,
1302
+ ) + export_config .DUMMY_INPUT_GENERATOR_CLASSES
1293
1303
self .DUMMY_INPUT_GENERATOR_CLASSES = export_config .DUMMY_INPUT_GENERATOR_CLASSES
1294
1304
self .DEFAULT_ONNX_OPSET = export_config .DEFAULT_ONNX_OPSET
1295
1305
self .DUMMY_PKV_GENERATOR_CLASS = export_config .DUMMY_PKV_GENERATOR_CLASS
1296
1306
self ._config = export_config ._config
1297
1307
self ._normalized_config = export_config ._normalized_config
1298
1308
self .use_past = export_config .use_past
1309
+ self .patcher_cls = patcher_cls
1310
+ self .input_info_upd = inputs_update
1299
1311
1300
1312
def patch_model_for_export (
1301
1313
self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
1302
1314
) -> "ModelPatcher" :
1315
+ if self .patcher_cls is not None :
1316
+ return self .patcher_cls (self , model , model_kwargs = model_kwargs )
1303
1317
# Refer to DecoderModelPatcher.
1304
1318
return self .orig_export_config .patch_model_for_export (model , model_kwargs = model_kwargs )
1305
1319
@@ -1312,6 +1326,8 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
1312
1326
orig_inputs = self .orig_export_config .inputs
1313
1327
input_ids_config = orig_inputs .pop ("input_ids" )
1314
1328
orig_inputs ["inputs_embeds" ] = input_ids_config
1329
+ if self .input_info_upd is not None :
1330
+ orig_inputs .update (self .input_info_upd )
1315
1331
return orig_inputs
1316
1332
1317
1333
def generate_dummy_inputs (self , framework : str = "pt" , ** kwargs ):
@@ -1383,9 +1399,22 @@ def get_vlm_text_embeddings_config(model_type, model_config, int_dtype, float_dt
1383
1399
return export_config
1384
1400
1385
1401
1386
- def get_vlm_text_generation_config (model_type , model_config , int_dtype , float_dtype ):
1402
+ def get_vlm_text_generation_config (
1403
+ model_type ,
1404
+ model_config ,
1405
+ int_dtype ,
1406
+ float_dtype ,
1407
+ model_patcher = None ,
1408
+ dummy_input_generator = None ,
1409
+ inputs_update = None ,
1410
+ ):
1387
1411
internal_export_config = get_vlm_internal_text_generation_config (model_type , model_config , int_dtype , float_dtype )
1388
- export_config = LMInputEmbedsConfigHelper (internal_export_config )
1412
+ export_config = LMInputEmbedsConfigHelper (
1413
+ internal_export_config ,
1414
+ patcher_cls = model_patcher ,
1415
+ dummy_input_generator = dummy_input_generator ,
1416
+ inputs_update = inputs_update ,
1417
+ )
1389
1418
export_config ._normalized_config = internal_export_config ._normalized_config
1390
1419
return export_config
1391
1420
@@ -1821,9 +1850,11 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
1821
1850
img_ids_height = self .height // 2
1822
1851
img_ids_width = self .width // 2
1823
1852
return self .random_int_tensor (
1824
- [self .batch_size , img_ids_height * img_ids_width , 3 ]
1825
- if is_diffusers_version ("<" , "0.31.0" )
1826
- else [img_ids_height * img_ids_width , 3 ],
1853
+ (
1854
+ [self .batch_size , img_ids_height * img_ids_width , 3 ]
1855
+ if is_diffusers_version ("<" , "0.31.0" )
1856
+ else [img_ids_height * img_ids_width , 3 ]
1857
+ ),
1827
1858
min_value = 0 ,
1828
1859
max_value = min (img_ids_height , img_ids_width ),
1829
1860
framework = framework ,
@@ -2260,3 +2291,192 @@ def patch_model_for_export(
2260
2291
if self ._behavior == Phi3VisionConfigBehavior .VISION_EMBEDDINGS :
2261
2292
return Phi3VisionImageEmbeddingsPatcher (self , model , model_kwargs )
2262
2293
return super ().patch_model_for_export (model , model_kwargs )
2294
+
2295
+
2296
+ class DummyQwen2VLLMInputGenerator (DummyTextInputGenerator ):
2297
+ def generate (self , input_name : str , framework : str = "pt" , int_dtype : str = "int64" , float_dtype : str = "fp32" ):
2298
+ generated_input = super ().generate (input_name , framework , int_dtype , float_dtype )
2299
+ if input_name == "position_ids" :
2300
+ return generated_input .unsqueeze (0 ).expand (3 , - 1 , - 1 )
2301
+ return generated_input
2302
+
2303
+
2304
+ class DummyQwen2VLVisionEmbedInputGenerator (DummyVisionInputGenerator ):
2305
+ SUPPORTED_INPUT_NAMES = ("hidden_states" , "attention_mask" , "rotary_pos_emb" )
2306
+
2307
+ def __init__ (
2308
+ self ,
2309
+ task : str ,
2310
+ normalized_config : NormalizedVisionConfig ,
2311
+ batch_size : int = 1 ,
2312
+ num_channels : int = DEFAULT_DUMMY_SHAPES ["num_channels" ],
2313
+ width : int = 420 ,
2314
+ height : int = 420 ,
2315
+ ** kwargs ,
2316
+ ):
2317
+ self .batch_size = batch_size
2318
+ self .height = height
2319
+ self .width = width
2320
+ self .num_channels = num_channels
2321
+ self .temporal_patch_size = normalized_config .config .temporal_patch_size
2322
+ self .patch_size = normalized_config .config .patch_size
2323
+ if normalized_config .use_embed_dim :
2324
+ self .embed_dim = normalized_config .config .embed_dim
2325
+ else :
2326
+ self .embed_dim = self .num_channels * self .temporal_patch_size * self .patch_size * self .patch_size
2327
+ self .num_heads = normalized_config .config .num_heads
2328
+
2329
+ def generate (self , input_name : str , framework : str = "pt" , int_dtype : str = "int64" , float_dtype : str = "fp32" ):
2330
+ grid_h , grid_w = self .height // self .patch_size , self .width // self .patch_size
2331
+ grid_t = self .batch_size
2332
+
2333
+ if input_name == "hidden_states" :
2334
+ return self .random_float_tensor (
2335
+ [grid_t * grid_h * grid_w , self .embed_dim ], framework = framework , dtype = float_dtype
2336
+ )
2337
+
2338
+ if input_name == "attention_mask" :
2339
+ return self .random_mask_tensor (
2340
+ [1 , grid_t * grid_h * grid_w , grid_t * grid_h * grid_w ], framework = framework , dtype = float_dtype
2341
+ )
2342
+
2343
+ if input_name == "rotary_pos_emb" :
2344
+ dim = self .embed_dim // self .num_heads // 2
2345
+ return self .random_float_tensor ([grid_h * grid_t * grid_w , dim ], framework = framework , dtype = float_dtype )
2346
+
2347
+
2348
+ class Qwen2VLConfigBehavior (str , enum .Enum ):
2349
+ LANGUAGE = "language"
2350
+ VISION_EMBEDDINGS = "vision_embeddings"
2351
+ VISION_EMBEDDINGS_MERGER = "vision_embeddings_merger"
2352
+ TEXT_EMBEDDINGS = "text_embeddings"
2353
+
2354
+
2355
+ @register_in_tasks_manager ("qwen2-vl" , * ["image-text-to-text" ], library_name = "transformers" )
2356
+ class Qwen2VLOpenVINOConfig (OnnxConfig ):
2357
+ SUPPORTED_BEHAVIORS = [model_type .value for model_type in Qwen2VLConfigBehavior ]
2358
+ NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
2359
+ DUMMY_INPUT_GENERATOR_CLASSES = (DummyQwen2VLVisionEmbedInputGenerator ,)
2360
+ MIN_TRANSFORMERS_VERSION = version .parse ("4.45.0" )
2361
+
2362
+ def __init__ (
2363
+ self ,
2364
+ config : "PretrainedConfig" ,
2365
+ task : str = "feature-extraction" ,
2366
+ int_dtype : str = "int64" ,
2367
+ float_dtype : str = "fp32" ,
2368
+ behavior : Qwen2VLConfigBehavior = Qwen2VLConfigBehavior .VISION_EMBEDDINGS ,
2369
+ preprocessors : Optional [List [Any ]] = None ,
2370
+ ):
2371
+ super ().__init__ (
2372
+ config = config ,
2373
+ task = task ,
2374
+ int_dtype = int_dtype ,
2375
+ float_dtype = float_dtype ,
2376
+ preprocessors = preprocessors ,
2377
+ )
2378
+ self ._behavior = behavior
2379
+ self ._orig_config = config
2380
+ if self ._behavior == Qwen2VLConfigBehavior .VISION_EMBEDDINGS and hasattr (config , "vision_config" ):
2381
+ self ._config = config .vision_config
2382
+ self ._config
2383
+ self ._normalized_config = self .NORMALIZED_CONFIG_CLASS (self ._config )
2384
+ self ._normalized_config .use_embed_dim = False
2385
+ if self ._behavior == Qwen2VLConfigBehavior .VISION_EMBEDDINGS_MERGER and hasattr (config , "vision_config" ):
2386
+ self ._config = config .vision_config
2387
+ self ._normalized_config = self .NORMALIZED_CONFIG_CLASS (self ._config )
2388
+ self ._normalized_config .use_embed_dim = True
2389
+
2390
+ @staticmethod
2391
+ def get_model_for_behavior (model , behavior : Union [str , Qwen2VLConfigBehavior ]):
2392
+ if isinstance (behavior , str ) and not isinstance (behavior , Qwen2VLConfigBehavior ):
2393
+ behavior = Qwen2VLConfigBehavior (behavior )
2394
+
2395
+ if behavior == Qwen2VLConfigBehavior .LANGUAGE :
2396
+ return model
2397
+
2398
+ if behavior == Qwen2VLConfigBehavior .VISION_EMBEDDINGS :
2399
+ vision_embeddings = model .visual .patch_embed
2400
+ vision_embeddings .config = model .config .vision_config
2401
+ return vision_embeddings
2402
+
2403
+ if behavior == Qwen2VLConfigBehavior .VISION_EMBEDDINGS_MERGER :
2404
+ vision_emb_merger = model .visual
2405
+ vision_emb_merger .config = model .config .vision_config
2406
+ return vision_emb_merger
2407
+
2408
+ if behavior == Qwen2VLConfigBehavior .TEXT_EMBEDDINGS :
2409
+ text_embedding = model .model .embed_tokens
2410
+ text_embedding .config = model .config
2411
+ return text_embedding
2412
+
2413
+ def with_behavior (
2414
+ self ,
2415
+ behavior : Union [str , Qwen2VLConfigBehavior ],
2416
+ ):
2417
+ """
2418
+ Creates a config for different behaviour.
2419
+ Args:
2420
+ behavior ([`ConfigBehavior`]):
2421
+ The behavior to use for the new instance.
2422
+ """
2423
+ if isinstance (behavior , str ) and not isinstance (behavior , Qwen2VLConfigBehavior ):
2424
+ behavior = Qwen2VLConfigBehavior (behavior )
2425
+
2426
+ if behavior == Qwen2VLConfigBehavior .TEXT_EMBEDDINGS :
2427
+ return get_vlm_text_embeddings_config ("qwen2" , self ._orig_config , self .int_dtype , self .float_dtype )
2428
+
2429
+ if behavior == Qwen2VLConfigBehavior .LANGUAGE :
2430
+ return get_vlm_text_generation_config (
2431
+ "qwen2" ,
2432
+ self ._orig_config ,
2433
+ self .int_dtype ,
2434
+ self .float_dtype ,
2435
+ model_patcher = Qwen2VLLanguageModelPatcher ,
2436
+ dummy_input_generator = DummyQwen2VLLMInputGenerator ,
2437
+ inputs_update = {"position_ids" : {1 : "batch_size" , 2 : "sequence_length" }},
2438
+ )
2439
+
2440
+ if behavior == Qwen2VLConfigBehavior .VISION_EMBEDDINGS :
2441
+ return self .__class__ (
2442
+ self ._orig_config ,
2443
+ task = self .task ,
2444
+ int_dtype = self .int_dtype ,
2445
+ float_dtype = self .float_dtype ,
2446
+ behavior = behavior ,
2447
+ preprocessors = self ._preprocessors ,
2448
+ )
2449
+ if behavior == Qwen2VLConfigBehavior .VISION_EMBEDDINGS_MERGER :
2450
+ return self .__class__ (
2451
+ self ._orig_config ,
2452
+ task = self .task ,
2453
+ int_dtype = self .int_dtype ,
2454
+ float_dtype = self .float_dtype ,
2455
+ behavior = behavior ,
2456
+ preprocessors = self ._preprocessors ,
2457
+ )
2458
+
2459
+ def patch_model_for_export (
2460
+ self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
2461
+ ):
2462
+ model_kwargs = model_kwargs or {}
2463
+ if self ._behavior == Qwen2VLConfigBehavior .VISION_EMBEDDINGS_MERGER :
2464
+ return Qwen2VLVisionEmbMergerPatcher (self , model , model_kwargs )
2465
+ return super ().patch_model_for_export (model , model_kwargs )
2466
+
2467
+ @property
2468
+ def inputs (self ) -> Dict [str , Dict [int , str ]]:
2469
+ if self ._behavior == Phi3VisionConfigBehavior .VISION_EMBEDDINGS :
2470
+ return {"hidden_states" : {0 : "patch_thw_grid" , 1 : "patch_temporal_channels" }}
2471
+ if self ._behavior == Qwen2VLConfigBehavior .VISION_EMBEDDINGS_MERGER :
2472
+ return {
2473
+ "hidden_states" : {0 : "sequence_length" },
2474
+ "attention_mask" : {1 : "sequence_length" , 2 : "sequence_length" },
2475
+ "rotary_pos_emb" : {0 : "sequence_length" },
2476
+ }
2477
+
2478
+ @property
2479
+ def outputs (self ) -> Dict [str , Dict [int , str ]]:
2480
+ if self ._behavior in [Qwen2VLConfigBehavior .VISION_EMBEDDINGS , Qwen2VLConfigBehavior .VISION_EMBEDDINGS_MERGER ]:
2481
+ return {"last_hidden_state" : {0 : "seq_len" }}
2482
+ return {}
0 commit comments