File tree 3 files changed +11
-2
lines changed
3 files changed +11
-2
lines changed Original file line number Diff line number Diff line change @@ -180,13 +180,22 @@ def _reorder_cache(
180
180
"""
181
181
if self .config .model_type == "bloom" :
182
182
return self ._reorder_cache_bloom (past_key_values , beam_idx )
183
+ elif self .config .model_type == "gpt_bigcode" :
184
+ return self ._reorder_cache_gpt_bigcode (past_key_values , beam_idx )
183
185
184
186
# from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
185
187
return tuple (
186
188
tuple (past_state .index_select (0 , beam_idx .to (past_state .device )) for past_state in layer_past )
187
189
for layer_past in past_key_values
188
190
)
189
191
192
+ # Copied from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache
193
+ @staticmethod
194
+ def _reorder_cache_gpt_bigcode (
195
+ past_key_values : Tuple [Tuple [torch .Tensor ]], beam_idx : torch .Tensor
196
+ ) -> Tuple [Tuple [torch .Tensor ]]:
197
+ return tuple (layer_past .index_select (0 , beam_idx .to (layer_past .device )) for layer_past in past_key_values )
198
+
190
199
# Copied from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache
191
200
def _reorder_cache_bloom (
192
201
self , past_key_values : Tuple [Tuple [torch .Tensor ]], beam_idx : torch .Tensor
Original file line number Diff line number Diff line change @@ -58,7 +58,7 @@ class ModelingIntegrationTest(unittest.TestCase):
58
58
"mistral" ,
59
59
"llama" ,
60
60
"llama2" ,
61
- # "gpt_bigcode",
61
+ "gpt_bigcode" ,
62
62
)
63
63
64
64
GENERATION_LENGTH = 100
Original file line number Diff line number Diff line change @@ -65,7 +65,7 @@ class IPEXIntegrationTest(unittest.TestCase):
65
65
"gptj" ,
66
66
"gpt2" ,
67
67
"gpt_neo" ,
68
- # "gpt_bigcode",
68
+ "gpt_bigcode" ,
69
69
"llama" ,
70
70
"llama2" ,
71
71
"opt" ,
You can’t perform that action at this time.
0 commit comments