Skip to content

Commit 1b5c3cb

Browse files
authored
IPEX decoder model fix (#539)
1 parent 20df723 commit 1b5c3cb

File tree

3 files changed

+6
-11
lines changed

3 files changed

+6
-11
lines changed

optimum/intel/ipex/inference.py

+2-9
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,13 @@
3636
IPEXOPTForCausalLM,
3737
IPEXGPTBigCodeForCausalLM,
3838
IPEXModelForQuestionAnswering,
39+
_MODEL_TYPE_TO_AUTOMODELS,
3940
)
4041

4142

4243
from .utils import _HEAD_TO_AUTOMODELS
4344

4445

45-
_MODEL_TYPE_TO_AUTOMODELS = {
46-
"bloom": IPEXBloomForCausalLM,
47-
"mpt": IPEXMPTForCausalLM,
48-
"opt": IPEXOPTForCausalLM,
49-
"big_code": IPEXGPTBigCodeForCausalLM,
50-
}
51-
52-
5346
logger = logging.getLogger(__name__)
5447

5548
IPEX_NOT_AVAILABLE_ERROR_MSG = (
@@ -149,7 +142,7 @@ def __enter__(self):
149142
model_type = getattr(self._original.config, "model_type", "").replace("_", "-")
150143

151144
if task == "text-generation" and model_type in _MODEL_TYPE_TO_AUTOMODELS.keys():
152-
auto_model_class = _MODEL_TYPE_TO_AUTOMODELS[task]
145+
auto_model_class = _MODEL_TYPE_TO_AUTOMODELS[model_type]
153146
else:
154147
auto_model_class = eval(_HEAD_TO_AUTOMODELS[task])
155148

optimum/intel/ipex/modeling_base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -600,5 +600,5 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
600600
"bloom": IPEXBloomForCausalLM,
601601
"mpt": IPEXMPTForCausalLM,
602602
"opt": IPEXOPTForCausalLM,
603-
"big-code": IPEXGPTBigCodeForCausalLM,
603+
"gpt-bigcode": IPEXGPTBigCodeForCausalLM,
604604
}

tests/ipex/test_inference.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
"gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
4343
"gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",
4444
"llama": "fxmarty/tiny-llama-fast-tokenizer",
45+
"opt": "hf-internal-testing/tiny-random-OPTModel",
46+
"mpt": "hf-internal-testing/tiny-random-MptForCausalLM",
4547
}
4648

4749
_CLASSIFICATION_TASK_TO_AUTOMODELS = {
@@ -57,7 +59,7 @@ class IPEXIntegrationTest(unittest.TestCase):
5759
"roberta",
5860
)
5961

60-
TEXT_GENERATION_SUPPORTED_ARCHITECTURES = ("gptj", "gpt2", "gpt_neo", "gpt_bigcode", "llama")
62+
TEXT_GENERATION_SUPPORTED_ARCHITECTURES = ("gptj", "gpt2", "gpt_neo", "gpt_bigcode", "llama", "opt", "mpt")
6163

6264
QA_SUPPORTED_ARCHITECTURES = (
6365
"bert",

0 commit comments

Comments
 (0)