Skip to content

Commit 64a7f6b

Browse files
fix and enable gpt_bigcode
1 parent a06522c commit 64a7f6b

File tree

3 files changed

+11
-2
lines changed

3 files changed

+11
-2
lines changed

optimum/intel/generation/modeling.py

+9
Original file line numberDiff line numberDiff line change
@@ -179,13 +179,22 @@ def _reorder_cache(
179179
"""
180180
if self.config.model_type == "bloom":
181181
return self._reorder_cache_bloom(past_key_values, beam_idx)
182+
elif self.config.model_type == "gpt_bigcode":
183+
return self._reorder_cache_gpt_bigcode(past_key_values, beam_idx)
182184

183185
# from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
184186
return tuple(
185187
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
186188
for layer_past in past_key_values
187189
)
188190

191+
# Copied from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache
192+
@staticmethod
193+
def _reorder_cache_gpt_bigcode(
194+
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
195+
) -> Tuple[Tuple[torch.Tensor]]:
196+
return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values)
197+
189198
# Copied from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache
190199
def _reorder_cache_bloom(
191200
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor

tests/generation/test_modeling.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class ModelingIntegrationTest(unittest.TestCase):
5858
"mistral",
5959
"llama",
6060
"llama2",
61-
# "gpt_bigcode",
61+
"gpt_bigcode",
6262
)
6363

6464
GENERATION_LENGTH = 100

tests/ipex/test_inference.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class IPEXIntegrationTest(unittest.TestCase):
6565
"gptj",
6666
"gpt2",
6767
"gpt_neo",
68-
# "gpt_bigcode",
68+
"gpt_bigcode",
6969
"llama",
7070
"llama2",
7171
"opt",

0 commit comments

Comments
 (0)