12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ import types
15
16
from typing import Tuple
16
17
17
18
import torch
@@ -92,6 +93,40 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds,
92
93
return combined_attention_mask
93
94
94
95
96
+ @torch .jit .script_if_tracing
97
+ def _chatglm2_get_context_layer (query_layer : torch .Tensor , key_layer : torch .Tensor , value_layer : torch .Tensor ):
98
+ mask = torch .zeros ((query_layer .shape [- 2 ], key_layer .shape [- 2 ]), dtype = query_layer .dtype )
99
+ if query_layer .shape [2 ] == key_layer .shape [2 ]:
100
+ tmp_mask = torch .ones ((query_layer .shape [- 2 ], key_layer .shape [- 2 ]), dtype = torch .bool ).triu (diagonal = 1 )
101
+ mask .masked_fill_ (tmp_mask , float ("-inf" ))
102
+
103
+ context_layer = torch .nn .functional .scaled_dot_product_attention (query_layer , key_layer , value_layer , attn_mask = mask )
104
+ return context_layer
105
+
106
+
107
+ def _core_attention_forward (self , query_layer , key_layer , value_layer , attention_mask ):
108
+ query_layer , key_layer , value_layer = [k .permute (1 , 2 , 0 , 3 ) for k in [query_layer , key_layer , value_layer ]]
109
+ if attention_mask is None :
110
+ context_layer = _chatglm2_get_context_layer (query_layer , key_layer , value_layer )
111
+ else :
112
+ attention_mask = ~ attention_mask
113
+ context_layer = torch .nn .functional .scaled_dot_product_attention (
114
+ query_layer , key_layer , value_layer , attention_mask
115
+ )
116
+ context_layer = context_layer .permute (2 , 0 , 1 , 3 )
117
+ new_context_layer_shape = context_layer .size ()[:- 2 ] + (self .hidden_size_per_partition ,)
118
+ context_layer = context_layer .reshape (* new_context_layer_shape )
119
+
120
+ return context_layer
121
+
122
+
123
+ def _patch_chatglm_core_attention_forward (model : "PreTrainedModel" ):
124
+ for block in model .transformer .encoder .layers :
125
+ block .self_attention .core_attention .forward = types .MethodType (
126
+ _core_attention_forward , block .self_attention .core_attention
127
+ )
128
+
129
+
95
130
def patch_decoder_attention_mask (model : "PreTrainedModel" ):
96
131
"""
97
132
Apply patch on decoder with past model forward to resolve first inference based on model architecture
@@ -108,4 +143,7 @@ def patch_decoder_attention_mask(model: "PreTrainedModel"):
108
143
model .model ._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
109
144
elif model .config .model_type in {"blenderbot-small" , "blenderbot" , "opt" , "pegasus" , "bart" }:
110
145
model .model .decoder ._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
146
+ elif model .config .model_type == "chatglm" :
147
+ _patch_chatglm_core_attention_forward (model )
148
+
111
149
return model
0 commit comments