Skip to content

Commit c26a450

Browse files
committed
qwen2vl support
1 parent c454b00 commit c26a450

File tree

5 files changed

+781
-22
lines changed

5 files changed

+781
-22
lines changed

notebooks/openvino/sentence_transformer_quantization.ipynb

+8-4
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,11 @@
170170
],
171171
"source": [
172172
"from functools import partial\n",
173-
"import datasets\n",
173+
"\n",
174174
"from transformers import AutoTokenizer\n",
175-
"from optimum.intel import OVModelForFeatureExtraction, OVQuantizer, OVQuantizationConfig, OVConfig\n",
175+
"\n",
176+
"from optimum.intel import OVConfig, OVModelForFeatureExtraction, OVQuantizationConfig, OVQuantizer\n",
177+
"\n",
176178
"\n",
177179
"MODEL_ID = \"sentence-transformers/all-MiniLM-L6-v2\"\n",
178180
"base_model_path = \"all-MiniLM-L6-v2\"\n",
@@ -187,6 +189,7 @@
187189
"\n",
188190
"quantizer = OVQuantizer.from_pretrained(model)\n",
189191
"\n",
192+
"\n",
190193
"def preprocess_function(examples, tokenizer):\n",
191194
" return tokenizer(examples[\"sentence\"], padding=\"max_length\", max_length=384, truncation=True)\n",
192195
"\n",
@@ -225,9 +228,9 @@
225228
"metadata": {},
226229
"outputs": [],
227230
"source": [
228-
"from transformers import Pipeline\n",
229-
"import torch.nn.functional as F\n",
230231
"import torch\n",
232+
"import torch.nn.functional as F\n",
233+
"from transformers import Pipeline\n",
231234
"\n",
232235
"\n",
233236
"# copied from the model card \"sentence-transformers/all-MiniLM-L6-v2\"\n",
@@ -296,6 +299,7 @@
296299
"from datasets import load_dataset\n",
297300
"from evaluate import load\n",
298301
"\n",
302+
"\n",
299303
"eval_dataset = load_dataset(\"glue\", \"stsb\", split=\"validation\")\n",
300304
"metric = load(\"glue\", \"stsb\")"
301305
]

optimum/exporters/openvino/model_configs.py

+255-9
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@
8989
Phi3ModelPatcher,
9090
Phi3VisionImageEmbeddingsPatcher,
9191
QwenModelPatcher,
92+
Qwen2VLLanguageModelPatcher,
93+
Qwen2VLVisionEmbMergerPatcher,
9294
RotaryEmbPatcher,
9395
UpdateCausalMaskModelPatcher,
9496
XverseModelPatcher,
@@ -106,9 +108,13 @@ def init_model_configs():
106108
"transformers",
107109
"LlavaNextForConditionalGeneration",
108110
)
109-
TasksManager._TRANSFORMERS_TASKS_TO_MODEL_LOADERS[
110-
"image-text-to-text"
111-
] = TasksManager._TRANSFORMERS_TASKS_TO_MODEL_LOADERS["text-generation"]
111+
TasksManager._CUSTOM_CLASSES[("pt", "qwen2-vl", "image-text-to-text")] = (
112+
"transformers",
113+
"Qwen2VLForConditionalGeneration",
114+
)
115+
TasksManager._TRANSFORMERS_TASKS_TO_MODEL_LOADERS["image-text-to-text"] = (
116+
TasksManager._TRANSFORMERS_TASKS_TO_MODEL_LOADERS["text-generation"]
117+
)
112118

113119
supported_model_types = [
114120
"_SUPPORTED_MODEL_TYPE",
@@ -1288,18 +1294,26 @@ def patch_model_for_export(
12881294

12891295

12901296
class LMInputEmbedsConfigHelper(TextDecoderWithPositionIdsOnnxConfig):
1291-
def __init__(self, export_config):
1297+
def __init__(self, export_config, patcher_cls=None, dummy_input_generator=None, inputs_update=None):
12921298
self.orig_export_config = export_config
1299+
if dummy_input_generator is not None:
1300+
export_config.DUMMY_INPUT_GENERATOR_CLASSES = (
1301+
dummy_input_generator,
1302+
) + export_config.DUMMY_INPUT_GENERATOR_CLASSES
12931303
self.DUMMY_INPUT_GENERATOR_CLASSES = export_config.DUMMY_INPUT_GENERATOR_CLASSES
12941304
self.DEFAULT_ONNX_OPSET = export_config.DEFAULT_ONNX_OPSET
12951305
self.DUMMY_PKV_GENERATOR_CLASS = export_config.DUMMY_PKV_GENERATOR_CLASS
12961306
self._config = export_config._config
12971307
self._normalized_config = export_config._normalized_config
12981308
self.use_past = export_config.use_past
1309+
self.patcher_cls = patcher_cls
1310+
self.input_info_upd = inputs_update
12991311

13001312
def patch_model_for_export(
13011313
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
13021314
) -> "ModelPatcher":
1315+
if self.patcher_cls is not None:
1316+
return self.patcher_cls(self, model, model_kwargs=model_kwargs)
13031317
# Refer to DecoderModelPatcher.
13041318
return self.orig_export_config.patch_model_for_export(model, model_kwargs=model_kwargs)
13051319

@@ -1312,6 +1326,8 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
13121326
orig_inputs = self.orig_export_config.inputs
13131327
input_ids_config = orig_inputs.pop("input_ids")
13141328
orig_inputs["inputs_embeds"] = input_ids_config
1329+
if self.input_info_upd is not None:
1330+
orig_inputs.update(self.input_info_upd)
13151331
return orig_inputs
13161332

13171333
def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
@@ -1383,9 +1399,22 @@ def get_vlm_text_embeddings_config(model_type, model_config, int_dtype, float_dt
13831399
return export_config
13841400

13851401

1386-
def get_vlm_text_generation_config(model_type, model_config, int_dtype, float_dtype):
1402+
def get_vlm_text_generation_config(
1403+
model_type,
1404+
model_config,
1405+
int_dtype,
1406+
float_dtype,
1407+
model_patcher=None,
1408+
dummy_input_generator=None,
1409+
inputs_update=None,
1410+
):
13871411
internal_export_config = get_vlm_internal_text_generation_config(model_type, model_config, int_dtype, float_dtype)
1388-
export_config = LMInputEmbedsConfigHelper(internal_export_config)
1412+
export_config = LMInputEmbedsConfigHelper(
1413+
internal_export_config,
1414+
patcher_cls=model_patcher,
1415+
dummy_input_generator=dummy_input_generator,
1416+
inputs_update=inputs_update,
1417+
)
13891418
export_config._normalized_config = internal_export_config._normalized_config
13901419
return export_config
13911420

@@ -1820,9 +1849,11 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
18201849
img_ids_height = self.height // 2
18211850
img_ids_width = self.width // 2
18221851
return self.random_int_tensor(
1823-
[self.batch_size, img_ids_height * img_ids_width, 3]
1824-
if is_diffusers_version("<", "0.31.0")
1825-
else [img_ids_height * img_ids_width, 3],
1852+
(
1853+
[self.batch_size, img_ids_height * img_ids_width, 3]
1854+
if is_diffusers_version("<", "0.31.0")
1855+
else [img_ids_height * img_ids_width, 3]
1856+
),
18261857
min_value=0,
18271858
max_value=min(img_ids_height, img_ids_width),
18281859
framework=framework,
@@ -2259,3 +2290,218 @@ def patch_model_for_export(
22592290
if self._behavior == Phi3VisionConfigBehavior.VISION_EMBEDDINGS:
22602291
return Phi3VisionImageEmbeddingsPatcher(self, model, model_kwargs)
22612292
return super().patch_model_for_export(model, model_kwargs)
2293+
2294+
2295+
class DummyQwen2VLLMInputGenerator(DummyTextInputGenerator):
2296+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
2297+
generated_input = super().generate(input_name, framework, int_dtype, float_dtype)
2298+
if input_name == "position_ids":
2299+
return generated_input.unsqueeze(0).expand(3, -1, -1)
2300+
return generated_input
2301+
2302+
2303+
class DummyQwen2VLVisionEMbedInputGenerator(DummyVisionInputGenerator):
2304+
SUPPORTED_INPUT_NAMES = ("hidden_states",)
2305+
2306+
def __init__(
2307+
self,
2308+
task: str,
2309+
normalized_config: NormalizedVisionConfig,
2310+
batch_size: int = 1,
2311+
num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"],
2312+
width: int = 420,
2313+
height: int = 420,
2314+
**kwargs,
2315+
):
2316+
self.batch_size = batch_size
2317+
self.height = height
2318+
self.width = width
2319+
self.num_channels = num_channels
2320+
self.temporal_patch_size = normalized_config.config.temporal_patch_size
2321+
self.patch_size = normalized_config.config.patch_size
2322+
2323+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
2324+
grid_h, grid_w = self.height // self.patch_size, self.width // self.patch_size
2325+
grid_t = self.batch_size
2326+
shape = [
2327+
grid_t * grid_h * grid_w,
2328+
self.num_channels * self.temporal_patch_size * self.patch_size * self.patch_size,
2329+
]
2330+
return self.random_float_tensor(shape, framework=framework, dtype=float_dtype)
2331+
2332+
2333+
class DummyQwen2VLVisionEmbedMergerInputGenerator(DummyVisionInputGenerator):
2334+
SUPPORTED_INPUT_NAMES = ("hidden_states", "attention_mask", "rotary_pos_emb")
2335+
2336+
def __init__(
2337+
self,
2338+
task: str,
2339+
normalized_config: NormalizedVisionConfig,
2340+
batch_size: int = 1,
2341+
num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"],
2342+
width: int = 420,
2343+
height: int = 420,
2344+
**kwargs,
2345+
):
2346+
self.batch_size = batch_size
2347+
self.height = height
2348+
self.width = width
2349+
self.num_channels = num_channels
2350+
self.temporal_patch_size = normalized_config.config.temporal_patch_size
2351+
self.patch_size = normalized_config.config.patch_size
2352+
self.embed_dim = normalized_config.config.embed_dim
2353+
self.num_heads = normalized_config.config.num_heads
2354+
2355+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
2356+
grid_h, grid_w = self.height // self.patch_size, self.width // self.patch_size
2357+
grid_t = self.batch_size
2358+
2359+
if input_name == "hidden_states":
2360+
return self.random_float_tensor(
2361+
[grid_t * grid_h * grid_w, self.embed_dim], framework=framework, dtype=float_dtype
2362+
)
2363+
2364+
if input_name == "attention_mask":
2365+
return self.random_mask_tensor(
2366+
[1, grid_t * grid_h * grid_w, grid_t * grid_h * grid_w], framework=framework, dtype=float_dtype
2367+
)
2368+
2369+
if input_name == "rotary_pos_emb":
2370+
dim = self.embed_dim // self.num_heads // 2
2371+
return self.random_float_tensor([grid_h * grid_t * grid_w, dim], framework=framework, dtype=float_dtype)
2372+
2373+
2374+
class Qwen2VLConfigBehavior(str, enum.Enum):
2375+
LANGUAGE = "language"
2376+
VISION_EMBEDDINGS = "vision_embeddings"
2377+
VISION_EMBEDDINGS_MERGER = "vision_embeddings_merger"
2378+
TEXT_EMBEDDINGS = "text_embeddings"
2379+
2380+
2381+
@register_in_tasks_manager("qwen2-vl", *["image-text-to-text"], library_name="transformers")
2382+
class Qwen2VLOpenVINOConfig(OnnxConfig):
2383+
SUPPORTED_BEHAVIORS = [model_type.value for model_type in Qwen2VLConfigBehavior]
2384+
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
2385+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyQwen2VLVisionEMbedInputGenerator,)
2386+
MIN_TRANSFORMERS_VERSION = version.parse("4.45.0")
2387+
2388+
def __init__(
2389+
self,
2390+
config: "PretrainedConfig",
2391+
task: str = "feature-extraction",
2392+
int_dtype: str = "int64",
2393+
float_dtype: str = "fp32",
2394+
behavior: Qwen2VLConfigBehavior = Qwen2VLConfigBehavior.VISION_EMBEDDINGS,
2395+
preprocessors: Optional[List[Any]] = None,
2396+
):
2397+
super().__init__(
2398+
config=config,
2399+
task=task,
2400+
int_dtype=int_dtype,
2401+
float_dtype=float_dtype,
2402+
preprocessors=preprocessors,
2403+
)
2404+
self._behavior = behavior
2405+
self._orig_config = config
2406+
if self._behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS and hasattr(config, "vision_config"):
2407+
self._config = config.vision_config
2408+
self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config)
2409+
self.DUMMY_INPUT_GENERATOR_CLASSES = (DummyQwen2VLVisionEMbedInputGenerator,)
2410+
if self._behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER and hasattr(config, "vision_config"):
2411+
self._config = config.vision_config
2412+
self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config)
2413+
self.DUMMY_INPUT_GENERATOR_CLASSES = (DummyQwen2VLVisionEmbedMergerInputGenerator,)
2414+
2415+
@staticmethod
2416+
def get_model_for_behavior(model, behavior: Union[str, Qwen2VLConfigBehavior]):
2417+
if isinstance(behavior, str) and not isinstance(behavior, Qwen2VLConfigBehavior):
2418+
behavior = Qwen2VLConfigBehavior(behavior)
2419+
2420+
if behavior == Qwen2VLConfigBehavior.LANGUAGE:
2421+
return model
2422+
2423+
if behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS:
2424+
vision_embeddings = model.visual.patch_embed
2425+
vision_embeddings.config = model.config.vision_config
2426+
return vision_embeddings
2427+
2428+
if behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER:
2429+
vision_emb_merger = model.visual
2430+
vision_emb_merger.config = model.config.vision_config
2431+
return vision_emb_merger
2432+
2433+
if behavior == Qwen2VLConfigBehavior.TEXT_EMBEDDINGS:
2434+
text_embedding = model.model.embed_tokens
2435+
text_embedding.config = model.config
2436+
return text_embedding
2437+
2438+
def with_behavior(
2439+
self,
2440+
behavior: Union[str, Qwen2VLConfigBehavior],
2441+
):
2442+
"""
2443+
Creates a config for different behaviour.
2444+
Args:
2445+
behavior ([`ConfigBehavior`]):
2446+
The behavior to use for the new instance.
2447+
"""
2448+
if isinstance(behavior, str) and not isinstance(behavior, Qwen2VLConfigBehavior):
2449+
behavior = Qwen2VLConfigBehavior(behavior)
2450+
2451+
if behavior == Qwen2VLConfigBehavior.TEXT_EMBEDDINGS:
2452+
return get_vlm_text_embeddings_config("qwen2", self._orig_config, self.int_dtype, self.float_dtype)
2453+
2454+
if behavior == Qwen2VLConfigBehavior.LANGUAGE:
2455+
return get_vlm_text_generation_config(
2456+
"qwen2",
2457+
self._orig_config,
2458+
self.int_dtype,
2459+
self.float_dtype,
2460+
model_patcher=Qwen2VLLanguageModelPatcher,
2461+
dummy_input_generator=DummyQwen2VLLMInputGenerator,
2462+
inputs_update={"position_ids": {1: "batch_size", 2: "sequence_length"}},
2463+
)
2464+
2465+
if behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS:
2466+
return self.__class__(
2467+
self._orig_config,
2468+
task=self.task,
2469+
int_dtype=self.int_dtype,
2470+
float_dtype=self.float_dtype,
2471+
behavior=behavior,
2472+
preprocessors=self._preprocessors,
2473+
)
2474+
if behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER:
2475+
return self.__class__(
2476+
self._orig_config,
2477+
task=self.task,
2478+
int_dtype=self.int_dtype,
2479+
float_dtype=self.float_dtype,
2480+
behavior=behavior,
2481+
preprocessors=self._preprocessors,
2482+
)
2483+
2484+
def patch_model_for_export(
2485+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
2486+
):
2487+
model_kwargs = model_kwargs or {}
2488+
if self._behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER:
2489+
return Qwen2VLVisionEmbMergerPatcher(self, model, model_kwargs)
2490+
return super().patch_model_for_export(model, model_kwargs)
2491+
2492+
@property
2493+
def inputs(self) -> Dict[str, Dict[int, str]]:
2494+
if self._behavior == Phi3VisionConfigBehavior.VISION_EMBEDDINGS:
2495+
return {"hidden_states": {0: "patch_thw_grid", 1: "patch_temporal_channels"}}
2496+
if self._behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER:
2497+
return {
2498+
"hidden_states": {0: "sequence_length"},
2499+
"attention_mask": {1: "sequence_length", 2: "sequence_length"},
2500+
"rotary_pos_emb": {0: "sequence_length"},
2501+
}
2502+
2503+
@property
2504+
def outputs(self) -> Dict[str, Dict[int, str]]:
2505+
if self._behavior in [Qwen2VLConfigBehavior.VISION_EMBEDDINGS, Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER]:
2506+
return {"last_hidden_state": {0: "seq_len"}}
2507+
return {}

0 commit comments

Comments
 (0)