File tree 3 files changed +11
-2
lines changed
3 files changed +11
-2
lines changed Original file line number Diff line number Diff line change @@ -179,13 +179,22 @@ def _reorder_cache(
179
179
"""
180
180
if self .config .model_type == "bloom" :
181
181
return self ._reorder_cache_bloom (past_key_values , beam_idx )
182
+ elif self .config .model_type == "gpt_bigcode" :
183
+ return self ._reorder_cache_gpt_bigcode (past_key_values , beam_idx )
182
184
183
185
# from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
184
186
return tuple (
185
187
tuple (past_state .index_select (0 , beam_idx .to (past_state .device )) for past_state in layer_past )
186
188
for layer_past in past_key_values
187
189
)
188
190
191
+ # Copied from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache
192
+ @staticmethod
193
+ def _reorder_cache_gpt_bigcode (
194
+ past_key_values : Tuple [Tuple [torch .Tensor ]], beam_idx : torch .Tensor
195
+ ) -> Tuple [Tuple [torch .Tensor ]]:
196
+ return tuple (layer_past .index_select (0 , beam_idx .to (layer_past .device )) for layer_past in past_key_values )
197
+
189
198
# Copied from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache
190
199
def _reorder_cache_bloom (
191
200
self , past_key_values : Tuple [Tuple [torch .Tensor ]], beam_idx : torch .Tensor
Original file line number Diff line number Diff line change @@ -58,7 +58,7 @@ class ModelingIntegrationTest(unittest.TestCase):
58
58
"mistral" ,
59
59
"llama" ,
60
60
"llama2" ,
61
- # "gpt_bigcode",
61
+ "gpt_bigcode" ,
62
62
)
63
63
64
64
GENERATION_LENGTH = 100
Original file line number Diff line number Diff line change @@ -65,7 +65,7 @@ class IPEXIntegrationTest(unittest.TestCase):
65
65
"gptj" ,
66
66
"gpt2" ,
67
67
"gpt_neo" ,
68
- # "gpt_bigcode",
68
+ "gpt_bigcode" ,
69
69
"llama" ,
70
70
"llama2" ,
71
71
"opt" ,
You can’t perform that action at this time.
0 commit comments