|
31 | 31 | IPEXModelForMaskedLM,
|
32 | 32 | IPEXModelForSequenceClassification,
|
33 | 33 | IPEXModelForTokenClassification,
|
34 |
| - IPEXBloomForCausalLM, |
35 |
| - IPEXMPTForCausalLM, |
36 |
| - IPEXOPTForCausalLM, |
37 |
| - IPEXGPTBigCodeForCausalLM, |
38 | 34 | IPEXModelForQuestionAnswering,
|
39 | 35 | )
|
40 | 36 |
|
41 | 37 |
|
42 | 38 | from .utils import _HEAD_TO_AUTOMODELS
|
43 | 39 |
|
44 | 40 |
|
45 |
| -_MODEL_TYPE_TO_AUTOMODELS = { |
46 |
| - "bloom": IPEXBloomForCausalLM, |
47 |
| - "mpt": IPEXMPTForCausalLM, |
48 |
| - "opt": IPEXOPTForCausalLM, |
49 |
| - "big_code": IPEXGPTBigCodeForCausalLM, |
50 |
| -} |
51 |
| - |
52 |
| - |
53 | 41 | logger = logging.getLogger(__name__)
|
54 | 42 |
|
55 | 43 | IPEX_NOT_AVAILABLE_ERROR_MSG = (
|
@@ -146,13 +134,7 @@ def __enter__(self):
|
146 | 134 | )
|
147 | 135 | if task in _HEAD_TO_AUTOMODELS:
|
148 | 136 | model = jit_trace(model, task, use_cache)
|
149 |
| - model_type = getattr(self._original.config, "model_type", "").replace("_", "-") |
150 |
| - |
151 |
| - if task == "text-generation" and model_type in _MODEL_TYPE_TO_AUTOMODELS.keys(): |
152 |
| - auto_model_class = _MODEL_TYPE_TO_AUTOMODELS[task] |
153 |
| - else: |
154 |
| - auto_model_class = eval(_HEAD_TO_AUTOMODELS[task]) |
155 |
| - |
| 137 | + auto_model_class = eval(_HEAD_TO_AUTOMODELS[task]) |
156 | 138 | model = auto_model_class(model, self._original.config, use_cache=use_cache)
|
157 | 139 |
|
158 | 140 | # Enable automatic mixed precision (AMP) if we are going to target `bfloat16`
|
|
0 commit comments