Skip to content

Commit f932c3a

Browse files
committed
nanollava support
1 parent 24414cd commit f932c3a

File tree

5 files changed

+374
-15
lines changed

5 files changed

+374
-15
lines changed

optimum/exporters/openvino/model_configs.py

+165-4
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
1818

1919
from packaging import version
20-
from transformers import PretrainedConfig, PreTrainedModel, TFPreTrainedModel
20+
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, TFPreTrainedModel
2121
from transformers.utils import is_tf_available
2222

2323
from optimum.exporters.onnx.config import OnnxConfig, TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
@@ -69,6 +69,7 @@
6969
JaisModelPatcher,
7070
LlamaModelPatcher,
7171
LlavaImageEmbeddingModelPatcher,
72+
LlavaQwen2ImageEmbeddingsModelPatcher,
7273
MiniCPMVImageEmbeddingsModelPatcher,
7374
MiniCPMVResamplerModelPatcher,
7475
MistralModelPatcher,
@@ -1218,8 +1219,8 @@ def patch_model_for_export(
12181219

12191220

12201221
class LlavaConfigBehavior(str, enum.Enum):
1221-
LANGUAGE = "language"
12221222
VISION_EMBEDDINGS = "vision_embeddings"
1223+
LANGUAGE = "language"
12231224
TEXT_EMBEDDINGS = "text_embeddings"
12241225

12251226

@@ -1380,6 +1381,166 @@ class LlavaNextOpenVINOConfig(LlavaOpenVINOConfig):
13801381
MIN_TRANSFORMERS_VERSION = version.parse("4.40.0")
13811382

13821383

1384+
@register_in_tasks_manager(
1385+
"llava-qwen2", *["image-text-to-text", "text-generation", "text-generation-with-past"], library_name="transformers"
1386+
)
1387+
class LlavaQwen2OpenVINOConfig(OnnxConfig):
1388+
SUPPORTS_PAST = True
1389+
MIN_TRANSFORMERS_VERSION = version.parse("4.40.0")
1390+
SUPPORTED_BEHAVIORS = [model_type.value for model_type in LlavaConfigBehavior]
1391+
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
1392+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator,)
1393+
1394+
def __init__(
1395+
self,
1396+
config: "PretrainedConfig",
1397+
task: str = "feature-extraction",
1398+
int_dtype: str = "int64",
1399+
float_dtype: str = "fp32",
1400+
behavior: LlavaConfigBehavior = LlavaConfigBehavior.VISION_EMBEDDINGS,
1401+
preprocessors: Optional[List[Any]] = None,
1402+
use_past: bool = False,
1403+
):
1404+
self._behavior = behavior
1405+
self._orig_config = config
1406+
if self._behavior == LlavaConfigBehavior.VISION_EMBEDDINGS:
1407+
config = AutoConfig.from_pretrained(config.mm_vision_tower, trust_remote_code=True)
1408+
if hasattr(config, "vision_config"):
1409+
config = config.vision_config
1410+
super().__init__(
1411+
config=config,
1412+
task=task,
1413+
int_dtype=int_dtype,
1414+
float_dtype=float_dtype,
1415+
preprocessors=preprocessors,
1416+
)
1417+
1418+
@property
1419+
def inputs(self) -> Dict[str, Dict[int, str]]:
1420+
if not self._behavior == LlavaConfigBehavior.VISION_EMBEDDINGS:
1421+
return {}
1422+
return {"pixel_values": {0: "batch_size", 2: "height", 3: "width"}}
1423+
1424+
@property
1425+
def outputs(self) -> Dict[str, Dict[int, str]]:
1426+
if not self._behavior == LlavaConfigBehavior.VISION_EMBEDDINGS:
1427+
return {}
1428+
return {"last_hidden_state": {0: "batch_size"}}
1429+
1430+
def get_model_for_behavior(self, model, behavior: Union[str, LlavaConfigBehavior]):
1431+
if isinstance(behavior, str) and not isinstance(behavior, LlavaConfigBehavior):
1432+
behavior = LlavaConfigBehavior(behavior)
1433+
1434+
if behavior == LlavaConfigBehavior.LANGUAGE:
1435+
model.forward = super(type(model), model).forward
1436+
return model
1437+
1438+
if behavior == LlavaConfigBehavior.VISION_EMBEDDINGS:
1439+
return model
1440+
1441+
if behavior == LlavaConfigBehavior.TEXT_EMBEDDINGS:
1442+
text_embedding = model.model.embed_tokens
1443+
text_embedding.config = model.model.config
1444+
return text_embedding
1445+
1446+
def with_behavior(
1447+
self,
1448+
behavior: Union[str, LlavaConfigBehavior],
1449+
):
1450+
"""
1451+
Creates a config for different behaviour.
1452+
1453+
Args:
1454+
behavior ([`ConfigBehavior`]):
1455+
The behavior to use for the new instance.
1456+
"""
1457+
if isinstance(behavior, str) and not isinstance(behavior, LlavaConfigBehavior):
1458+
behavior = LlavaConfigBehavior(behavior)
1459+
1460+
if behavior == LlavaConfigBehavior.TEXT_EMBEDDINGS:
1461+
model_type = self._orig_config.model_type.replace("llava-", "")
1462+
model_type = model_type.replace("_", "-")
1463+
if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
1464+
raise ValueError(
1465+
f"Unsupported language model type provided `{model_type}`. Please define custom export config"
1466+
)
1467+
1468+
if "text-generation-with-past" not in TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"]:
1469+
raise ValueError(
1470+
f"Export config for text generation for `{model_type}` is not available. Please define custom export config"
1471+
)
1472+
internal_export_config_class = TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"][
1473+
"text-generation-with-past"
1474+
]
1475+
internal_export_config = internal_export_config_class(
1476+
self._orig_config,
1477+
use_past=True,
1478+
use_past_in_inputs=True,
1479+
int_dtype=self.int_dtype,
1480+
float_dtype=self.float_dtype,
1481+
)
1482+
InputEmbedOpenvVINOConfig.NORMALIZED_CONFIG_CLASS = internal_export_config.NORMALIZED_CONFIG_CLASS
1483+
export_config = InputEmbedOpenvVINOConfig(
1484+
self._orig_config,
1485+
task="feature-extraction",
1486+
int_dtype=self.int_dtype,
1487+
float_dtype=self.float_dtype,
1488+
)
1489+
return export_config
1490+
1491+
if behavior == LlavaConfigBehavior.LANGUAGE:
1492+
model_type = self._orig_config.model_type.replace("llava-", "")
1493+
model_type = model_type.replace("_", "-")
1494+
1495+
if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
1496+
raise ValueError(
1497+
f"Unsupported language model type provided `{model_type}`. Please define custom export config"
1498+
)
1499+
1500+
if "text-generation-with-past" not in TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"]:
1501+
raise ValueError(
1502+
f"Export config for text generation for `{model_type}` is not available. Please define custom export config"
1503+
)
1504+
internal_export_config_class = TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"][
1505+
"text-generation-with-past"
1506+
]
1507+
internal_export_config = internal_export_config_class(
1508+
self._orig_config,
1509+
use_past=True,
1510+
use_past_in_inputs=True,
1511+
int_dtype=self.int_dtype,
1512+
float_dtype=self.float_dtype,
1513+
)
1514+
export_config = LMInputEmbedsConfigHelper(internal_export_config)
1515+
export_config._normalized_config = internal_export_config._normalized_config
1516+
return export_config
1517+
1518+
if behavior == LlavaConfigBehavior.VISION_EMBEDDINGS:
1519+
return self.__class__(
1520+
self._orig_config,
1521+
task=self.task,
1522+
int_dtype=self.int_dtype,
1523+
float_dtype=self.float_dtype,
1524+
behavior=behavior,
1525+
preprocessors=self._preprocessors,
1526+
)
1527+
1528+
def patch_model_for_export(
1529+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
1530+
):
1531+
model_kwargs = model_kwargs or {}
1532+
if self._behavior != LlavaConfigBehavior.VISION_EMBEDDINGS:
1533+
return super().patch_model_for_export(model, model_kwargs)
1534+
return LlavaQwen2ImageEmbeddingsModelPatcher(self, model, model_kwargs)
1535+
1536+
def rename_ambiguous_inputs(self, inputs):
1537+
if self._behavior == LlavaConfigBehavior.VISION_EMBEDDINGS:
1538+
model_inputs = {}
1539+
model_inputs["images"] = inputs["pixel_values"]
1540+
return model_inputs
1541+
return super().rename_ambiguous_inputs(inputs)
1542+
1543+
13831544
class InternVLChatConfigBehavior(str, enum.Enum):
13841545
LANGUAGE = "language"
13851546
VISION_EMBEDDINGS = "vision_embeddings"
@@ -1508,8 +1669,8 @@ def with_behavior(
15081669
preprocessors=self._preprocessors,
15091670
)
15101671

1511-
def get_model_for_behavior(self, model, behavior: Union[str, LlavaConfigBehavior]):
1512-
if isinstance(behavior, str) and not isinstance(behavior, LlavaConfigBehavior):
1672+
def get_model_for_behavior(self, model, behavior: Union[str, InternVLChatConfigBehavior]):
1673+
if isinstance(behavior, str) and not isinstance(behavior, InternVLChatConfigBehavior):
15131674
behavior = InternVLChatConfigBehavior(behavior)
15141675

15151676
if behavior == InternVLChatConfigBehavior.LANGUAGE:

optimum/exporters/openvino/model_patcher.py

+18
Original file line numberDiff line numberDiff line change
@@ -2936,3 +2936,21 @@ def forward(self, input):
29362936
def __exit__(self, exc_type, exc_value, traceback):
29372937
super().__exit__(exc_type, exc_value, traceback)
29382938
self._model.forward = self._model.__orig_forward
2939+
2940+
2941+
class LlavaQwen2ImageEmbeddingsModelPatcher(ModelPatcher):
2942+
def __init__(
2943+
self,
2944+
config: "OnnxConfig",
2945+
model: Union["PreTrainedModel", "TFPreTrainedModel"],
2946+
model_kwargs: Dict[str, Any],
2947+
):
2948+
model.__orig_forward = model.forward
2949+
model.forward = model.encode_images
2950+
super().__init__(config, model, model_kwargs)
2951+
if not self._model.get_vision_tower().is_loaded:
2952+
self._model.get_vision_tower().load_model()
2953+
2954+
def __exit__(self, exc_type, exc_value, traceback):
2955+
super().__exit__(exc_type, exc_value, traceback)
2956+
self._model.forward = self._model.__orig_forward

optimum/exporters/openvino/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -208,4 +208,4 @@ def get_submodels(model):
208208
return custom_export, fn_get_submodels
209209

210210

211-
MULTI_MODAL_TEXT_GENERATION_MODELS = ["llava", "llava-next", "internvl-chat", "minicpmv"]
211+
MULTI_MODAL_TEXT_GENERATION_MODELS = ["llava", "llava-next", "llava-qwen2", "internvl-chat", "minicpmv"]

optimum/intel/openvino/modeling_decoder.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -504,8 +504,8 @@ def prepare_inputs(
504504
else:
505505
position_ids = np.cumsum(attention_mask, axis=1) - 1
506506
position_ids[attention_mask == 0] = 1
507-
if past_key_values:
508-
position_ids = position_ids[:, -input_ids.shape[1] :]
507+
if past_key_values:
508+
position_ids = position_ids[:, -input_ids.shape[1] :]
509509

510510
inputs["position_ids"] = position_ids
511511

0 commit comments

Comments
 (0)