Skip to content

Commit e596cc7

Browse files
committed
enable attention mask and fix accuracy issue for chatglm
1 parent fae7802 commit e596cc7

File tree

4 files changed

+83
-6
lines changed

4 files changed

+83
-6
lines changed

optimum/exporters/openvino/dummy_input_generators.py

+12
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from typing import Optional, Tuple
1616

17+
import torch
18+
1719
from optimum.utils import (
1820
DEFAULT_DUMMY_SHAPES,
1921
DummyPastKeyValuesGenerator,
@@ -30,6 +32,16 @@ class ChatGLN2DummyTextInputGenerator(DummyTextInputGenerator):
3032
"position_ids",
3133
}
3234

35+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
36+
input = super().generate(input_name, framework, int_dtype, float_dtype)
37+
if input_name == "attention_mask":
38+
input = torch.ones((input.shape[0], input.shape[1] + 1), dtype=input.dtype)
39+
# input[0] = 0
40+
if input_name == "position_ids":
41+
input = torch.range(0, input.shape[1] + 1, dtype=input.dtype).repeat(1, 1)
42+
# input[0] = 0
43+
return input
44+
3345

3446
class ChatGLM2DummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
3547
def __init__(

optimum/exporters/openvino/model_configs.py

-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ class ChatGLM2OpenVINOConfig(TextDecoderOnnxConfig):
5959
@property
6060
def inputs(self) -> Dict[str, Dict[int, str]]:
6161
common_inputs = super().inputs
62-
common_inputs.pop("attention_mask")
6362
if not self.no_position_ids and self.task == "text-generation":
6463
common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"}
6564

optimum/intel/openvino/modeling_decoder.py

+33-5
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import os
1717
from pathlib import Path
1818
from tempfile import TemporaryDirectory
19-
from typing import Dict, Optional, Tuple, Union
19+
from typing import Any, Dict, Optional, Tuple, Union
2020

2121
import numpy as np
2222
import openvino
@@ -25,7 +25,7 @@
2525
from openvino.runtime import Core, Tensor, Type
2626
from transformers import AutoModelForCausalLM, PretrainedConfig
2727
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
28-
from transformers.modeling_outputs import CausalLMOutputWithPast
28+
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
2929

3030
from optimum.utils import NormalizedConfigManager
3131

@@ -401,9 +401,8 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
401401
# create position_ids on the fly for batch generation
402402
position_ids = attention_mask.long().cumsum(-1) - 1
403403
position_ids.masked_fill_(attention_mask == 0, 1)
404-
if past_key_values:
405-
position_ids = position_ids[:, -1].unsqueeze(-1)
406-
404+
if past_key_values:
405+
position_ids = position_ids[:, -1].unsqueeze(-1)
407406
return {
408407
"input_ids": input_ids,
409408
"past_key_values": past_key_values,
@@ -413,6 +412,35 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
413412
"token_type_ids": None,
414413
}
415414

415+
def _update_model_kwargs_for_generation(
416+
self,
417+
outputs: ModelOutput,
418+
model_kwargs: Dict[str, Any],
419+
is_encoder_decoder: bool = False,
420+
standardize_cache_format: bool = False,
421+
) -> Dict[str, Any]:
422+
# update past_key_values
423+
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
424+
outputs, standardize_cache_format=standardize_cache_format
425+
)
426+
427+
# update attention mask
428+
if "attention_mask" in model_kwargs:
429+
attention_mask = model_kwargs["attention_mask"]
430+
model_kwargs["attention_mask"] = torch.cat(
431+
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
432+
)
433+
434+
# update position ids
435+
if "position_ids" in model_kwargs:
436+
position_ids = model_kwargs["position_ids"]
437+
new_position_id = position_ids[..., -1:].clone()
438+
new_position_id += 1
439+
model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1)
440+
441+
model_kwargs["is_first_forward"] = False
442+
return model_kwargs
443+
416444
def _reorder_cache(
417445
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
418446
) -> Tuple[Tuple[torch.Tensor]]:

optimum/intel/utils/modeling_utils.py

+38
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import types
1516
from typing import Tuple
1617

1718
import torch
@@ -92,6 +93,40 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds,
9293
return combined_attention_mask
9394

9495

96+
@torch.jit.script_if_tracing
97+
def _chatglm2_get_context_layer(query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor):
98+
mask = torch.zeros((query_layer.shape[-2], key_layer.shape[-2]), dtype=query_layer.dtype)
99+
if query_layer.shape[2] == key_layer.shape[2]:
100+
tmp_mask = torch.ones((query_layer.shape[-2], key_layer.shape[-2]), dtype=torch.bool).triu(diagonal=1)
101+
mask.masked_fill_(tmp_mask, float("-inf"))
102+
103+
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, attn_mask=mask)
104+
return context_layer
105+
106+
107+
def _core_attention_forward(self, query_layer, key_layer, value_layer, attention_mask):
108+
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
109+
if attention_mask is None:
110+
context_layer = _chatglm2_get_context_layer(query_layer, key_layer, value_layer)
111+
else:
112+
attention_mask = ~attention_mask
113+
context_layer = torch.nn.functional.scaled_dot_product_attention(
114+
query_layer, key_layer, value_layer, attention_mask
115+
)
116+
context_layer = context_layer.permute(2, 0, 1, 3)
117+
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
118+
context_layer = context_layer.reshape(*new_context_layer_shape)
119+
120+
return context_layer
121+
122+
123+
def _patch_chatglm_core_attention_forward(model: "PreTrainedModel"):
124+
for block in model.transformer.encoder.layers:
125+
block.self_attention.core_attention.forward = types.MethodType(
126+
_core_attention_forward, block.self_attention.core_attention
127+
)
128+
129+
95130
def patch_decoder_attention_mask(model: "PreTrainedModel"):
96131
"""
97132
Apply patch on decoder with past model forward to resolve first inference based on model architecture
@@ -108,4 +143,7 @@ def patch_decoder_attention_mask(model: "PreTrainedModel"):
108143
model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
109144
elif model.config.model_type in {"blenderbot-small", "blenderbot", "opt", "pegasus", "bart"}:
110145
model.model.decoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
146+
elif model.config.model_type == "chatglm":
147+
_patch_chatglm_core_attention_forward(model)
148+
111149
return model

0 commit comments

Comments
 (0)