From 529381e3f210f419f658106a4b33cf39082ca0c6 Mon Sep 17 00:00:00 2001 From: eaidova Date: Mon, 22 Apr 2024 17:17:47 +0400 Subject: [PATCH 1/8] apply sdpa for mpt and internlm --- optimum/exporters/openvino/model_configs.py | 19 ++- optimum/exporters/openvino/model_patcher.py | 136 +++++++++++++++++++- 2 files changed, 151 insertions(+), 4 deletions(-) diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 90297c8fb3..9ecf6bcaee 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -19,7 +19,7 @@ from transformers.utils import is_tf_available from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig -from optimum.exporters.onnx.model_configs import GemmaOnnxConfig, LlamaOnnxConfig +from optimum.exporters.onnx.model_configs import GemmaOnnxConfig, LlamaOnnxConfig, MPTOnnxConfig from optimum.exporters.tasks import TasksManager from optimum.utils import DEFAULT_DUMMY_SHAPES from optimum.utils.input_generators import ( @@ -37,6 +37,8 @@ LlamaModelPatcher, MixtralModelPatcher, QwenModelPatcher, + MPTModelPatcher, + InternLMPatcher, ) @@ -429,6 +431,11 @@ class InternLM2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return InternLMPatcher(self, model, model_kwargs=model_kwargs) + @register_in_tasks_manager("orion", *["text-generation", "text-generation-with-past"], library_name="transformers") class OrionOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): @@ -437,3 +444,13 @@ class OrionOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator) DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + +@register_in_tasks_manager( + "mpt", *["text-generation", "text-generation-with-past", "text-classification"], library_name="transformers" +) +class MPTOpenVINOConfig(MPTOnnxConfig): + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return MPTModelPatcher(self, model, model_kwargs=model_kwargs) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 35983975ed..e598458cda 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -13,6 +13,7 @@ # limitations under the License. import logging as log +import math import types from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union @@ -339,9 +340,9 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[ - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] - ] = mask_slice + causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = ( + mask_slice + ) if ( self.config._attn_implementation == "sdpa" @@ -623,3 +624,132 @@ def __init__( # model has first inference buffers initialization if hasattr(self._model.lm_head, "first_flag"): self._model(torch.ones((1, 10), dtype=torch.int64), torch.ones((1, 10), dtype=torch.int64)) + + +def _mpt_attention_forward( + self, + hidden_states: torch.Tensor, + position_bias: torch.Tensor, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, +): + batch_size, seq_length = hidden_states.shape[:2] + + mixed_qkv = self.Wqkv(hidden_states) + query_states, key_states, value_states = mixed_qkv.chunk(3, dim=2) + query_states = query_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) + key_states = key_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) + value_states = value_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + if len(past_key_value) != 0: + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + past_key_value = (key_states, value_states) + else: + past_key_value = (key_states, value_states) + + attention_mask_sdpa = torch.ones(attention_mask.shape, dtype=query_states.dtype) + attention_mask_sdpa.masked_fill_(attention_mask, torch.finfo(query_states.dtype).min) + context_states = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask_sdpa, + dropout_p=self.attn_dropout_p, + scale=self.softmax_scale, + ) + context_states = context_states.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1) + attn_output = self.out_proj(context_states) + + return attn_output, None, past_key_value + + +class MPTModelPatcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + + if is_torch_version(">=", "2.1.0"): + for block in self._model.transformer.blocks: + block.attn._orig_forward = block.attn.forward + block.attn.forward = types.MethodType(_mpt_attention_forward, block.attn) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + for block in self._model.transformer.blocks: + if hasattr(block.attn, "_orig_forward"): + block.attn.forward = block.attn._orig_forward + + +def _internlm_attention_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + + from transformers.models.llama.modeling_llama import repeat_kv, apply_rotary_pos_emb + + bsz, q_len, _ = hidden_states.size() + + qkv_states = self.wqkv(hidden_states) + + qkv_states = qkv_states.reshape( + qkv_states.shape[0], qkv_states.shape[1], -1, 2 + self.num_key_values_groups, self.head_dim + ) + query_states = qkv_states[..., : self.num_key_value_groups, :] + query_states = query_states.reshape(query_states.shape[0], query_states.shape[1], -1, query_states.shape[-1]) + key_states = qkv_states[..., -2, :] + value_states = qkv_states[..., -1, :] + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, scale=(1 / math.sqrt(self.head_dim)) + ) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.wo(attn_output) + + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class InternLMPatcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + + if is_torch_version(">=", "2.1.0"): + for block in self._model.model.layers: + block.attention._orig_forward = block.attention.forward + block.attention.forward = types.MethodType(_internlm_attention_forward, block.attention) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + for block in self._model.model.layers: + if hasattr(block.attention, "_orig_forward"): + block.attention.forward = block.attention._orig_forward From 7a2bdf33a2178ffe56035b0478e20ffe3ebb6004 Mon Sep 17 00:00:00 2001 From: eaidova Date: Mon, 22 Apr 2024 20:13:40 +0400 Subject: [PATCH 2/8] fix bauchan-13b --- optimum/exporters/openvino/convert.py | 21 ++--- optimum/exporters/openvino/model_configs.py | 4 +- optimum/exporters/openvino/model_patcher.py | 94 +++++++++++++++++++-- optimum/intel/openvino/quantization.py | 6 +- 4 files changed, 105 insertions(+), 20 deletions(-) diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index 55e3318017..5c90dc7b71 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -347,6 +347,7 @@ def ts_patched_forward(*args, **kwargs): with patcher: check_dummy_inputs_are_allowed(model, dummy_inputs) + sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call) inputs = config.ordered_inputs(model) input_names = list(inputs.keys()) output_names = list(config.outputs.keys()) @@ -376,7 +377,6 @@ def ts_patched_forward(*args, **kwargs): ov_config=ov_config, ) - sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call) ordered_dummy_inputs = {param: dummy_inputs[param] for param in sig.parameters if param in dummy_inputs} if not ordered_dummy_inputs: ordered_dummy_inputs = dummy_inputs @@ -388,15 +388,16 @@ def ts_patched_forward(*args, **kwargs): out_tensor.get_tensor().set_names({output_names[idx]}) for idx, inp_tensor in enumerate(ov_model.inputs): - input_name = ordered_input_names[idx] - inp_tensor.get_tensor().set_names({input_name}) - inp_data = flatten_inputs[idx] - static_shape = PartialShape(inp_data.shape) - dims = inputs[input_name] - for dim in dims: - static_shape[dim] = -1 - inp_tensor.get_node().set_partial_shape(static_shape) - inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype)) + if idx < len(ordered_input_names): + input_name = ordered_input_names[idx] + inp_tensor.get_tensor().set_names({input_name}) + inp_data = flatten_inputs[idx] + static_shape = PartialShape(inp_data.shape) + dims = inputs.get(input_name, []) + for dim in dims: + static_shape[dim] = -1 + inp_tensor.get_node().set_partial_shape(static_shape) + inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype)) ov_model.validate_nodes_and_infer_types() if stateful: diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 9ecf6bcaee..b10e744801 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -34,11 +34,11 @@ BaichuanModelPatcher, ChatGLMModelPatcher, GemmaModelPatcher, + InternLMPatcher, LlamaModelPatcher, MixtralModelPatcher, - QwenModelPatcher, MPTModelPatcher, - InternLMPatcher, + QwenModelPatcher, ) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index e598458cda..ff08c2092e 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import logging as log import math import types @@ -340,9 +341,9 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = ( - mask_slice - ) + causal_mask[ + : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] + ] = mask_slice if ( self.config._attn_implementation == "sdpa" @@ -613,6 +614,46 @@ def __exit__(self, exc_type, exc_value, traceback): self._model.config.fp16 = self.original_fp16 +def _baichuan13b_atten_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = True, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + proj = self.W_pack(hidden_states) + proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) + query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + if past_key_value is not None: + # reuse k, v, self_attention + if attention_mask is not None: + attention_mask = attention_mask[:, :, -key_states.shape[-2] :, :] + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask) + attn_output = attn_output.transpose(1, 2) + attn_weights = None + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class BaichuanModelPatcher(DecoderModelPatcher): def __init__( self, @@ -625,6 +666,50 @@ def __init__( if hasattr(self._model.lm_head, "first_flag"): self._model(torch.ones((1, 10), dtype=torch.int64), torch.ones((1, 10), dtype=torch.int64)) + def __enter__(self): + super().__enter__() + # override signature to have position_ids + if "position_ids" not in inspect.signature(self._model.forward).parameters: + self._model._orig_forward = self._model.forward + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + position_ids: Optional[torch.LongTensor] = None, + ): + return self._orig_forward( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=past_key_values is not None, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=self.config.return_dict, + ) + + self._model.forward = types.MethodType(forward, self._model) + for layer in self._model.model.layers: + layer.self_attn._orig_forward = layer.self_attn.forward + layer.self_attn.forward = types.MethodType(_baichuan13b_atten_forward, layer.self_attn) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + if hasattr(self._model, "_orig_forward"): + self._model.forward = self._model._orig_forward + + for layer in self._model.model.layers: + layer.self_attn.forward = layer.self_attn._orig_forward + def _mpt_attention_forward( self, @@ -691,8 +776,7 @@ def _internlm_attention_forward( use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - - from transformers.models.llama.modeling_llama import repeat_kv, apply_rotary_pos_emb + from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv bsz, q_len, _ = hidden_states.size() diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index 124b0366c1..2ebb3ed0c2 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -477,9 +477,9 @@ def _quantize_torchmodel( subset_size=quantization_config.num_samples, ignored_scope=quantization_config.get_ignored_scope_instance(), model_type=nncf.ModelType(quantization_config.model_type), - preset=nncf.QuantizationPreset.PERFORMANCE - if quantization_config.sym - else nncf.QuantizationPreset.MIXED, + preset=( + nncf.QuantizationPreset.PERFORMANCE if quantization_config.sym else nncf.QuantizationPreset.MIXED + ), fast_bias_correction=quantization_config.fast_bias_correction, advanced_parameters=nncf.AdvancedQuantizationParameters( overflow_fix=OverflowFix(quantization_config.overflow_fix) From 67130593d25f6b33dbd12f4d01d5d75cdf5d39f0 Mon Sep 17 00:00:00 2001 From: eaidova Date: Tue, 23 Apr 2024 09:39:06 +0400 Subject: [PATCH 3/8] fix accuracy --- optimum/exporters/openvino/model_patcher.py | 52 +++++++++++++++++++-- 1 file changed, 47 insertions(+), 5 deletions(-) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index ff08c2092e..19bdad26d3 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -734,7 +734,18 @@ def _mpt_attention_forward( else: past_key_value = (key_states, value_states) - attention_mask_sdpa = torch.ones(attention_mask.shape, dtype=query_states.dtype) + key_length = key_states.shape[-2] + query_length = seq_length if past_key_value is None else seq_length + past_key_value[0].shape[2] + attention_mask_sdpa = torch.ones( + (query_states.shape[0], query_states.shape[1], query_states.shape[2], key_states.shape[2]), + dtype=query_states.dtype, + ) + if position_bias is not None: + position_bias_query_index = max(0, position_bias.size(1) - query_length) + position_bias_key_index = max(0, position_bias.size(2) - key_length) + + position_bias = position_bias[:, position_bias_query_index:, position_bias_key_index:] + attention_mask_sdpa += position_bias attention_mask_sdpa.masked_fill_(attention_mask, torch.finfo(query_states.dtype).min) context_states = torch.nn.functional.scaled_dot_product_attention( query_states, @@ -744,6 +755,7 @@ def _mpt_attention_forward( dropout_p=self.attn_dropout_p, scale=self.softmax_scale, ) + context_states = context_states.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1) attn_output = self.out_proj(context_states) @@ -776,17 +788,47 @@ def _internlm_attention_forward( use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv + # from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv + from einops import rearrange + + def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors.""" + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) bsz, q_len, _ = hidden_states.size() qkv_states = self.wqkv(hidden_states) - qkv_states = qkv_states.reshape( - qkv_states.shape[0], qkv_states.shape[1], -1, 2 + self.num_key_values_groups, self.head_dim + qkv_states = rearrange( + qkv_states, + "b q (h gs d) -> b q h gs d", + gs=2 + self.num_key_value_groups, + d=self.head_dim, ) + query_states = qkv_states[..., : self.num_key_value_groups, :] - query_states = query_states.reshape(query_states.shape[0], query_states.shape[1], -1, query_states.shape[-1]) + query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d") key_states = qkv_states[..., -2, :] value_states = qkv_states[..., -1, :] From 93e77a1941e9655f1a8258799891678c635ddd96 Mon Sep 17 00:00:00 2001 From: eaidova Date: Wed, 24 Apr 2024 19:24:06 +0400 Subject: [PATCH 4/8] small refactoring --- optimum/exporters/openvino/model_patcher.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 19bdad26d3..848fc2caa2 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -341,9 +341,9 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[ - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] - ] = mask_slice + causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = ( + mask_slice + ) if ( self.config._attn_implementation == "sdpa" @@ -648,9 +648,6 @@ def _baichuan13b_atten_forward( attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value From 5aa30ede14a56f8cbde90d2379ef9513fdb31a6a Mon Sep 17 00:00:00 2001 From: eaidova Date: Wed, 24 Apr 2024 19:51:04 +0400 Subject: [PATCH 5/8] add test for baichuan 13b --- optimum/exporters/openvino/convert.py | 19 +++++++++---------- optimum/exporters/openvino/model_patcher.py | 6 +++--- tests/openvino/test_modeling.py | 13 ++++++++++++- tests/openvino/utils_tests.py | 3 ++- 4 files changed, 26 insertions(+), 15 deletions(-) diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index 5c90dc7b71..1ab7a550f6 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -388,16 +388,15 @@ def ts_patched_forward(*args, **kwargs): out_tensor.get_tensor().set_names({output_names[idx]}) for idx, inp_tensor in enumerate(ov_model.inputs): - if idx < len(ordered_input_names): - input_name = ordered_input_names[idx] - inp_tensor.get_tensor().set_names({input_name}) - inp_data = flatten_inputs[idx] - static_shape = PartialShape(inp_data.shape) - dims = inputs.get(input_name, []) - for dim in dims: - static_shape[dim] = -1 - inp_tensor.get_node().set_partial_shape(static_shape) - inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype)) + input_name = ordered_input_names[idx] + inp_tensor.get_tensor().set_names({input_name}) + inp_data = flatten_inputs[idx] + static_shape = PartialShape(inp_data.shape) + dims = inputs.get(input_name, []) + for dim in dims: + static_shape[dim] = -1 + inp_tensor.get_node().set_partial_shape(static_shape) + inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype)) ov_model.validate_nodes_and_infer_types() if stateful: diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 848fc2caa2..850762977b 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -341,9 +341,9 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = ( - mask_slice - ) + causal_mask[ + : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] + ] = mask_slice if ( self.config._attn_implementation == "sdpa" diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index f84cac8161..06d0e85bef 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -510,6 +510,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): SUPPORTED_ARCHITECTURES = ( "bart", "baichuan2", + "baichuan2-13b", "gpt_bigcode", "blenderbot", "blenderbot-small", @@ -540,7 +541,17 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "falcon", ) GENERATION_LENGTH = 100 - REMOTE_CODE_MODELS = ("chatglm", "minicpm", "baichuan2", "jais", "qwen", "internlm2", "olmo", "orion") + REMOTE_CODE_MODELS = ( + "chatglm", + "minicpm", + "baichuan2", + "baichuan2-13b", + "jais", + "qwen", + "internlm2", + "olmo", + "orion", + ) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_transformers(self, model_arch): diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index ca56f6d552..bbd30b1023 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -24,6 +24,7 @@ "bert": "hf-internal-testing/tiny-random-bert", "bart": "hf-internal-testing/tiny-random-bart", "baichuan2": "katuni4ka/tiny-random-baichuan2", + "baichuan2-13b": "katuni4ka/tiny-random-baichuan2-13b", "bigbird_pegasus": "hf-internal-testing/tiny-random-bigbird_pegasus", "blenderbot-small": "hf-internal-testing/tiny-random-BlenderbotModel", "blenderbot": "hf-internal-testing/tiny-random-BlenderbotModel", @@ -80,7 +81,7 @@ "poolformer": "hf-internal-testing/tiny-random-PoolFormerModel", "qwen": "katuni4ka/tiny-random-qwen", "qwen2": "Qwen/Qwen1.5-0.5B", - "resnet": "hf-internal-testing/tiny-random-resnet", + "resnet": "hf-internal-testing/tiny-rantiny-random-baichuan2-13bdom-resnet", "roberta": "hf-internal-testing/tiny-random-roberta", "roformer": "hf-internal-testing/tiny-random-roformer", "segformer": "hf-internal-testing/tiny-random-SegformerModel", From e60872cf36f2b65d26d7bbdeb8142b036515578b Mon Sep 17 00:00:00 2001 From: eaidova Date: Thu, 25 Apr 2024 13:54:11 +0400 Subject: [PATCH 6/8] add support output_attentions --- optimum/exporters/openvino/model_patcher.py | 93 ++++++++++++++++++--- 1 file changed, 83 insertions(+), 10 deletions(-) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 850762977b..877d2368f4 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -640,11 +640,25 @@ def _baichuan13b_atten_forward( attention_mask = attention_mask[:, :, -key_states.shape[-2] :, :] key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) + if not output_attentions: + past_key_value = (key_states, value_states) if use_cache else None + attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask) + attn_weights = None + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: + if q_len == 1: # inference with cache + if len(attention_mask.size()) == 4: + attention_mask = attention_mask[:, :, -1:, :] + else: + attention_mask = attention_mask[:, -1:, :] + attn_weights = attn_weights + attention_mask + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + attn_output = torch.matmul(attn_weights, value_states) - past_key_value = (key_states, value_states) if use_cache else None - attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask) attn_output = attn_output.transpose(1, 2) - attn_weights = None attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) @@ -708,7 +722,7 @@ def __exit__(self, exc_type, exc_value, traceback): layer.self_attn.forward = layer.self_attn._orig_forward -def _mpt_attention_forward( +def _mpt_sdpa_attention_forward( self, hidden_states: torch.Tensor, position_bias: torch.Tensor, @@ -759,18 +773,73 @@ def _mpt_attention_forward( return attn_output, None, past_key_value +def _mpt_block_forward( + self, + hidden_states: torch.Tensor, + position_bias: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + output_attentions: bool = False, +): + # hidden_states: [batch_size, seq_length, hidden_size] + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.norm_1(hidden_states) + + residual = hidden_states + + if not output_attentions: + # Self attention. + attn_outputs, attn_weights, past_key_value = self.attn( + layernorm_output, + position_bias=position_bias, + attention_mask=attention_mask, + past_key_value=layer_past, + ) + else: + attn_outputs, attn_weights, past_key_value = self.attn._orig_forward( + layernorm_output, + position_bias=position_bias, + attention_mask=attention_mask, + past_key_value=layer_past, + ) + + hidden_states = self.resid_attn_dropout(attn_outputs) + residual + + layernorm_output = self.norm_2(hidden_states) + + # Get residual + residual = hidden_states + + # MLP. + output = self.ffn(layernorm_output, residual) + outputs = (output,) + + if use_cache: + outputs += (past_key_value,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + class MPTModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() if is_torch_version(">=", "2.1.0"): for block in self._model.transformer.blocks: + block._orig_forward = block.forward + block.forward = types.MethodType(_mpt_block_forward, block) block.attn._orig_forward = block.attn.forward - block.attn.forward = types.MethodType(_mpt_attention_forward, block.attn) + block.attn.forward = types.MethodType(_mpt_sdpa_attention_forward, block.attn) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) for block in self._model.transformer.blocks: + if hasattr(block, "_orig_forward"): + block.forward = block._orig_forward if hasattr(block.attn, "_orig_forward"): block.attn.forward = block.attn._orig_forward @@ -848,17 +917,21 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) + if not output_attentions: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, scale=(1 / math.sqrt(self.head_dim)) + ) + attn_weights = None + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, attention_mask, scale=(1 / math.sqrt(self.head_dim)) - ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.wo(attn_output) - attn_weights = None - return attn_output, attn_weights, past_key_value From 4f36b6f6e5489a0472fc6c9b21619bf0f83535a6 Mon Sep 17 00:00:00 2001 From: eaidova Date: Thu, 25 Apr 2024 19:01:19 +0400 Subject: [PATCH 7/8] code style --- optimum/exporters/openvino/model_configs.py | 10 ++++++++-- tests/openvino/test_modeling.py | 2 +- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 677d0423e9..a12a953824 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -19,7 +19,13 @@ from transformers.utils import is_tf_available from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig -from optimum.exporters.onnx.model_configs import FalconOnnxConfig, GemmaOnnxConfig, LlamaOnnxConfig, MPTOnnxConfig, PhiOnnxConfig +from optimum.exporters.onnx.model_configs import ( + FalconOnnxConfig, + GemmaOnnxConfig, + LlamaOnnxConfig, + MPTOnnxConfig, + PhiOnnxConfig, +) from optimum.exporters.tasks import TasksManager from optimum.utils import DEFAULT_DUMMY_SHAPES from optimum.utils.input_generators import ( @@ -458,7 +464,7 @@ def patch_model_for_export( return MPTModelPatcher(self, model, model_kwargs=model_kwargs) -@register_in_tasks_manager( +@register_in_tasks_manager( "phi3", *[ "feature-extraction", diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 9f1f26869f..3bfe0a2e8f 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -553,7 +553,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "internlm2", "olmo", "orion", - "phi3" + "phi3", ) @parameterized.expand(SUPPORTED_ARCHITECTURES) From 9bde68653b9ade46bbd964c61ace2d56f8ccb99b Mon Sep 17 00:00:00 2001 From: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> Date: Tue, 30 Apr 2024 16:05:48 +0200 Subject: [PATCH 8/8] Update tests/openvino/utils_tests.py --- tests/openvino/utils_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index 4c02df8883..9f28e40a4b 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -84,7 +84,7 @@ "poolformer": "hf-internal-testing/tiny-random-PoolFormerModel", "qwen": "katuni4ka/tiny-random-qwen", "qwen2": "Qwen/Qwen1.5-0.5B", - "resnet": "hf-internal-testing/tiny-rantiny-random-baichuan2-13bdom-resnet", + "resnet": "hf-internal-testing/tiny-random-resnet", "roberta": "hf-internal-testing/tiny-random-roberta", "roformer": "hf-internal-testing/tiny-random-roformer", "segformer": "hf-internal-testing/tiny-random-SegformerModel",