Skip to content

Commit 93777ec

Browse files
Add OpenVINO qwen2vl support (#1042)
* qwen2vl support * fix code style * add test case * Added compression tests for qwen2-vl * Remove trust_remote_code * Apply suggestions from code review Co-authored-by: Nikita Savelyev <nikita.savelyev@intel.com> * revert changes in notebook * apply review comments * add comments for patching * reuse original methods if possile * Update optimum/intel/openvino/modeling_visual_language.py * fix typings in patchers --------- Co-authored-by: Nikita Savelyev <nikita.savelyev@intel.com>
1 parent b17d1e0 commit 93777ec

File tree

8 files changed

+621
-10
lines changed

8 files changed

+621
-10
lines changed

optimum/exporters/openvino/model_configs.py

+226-6
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@
8888
PersimmonModelPatcher,
8989
Phi3ModelPatcher,
9090
Phi3VisionImageEmbeddingsPatcher,
91+
Qwen2VLLanguageModelPatcher,
92+
Qwen2VLVisionEmbMergerPatcher,
9193
QwenModelPatcher,
9294
RotaryEmbPatcher,
9395
UpdateCausalMaskModelPatcher,
@@ -106,6 +108,10 @@ def init_model_configs():
106108
"transformers",
107109
"LlavaNextForConditionalGeneration",
108110
)
111+
TasksManager._CUSTOM_CLASSES[("pt", "qwen2-vl", "image-text-to-text")] = (
112+
"transformers",
113+
"Qwen2VLForConditionalGeneration",
114+
)
109115
TasksManager._TRANSFORMERS_TASKS_TO_MODEL_LOADERS[
110116
"image-text-to-text"
111117
] = TasksManager._TRANSFORMERS_TASKS_TO_MODEL_LOADERS["text-generation"]
@@ -1288,18 +1294,26 @@ def patch_model_for_export(
12881294

12891295

12901296
class LMInputEmbedsConfigHelper(TextDecoderWithPositionIdsOnnxConfig):
1291-
def __init__(self, export_config):
1297+
def __init__(self, export_config, patcher_cls=None, dummy_input_generator=None, inputs_update=None):
12921298
self.orig_export_config = export_config
1299+
if dummy_input_generator is not None:
1300+
export_config.DUMMY_INPUT_GENERATOR_CLASSES = (
1301+
dummy_input_generator,
1302+
) + export_config.DUMMY_INPUT_GENERATOR_CLASSES
12931303
self.DUMMY_INPUT_GENERATOR_CLASSES = export_config.DUMMY_INPUT_GENERATOR_CLASSES
12941304
self.DEFAULT_ONNX_OPSET = export_config.DEFAULT_ONNX_OPSET
12951305
self.DUMMY_PKV_GENERATOR_CLASS = export_config.DUMMY_PKV_GENERATOR_CLASS
12961306
self._config = export_config._config
12971307
self._normalized_config = export_config._normalized_config
12981308
self.use_past = export_config.use_past
1309+
self.patcher_cls = patcher_cls
1310+
self.input_info_upd = inputs_update
12991311

13001312
def patch_model_for_export(
13011313
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
13021314
) -> "ModelPatcher":
1315+
if self.patcher_cls is not None:
1316+
return self.patcher_cls(self, model, model_kwargs=model_kwargs)
13031317
# Refer to DecoderModelPatcher.
13041318
return self.orig_export_config.patch_model_for_export(model, model_kwargs=model_kwargs)
13051319

@@ -1312,6 +1326,8 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
13121326
orig_inputs = self.orig_export_config.inputs
13131327
input_ids_config = orig_inputs.pop("input_ids")
13141328
orig_inputs["inputs_embeds"] = input_ids_config
1329+
if self.input_info_upd is not None:
1330+
orig_inputs.update(self.input_info_upd)
13151331
return orig_inputs
13161332

13171333
def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
@@ -1383,9 +1399,22 @@ def get_vlm_text_embeddings_config(model_type, model_config, int_dtype, float_dt
13831399
return export_config
13841400

13851401

1386-
def get_vlm_text_generation_config(model_type, model_config, int_dtype, float_dtype):
1402+
def get_vlm_text_generation_config(
1403+
model_type,
1404+
model_config,
1405+
int_dtype,
1406+
float_dtype,
1407+
model_patcher=None,
1408+
dummy_input_generator=None,
1409+
inputs_update=None,
1410+
):
13871411
internal_export_config = get_vlm_internal_text_generation_config(model_type, model_config, int_dtype, float_dtype)
1388-
export_config = LMInputEmbedsConfigHelper(internal_export_config)
1412+
export_config = LMInputEmbedsConfigHelper(
1413+
internal_export_config,
1414+
patcher_cls=model_patcher,
1415+
dummy_input_generator=dummy_input_generator,
1416+
inputs_update=inputs_update,
1417+
)
13891418
export_config._normalized_config = internal_export_config._normalized_config
13901419
return export_config
13911420

@@ -1821,9 +1850,11 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
18211850
img_ids_height = self.height // 2
18221851
img_ids_width = self.width // 2
18231852
return self.random_int_tensor(
1824-
[self.batch_size, img_ids_height * img_ids_width, 3]
1825-
if is_diffusers_version("<", "0.31.0")
1826-
else [img_ids_height * img_ids_width, 3],
1853+
(
1854+
[self.batch_size, img_ids_height * img_ids_width, 3]
1855+
if is_diffusers_version("<", "0.31.0")
1856+
else [img_ids_height * img_ids_width, 3]
1857+
),
18271858
min_value=0,
18281859
max_value=min(img_ids_height, img_ids_width),
18291860
framework=framework,
@@ -2260,3 +2291,192 @@ def patch_model_for_export(
22602291
if self._behavior == Phi3VisionConfigBehavior.VISION_EMBEDDINGS:
22612292
return Phi3VisionImageEmbeddingsPatcher(self, model, model_kwargs)
22622293
return super().patch_model_for_export(model, model_kwargs)
2294+
2295+
2296+
class DummyQwen2VLLMInputGenerator(DummyTextInputGenerator):
2297+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
2298+
generated_input = super().generate(input_name, framework, int_dtype, float_dtype)
2299+
if input_name == "position_ids":
2300+
return generated_input.unsqueeze(0).expand(3, -1, -1)
2301+
return generated_input
2302+
2303+
2304+
class DummyQwen2VLVisionEmbedInputGenerator(DummyVisionInputGenerator):
2305+
SUPPORTED_INPUT_NAMES = ("hidden_states", "attention_mask", "rotary_pos_emb")
2306+
2307+
def __init__(
2308+
self,
2309+
task: str,
2310+
normalized_config: NormalizedVisionConfig,
2311+
batch_size: int = 1,
2312+
num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"],
2313+
width: int = 420,
2314+
height: int = 420,
2315+
**kwargs,
2316+
):
2317+
self.batch_size = batch_size
2318+
self.height = height
2319+
self.width = width
2320+
self.num_channels = num_channels
2321+
self.temporal_patch_size = normalized_config.config.temporal_patch_size
2322+
self.patch_size = normalized_config.config.patch_size
2323+
if normalized_config.use_embed_dim:
2324+
self.embed_dim = normalized_config.config.embed_dim
2325+
else:
2326+
self.embed_dim = self.num_channels * self.temporal_patch_size * self.patch_size * self.patch_size
2327+
self.num_heads = normalized_config.config.num_heads
2328+
2329+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
2330+
grid_h, grid_w = self.height // self.patch_size, self.width // self.patch_size
2331+
grid_t = self.batch_size
2332+
2333+
if input_name == "hidden_states":
2334+
return self.random_float_tensor(
2335+
[grid_t * grid_h * grid_w, self.embed_dim], framework=framework, dtype=float_dtype
2336+
)
2337+
2338+
if input_name == "attention_mask":
2339+
return self.random_mask_tensor(
2340+
[1, grid_t * grid_h * grid_w, grid_t * grid_h * grid_w], framework=framework, dtype=float_dtype
2341+
)
2342+
2343+
if input_name == "rotary_pos_emb":
2344+
dim = self.embed_dim // self.num_heads // 2
2345+
return self.random_float_tensor([grid_h * grid_t * grid_w, dim], framework=framework, dtype=float_dtype)
2346+
2347+
2348+
class Qwen2VLConfigBehavior(str, enum.Enum):
2349+
LANGUAGE = "language"
2350+
VISION_EMBEDDINGS = "vision_embeddings"
2351+
VISION_EMBEDDINGS_MERGER = "vision_embeddings_merger"
2352+
TEXT_EMBEDDINGS = "text_embeddings"
2353+
2354+
2355+
@register_in_tasks_manager("qwen2-vl", *["image-text-to-text"], library_name="transformers")
2356+
class Qwen2VLOpenVINOConfig(OnnxConfig):
2357+
SUPPORTED_BEHAVIORS = [model_type.value for model_type in Qwen2VLConfigBehavior]
2358+
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
2359+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyQwen2VLVisionEmbedInputGenerator,)
2360+
MIN_TRANSFORMERS_VERSION = version.parse("4.45.0")
2361+
2362+
def __init__(
2363+
self,
2364+
config: "PretrainedConfig",
2365+
task: str = "feature-extraction",
2366+
int_dtype: str = "int64",
2367+
float_dtype: str = "fp32",
2368+
behavior: Qwen2VLConfigBehavior = Qwen2VLConfigBehavior.VISION_EMBEDDINGS,
2369+
preprocessors: Optional[List[Any]] = None,
2370+
):
2371+
super().__init__(
2372+
config=config,
2373+
task=task,
2374+
int_dtype=int_dtype,
2375+
float_dtype=float_dtype,
2376+
preprocessors=preprocessors,
2377+
)
2378+
self._behavior = behavior
2379+
self._orig_config = config
2380+
if self._behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS and hasattr(config, "vision_config"):
2381+
self._config = config.vision_config
2382+
self._config
2383+
self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config)
2384+
self._normalized_config.use_embed_dim = False
2385+
if self._behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER and hasattr(config, "vision_config"):
2386+
self._config = config.vision_config
2387+
self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config)
2388+
self._normalized_config.use_embed_dim = True
2389+
2390+
@staticmethod
2391+
def get_model_for_behavior(model, behavior: Union[str, Qwen2VLConfigBehavior]):
2392+
if isinstance(behavior, str) and not isinstance(behavior, Qwen2VLConfigBehavior):
2393+
behavior = Qwen2VLConfigBehavior(behavior)
2394+
2395+
if behavior == Qwen2VLConfigBehavior.LANGUAGE:
2396+
return model
2397+
2398+
if behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS:
2399+
vision_embeddings = model.visual.patch_embed
2400+
vision_embeddings.config = model.config.vision_config
2401+
return vision_embeddings
2402+
2403+
if behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER:
2404+
vision_emb_merger = model.visual
2405+
vision_emb_merger.config = model.config.vision_config
2406+
return vision_emb_merger
2407+
2408+
if behavior == Qwen2VLConfigBehavior.TEXT_EMBEDDINGS:
2409+
text_embedding = model.model.embed_tokens
2410+
text_embedding.config = model.config
2411+
return text_embedding
2412+
2413+
def with_behavior(
2414+
self,
2415+
behavior: Union[str, Qwen2VLConfigBehavior],
2416+
):
2417+
"""
2418+
Creates a config for different behaviour.
2419+
Args:
2420+
behavior ([`ConfigBehavior`]):
2421+
The behavior to use for the new instance.
2422+
"""
2423+
if isinstance(behavior, str) and not isinstance(behavior, Qwen2VLConfigBehavior):
2424+
behavior = Qwen2VLConfigBehavior(behavior)
2425+
2426+
if behavior == Qwen2VLConfigBehavior.TEXT_EMBEDDINGS:
2427+
return get_vlm_text_embeddings_config("qwen2", self._orig_config, self.int_dtype, self.float_dtype)
2428+
2429+
if behavior == Qwen2VLConfigBehavior.LANGUAGE:
2430+
return get_vlm_text_generation_config(
2431+
"qwen2",
2432+
self._orig_config,
2433+
self.int_dtype,
2434+
self.float_dtype,
2435+
model_patcher=Qwen2VLLanguageModelPatcher,
2436+
dummy_input_generator=DummyQwen2VLLMInputGenerator,
2437+
inputs_update={"position_ids": {1: "batch_size", 2: "sequence_length"}},
2438+
)
2439+
2440+
if behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS:
2441+
return self.__class__(
2442+
self._orig_config,
2443+
task=self.task,
2444+
int_dtype=self.int_dtype,
2445+
float_dtype=self.float_dtype,
2446+
behavior=behavior,
2447+
preprocessors=self._preprocessors,
2448+
)
2449+
if behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER:
2450+
return self.__class__(
2451+
self._orig_config,
2452+
task=self.task,
2453+
int_dtype=self.int_dtype,
2454+
float_dtype=self.float_dtype,
2455+
behavior=behavior,
2456+
preprocessors=self._preprocessors,
2457+
)
2458+
2459+
def patch_model_for_export(
2460+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
2461+
):
2462+
model_kwargs = model_kwargs or {}
2463+
if self._behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER:
2464+
return Qwen2VLVisionEmbMergerPatcher(self, model, model_kwargs)
2465+
return super().patch_model_for_export(model, model_kwargs)
2466+
2467+
@property
2468+
def inputs(self) -> Dict[str, Dict[int, str]]:
2469+
if self._behavior == Phi3VisionConfigBehavior.VISION_EMBEDDINGS:
2470+
return {"hidden_states": {0: "patch_thw_grid", 1: "patch_temporal_channels"}}
2471+
if self._behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER:
2472+
return {
2473+
"hidden_states": {0: "sequence_length"},
2474+
"attention_mask": {1: "sequence_length", 2: "sequence_length"},
2475+
"rotary_pos_emb": {0: "sequence_length"},
2476+
}
2477+
2478+
@property
2479+
def outputs(self) -> Dict[str, Dict[int, str]]:
2480+
if self._behavior in [Qwen2VLConfigBehavior.VISION_EMBEDDINGS, Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER]:
2481+
return {"last_hidden_state": {0: "seq_len"}}
2482+
return {}

optimum/exporters/openvino/model_patcher.py

+106
Original file line numberDiff line numberDiff line change
@@ -3378,3 +3378,109 @@ def __exit__(self, exc_type, exc_value, traceback):
33783378
super().__exit__(exc_type, exc_value, traceback)
33793379
for block in self._model.model.layers:
33803380
block.self_attn.forward = block.self_attn._orig_forward
3381+
3382+
3383+
class Qwen2VLLanguageModelPatcher(DecoderModelPatcher):
3384+
def __init__(
3385+
self,
3386+
config: "OnnxConfig",
3387+
model: Union["PreTrainedModel", "TFPreTrainedModel"],
3388+
model_kwargs: Dict[str, Any] = None,
3389+
):
3390+
model.__orig_forward = model.forward
3391+
3392+
def forward_wrap(
3393+
self,
3394+
attention_mask,
3395+
position_ids=None,
3396+
past_key_values=None,
3397+
inputs_embeds=None,
3398+
input_ids=None,
3399+
):
3400+
from transformers.cache_utils import DynamicCache
3401+
3402+
new_past_key_values = DynamicCache.from_legacy_cache(past_key_values)
3403+
result = self.__orig_forward(
3404+
input_ids=input_ids,
3405+
attention_mask=attention_mask,
3406+
position_ids=position_ids,
3407+
past_key_values=new_past_key_values,
3408+
inputs_embeds=inputs_embeds,
3409+
)
3410+
if past_key_values is not None:
3411+
result["past_key_values"] = result["past_key_values"].to_legacy_cache()
3412+
return result
3413+
3414+
model.forward = types.MethodType(forward_wrap, model)
3415+
super().__init__(config, model, model_kwargs)
3416+
3417+
def __exit__(self, exc_type, exc_value, traceback):
3418+
super().__exit__(exc_type, exc_value, traceback)
3419+
self._model.forward = self._model.__orig_forward
3420+
3421+
3422+
class Qwen2VLVisionEmbMergerPatcher(ModelPatcher):
3423+
def __init__(
3424+
self,
3425+
config: "OnnxConfig",
3426+
model: Union["PreTrainedModel", "TFPreTrainedModel"],
3427+
model_kwargs: Dict[str, Any] = None,
3428+
):
3429+
model.__orig_forward = model.forward
3430+
3431+
# Modified from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1118
3432+
# added attention_mask input instead cu_lens for its internal calculation model (unsupported by tracing due to cycle with dynamic len)
3433+
# separated patch_embed and rot_pos_emb calls for performing as part of another model
3434+
def image_embed_forward(
3435+
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, rotary_pos_emb: torch.Tensor
3436+
) -> torch.Tensor:
3437+
for blk in self.blocks:
3438+
hidden_states = blk(hidden_states, attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb)
3439+
return self.merger(hidden_states)
3440+
3441+
model.forward = types.MethodType(image_embed_forward, model)
3442+
super().__init__(config, model, model_kwargs)
3443+
3444+
def __enter__(self):
3445+
# Modified from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L390
3446+
# added attention_mask input instead of internal calculation (unsupported by tracing due to cycle with dynamic len)
3447+
def sdpa_attn_forward(
3448+
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, rotary_pos_emb: torch.Tensor = None
3449+
) -> torch.Tensor:
3450+
from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_rotary_pos_emb_vision
3451+
3452+
seq_length = hidden_states.shape[0]
3453+
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
3454+
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
3455+
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
3456+
3457+
q = q.transpose(0, 1)
3458+
k = k.transpose(0, 1)
3459+
v = v.transpose(0, 1)
3460+
attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
3461+
attn_output = attn_output.transpose(0, 1)
3462+
attn_output = attn_output.reshape(seq_length, -1)
3463+
attn_output = self.proj(attn_output)
3464+
return attn_output
3465+
3466+
# Modified from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L430
3467+
# added attention_mask input propagation to self.attn
3468+
def block_forward(self, hidden_states, attention_mask, rotary_pos_emb) -> torch.Tensor:
3469+
hidden_states = hidden_states + self.attn(
3470+
self.norm1(hidden_states), attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb
3471+
)
3472+
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
3473+
return hidden_states
3474+
3475+
for block in self._model.blocks:
3476+
block._orig_forward = block.forward
3477+
block.forward = types.MethodType(block_forward, block)
3478+
block.attn._orig_forward = block.attn.forward
3479+
block.attn.forward = types.MethodType(sdpa_attn_forward, block.attn)
3480+
3481+
def __exit__(self, exc_type, exc_value, traceback):
3482+
super().__exit__(exc_type, exc_value, traceback)
3483+
self._model.forward = self._model.__orig_forward
3484+
for block in self._model.blocks:
3485+
block.forward = block._orig_forward
3486+
block.attn.forward = block.attn._orig_forward

0 commit comments

Comments
 (0)