diff --git a/optimum/intel/generation/modeling.py b/optimum/intel/generation/modeling.py index 3d9c657626..9aaa942877 100644 --- a/optimum/intel/generation/modeling.py +++ b/optimum/intel/generation/modeling.py @@ -179,6 +179,8 @@ def _reorder_cache( """ if self.config.model_type == "bloom": return self._reorder_cache_bloom(past_key_values, beam_idx) + elif self.config.model_type == "gpt_bigcode": + return self._reorder_cache_gpt_bigcode(past_key_values, beam_idx) # from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache return tuple( @@ -186,6 +188,13 @@ def _reorder_cache( for layer_past in past_key_values ) + # Copied from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache + @staticmethod + def _reorder_cache_gpt_bigcode( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values) + # Copied from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache def _reorder_cache_bloom( self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor diff --git a/tests/generation/test_modeling.py b/tests/generation/test_modeling.py index 22a9cac661..20381aa92b 100644 --- a/tests/generation/test_modeling.py +++ b/tests/generation/test_modeling.py @@ -58,7 +58,7 @@ class ModelingIntegrationTest(unittest.TestCase): "mistral", "llama", "llama2", - # "gpt_bigcode", + "gpt_bigcode", ) GENERATION_LENGTH = 100 diff --git a/tests/ipex/test_inference.py b/tests/ipex/test_inference.py index e120514506..b65d3c9b8e 100644 --- a/tests/ipex/test_inference.py +++ b/tests/ipex/test_inference.py @@ -65,7 +65,7 @@ class IPEXIntegrationTest(unittest.TestCase): "gptj", "gpt2", "gpt_neo", - # "gpt_bigcode", + "gpt_bigcode", "llama", "llama2", "opt",