Skip to content

Commit eb044a9

Browse files
committed
mixtral and model patcher
1 parent 9068f37 commit eb044a9

File tree

5 files changed

+103
-10
lines changed

5 files changed

+103
-10
lines changed

optimum/exporters/openvino/__main__.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def main_export(
5959
local_files_only: bool = False,
6060
use_auth_token: Optional[Union[bool, str]] = None,
6161
model_kwargs: Optional[Dict[str, Any]] = None,
62-
custom_onnx_configs: Optional[Dict[str, "OnnxConfig"]] = None,
62+
custom_export_configs: Optional[Dict[str, "OnnxConfig"]] = None,
6363
fn_get_submodels: Optional[Callable] = None,
6464
compression_option: Optional[str] = None,
6565
compression_ratio: Optional[float] = None,
@@ -112,11 +112,11 @@ def main_export(
112112
when running `transformers-cli login` (stored in `~/.huggingface`).
113113
model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`):
114114
Experimental usage: keyword arguments to pass to the model during
115-
the export. This argument should be used along the `custom_onnx_configs` argument
115+
the export. This argument should be used along the `custom_export_configs` argument
116116
in case, for example, the model inputs/outputs are changed (for example, if
117117
`model_kwargs={"output_attentions": True}` is passed).
118-
custom_onnx_configs (`Optional[Dict[str, OnnxConfig]]`, defaults to `None`):
119-
Experimental usage: override the default ONNX config used for the given model. This argument may be useful for advanced users that desire a finer-grained control on the export. An example is available [here](https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model).
118+
custom_export_configs (`Optional[Dict[str, OnnxConfig]]`, defaults to `None`):
119+
Experimental usage: override the default export config used for the given model. This argument may be useful for advanced users that desire a finer-grained control on the export. An example is available [here](https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model).
120120
fn_get_submodels (`Optional[Callable]`, defaults to `None`):
121121
Experimental usage: Override the default submodels that are used at the export. This is
122122
especially useful when exporting a custom architecture that needs to split the ONNX (e.g. encoder-decoder). If unspecified with custom models, optimum will try to use the default submodels used for the given task, with no guarantee of success.
@@ -134,7 +134,7 @@ def main_export(
134134
```python
135135
>>> from optimum.exporters.openvino import main_export
136136
137-
>>> main_export("gpt2", output="gpt2_onnx/")
137+
>>> main_export("gpt2", output="gpt2_ov/")
138138
```
139139
"""
140140
original_task = task
@@ -183,14 +183,14 @@ def main_export(
183183
if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
184184
custom_architecture = True
185185
elif task not in TasksManager.get_supported_tasks_for_model_type(
186-
model_type, exporter="onnx", library_name=library_name
186+
model_type, exporter="openvino", library_name=library_name
187187
):
188188
if original_task == "auto":
189189
autodetected_message = " (auto-detected)"
190190
else:
191191
autodetected_message = ""
192192
model_tasks = TasksManager.get_supported_tasks_for_model_type(
193-
model_type, exporter="onnx", library_name=library_name
193+
model_type, exporter="openvino", library_name=library_name
194194
)
195195
raise ValueError(
196196
f"Asked to export a {model_type} model for the task {task}{autodetected_message}, but the Optimum OpenVINO exporter only supports the tasks {', '.join(model_tasks.keys())} for {model_type}. Please use a supported task. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the task {task} to be supported in the ONNX export for {model_type}."
@@ -265,7 +265,7 @@ class StoreAttr(object):
265265
not custom_architecture
266266
and library_name != "diffusers"
267267
and task + "-with-past"
268-
in TasksManager.get_supported_tasks_for_model_type(model_type, exporter="onnx", library_name=library_name)
268+
in TasksManager.get_supported_tasks_for_model_type(model_type, exporter="openvino", library_name=library_name)
269269
):
270270
# Make -with-past the default if --task was not explicitely specified
271271
if original_task == "auto":
@@ -297,7 +297,7 @@ class StoreAttr(object):
297297
compression_ratio=compression_ratio,
298298
stateful=stateful,
299299
model_kwargs=model_kwargs,
300-
custom_onnx_configs=custom_onnx_configs,
300+
custom_export_configs=custom_export_configs,
301301
fn_get_submodels=fn_get_submodels,
302302
preprocessors=preprocessors,
303303
device=device,

optimum/exporters/openvino/model_configs.py

+33
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,26 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
16+
17+
from packaging import version
18+
from transformers.utils import is_tf_available
1519

1620
from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
1721
from optimum.exporters.tasks import TasksManager
1822
from optimum.utils.input_generators import DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator
1923
from optimum.utils.normalized_config import NormalizedTextConfig
2024

25+
from .model_patcher import MixtralModelPatcher
26+
27+
28+
if TYPE_CHECKING:
29+
from transformers.modeling_utils import PreTrainedModel
30+
31+
from optimum.exporters.onnx.model_patcher import ModelPatcher
32+
33+
if is_tf_available():
34+
from transformers.modeling_tf_utils import TFPreTrainedModel
2135

2236
register_in_tasks_manager = TasksManager.create_register("openvino", overwrite_existing=True)
2337

@@ -54,3 +68,22 @@ class MiniCPMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
5468
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
5569
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
5670
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
71+
72+
73+
@register_in_tasks_manager("mixtral", *["text-generation", "text-generation-with-past"])
74+
class MixtralOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
75+
# This is because of the patching of torch.triu in AttentionMaskConverter, that exists from transformers>=4.35
76+
MIN_TRANSFORMERS_VERSION = version.parse("4.34.99")
77+
78+
# The ONNX export of this architecture needs the Trilu operator support, available since opset 14
79+
DEFAULT_ONNX_OPSET = 14
80+
DUMMY_INPUT_GENERATOR_CLASSES = (
81+
MistralDummyPastKeyValuesGenerator,
82+
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
83+
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
84+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True)
85+
86+
def patch_model_for_export(
87+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
88+
) -> "ModelPatcher":
89+
return MixtralModelPatcher(self, model, model_kwargs=model_kwargs)

optimum/exporters/openvino/model_patcher.py

+55
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
# limitations under the License.
1414

1515
import logging as log
16+
import types
1617

18+
import torch
19+
import torch.nn.functional as F
20+
21+
from optimum.exporters.onnx.model_patcher import DecoderModelPatcher
1722
from optimum.intel.utils.import_utils import (
1823
_torch_version,
1924
_transformers_version,
@@ -52,3 +57,53 @@ def patch_model_with_bettertransformer(model):
5257
return model
5358

5459
return model
60+
61+
62+
def mixtral_sparse_moe_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
63+
""" """
64+
batch_size, sequence_length, hidden_dim = hidden_states.shape
65+
hidden_states = hidden_states.view(-1, hidden_dim)
66+
# router_logits: (batch * sequence_length, n_experts)
67+
router_logits = self.gate(hidden_states)
68+
69+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
70+
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
71+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
72+
# we cast back to the input dtype
73+
routing_weights = routing_weights.to(hidden_states.dtype)
74+
75+
final_hidden_states = torch.zeros(
76+
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
77+
)
78+
79+
# One hot encode the selected experts to create an expert mask
80+
# this will be used to easily index which expert is going to be sollicitated
81+
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
82+
83+
# Loop over all available experts in the model and perform the computation on each expert
84+
for expert_idx in range(self.num_experts):
85+
expert_layer = self.experts[expert_idx]
86+
idx, top_x = torch.where(expert_mask[expert_idx])
87+
88+
# Index the correct hidden states and compute the expert hidden state for
89+
# the current expert. We need to make sure to multiply the output hidden
90+
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
91+
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
92+
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
93+
94+
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
95+
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
96+
return final_hidden_states, router_logits
97+
98+
99+
class MixtralModelPatcher(DecoderModelPatcher):
100+
def __enter__(self):
101+
super().__enter__()
102+
for layer in self._model.model.layers:
103+
layer.block_sparse_moe._unpatched_forward = layer.block_sparse_moe.forward
104+
layer.block_sparse_moe.forward = types.MethodType(mixtral_sparse_moe_block_forward, layer.block_sparse_moe)
105+
106+
def __exit__(self, exc_type, exc_value, traceback):
107+
super().__exit__(exc_type, exc_value, traceback)
108+
for layer in self._model.model.layers:
109+
layer.block_sparse_moe.forward = layer.block_sparse_moe._unpatched_forward

tests/openvino/test_modeling.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
486486
"llama_gptq",
487487
"marian",
488488
"mistral",
489+
"mixtral",
489490
"mpt",
490491
"opt",
491492
"pegasus",
@@ -520,7 +521,10 @@ def test_compare_to_transformers(self, model_arch):
520521
self.assertIsInstance(ov_outputs.logits, torch.Tensor)
521522
self.assertTrue("past_key_values" in ov_outputs)
522523
self.assertIsInstance(ov_outputs.past_key_values, tuple)
523-
if self.IS_SUPPORT_STATEFUL and model_arch != "gpt_bigcode":
524+
not_stateful = ["gpt_bogcode"]
525+
if is_openvino_version("<", "2024.0"):
526+
not_stateful.append("mixtral")
527+
if self.IS_SUPPORT_STATEFUL and model_arch not in not_stateful:
524528
self.assertTrue(len(ov_outputs.past_key_values) == 1 and len(ov_outputs.past_key_values[0]) == 0)
525529
with torch.no_grad():
526530
transformers_outputs = transformers_model(**tokens)

tests/openvino/utils_tests.py

+1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
"marian": "sshleifer/tiny-marian-en-de",
5757
"mbart": "hf-internal-testing/tiny-random-mbart",
5858
"mistral": "echarlaix/tiny-random-mistral",
59+
"mixtral": "TitanML/tiny-mixtral",
5960
"mobilebert": "hf-internal-testing/tiny-random-MobileBertModel",
6061
"mobilenet_v1": "google/mobilenet_v1_0.75_192",
6162
"mobilenet_v2": "hf-internal-testing/tiny-random-MobileNetV2Model",

0 commit comments

Comments
 (0)