Skip to content

Commit 3a86c40

Browse files
committed
fix raise import error
1 parent 809542c commit 3a86c40

File tree

3 files changed

+13
-12
lines changed

3 files changed

+13
-12
lines changed

optimum/exporters/ipex/model_patcher.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import intel_extension_for_pytorch as ipex
16-
from packaging import version
1715
from transformers.models.llama.modeling_llama import (
1816
LlamaAttention,
1917
LlamaDecoderLayer,
@@ -22,6 +20,8 @@
2220
LlamaRMSNorm,
2321
)
2422

23+
from optimum.intel.utils.import_utils import is_ipex_version
24+
2525
from .modeling_utils import (
2626
_IPEXLlamaDecoderLayerRef,
2727
llama_attn_forward,
@@ -30,10 +30,6 @@
3030
)
3131

3232

33-
if version.parse(ipex.__version__) > version.parse("2.3.0"):
34-
from intel_extension_for_pytorch.llm.modules import ApplyRotaryEmbedding, IndirectAccessKVCache
35-
36-
3733
_IPEX_EXPORTED_ARCH = ("LlamaForCausalLM",)
3834
_IPEX_EXPORTED_TASK = ("text-generation",)
3935

@@ -66,6 +62,11 @@ def patch_op(m, target_m, new_op_name, new_op):
6662

6763

6864
def _patch_llama_model(model):
65+
if is_ipex_version("<=", "2.3.0"):
66+
raise ImportError("Only ipex version > 2.3.0 supports ApplyRotaryEmbedding and IndirectAccessKVCache")
67+
68+
from intel_extension_for_pytorch.llm.modules import ApplyRotaryEmbedding, IndirectAccessKVCache
69+
6970
ipex_rope = ApplyRotaryEmbedding(
7071
model.config.max_position_embeddings,
7172
model.config.hidden_size // model.config.num_attention_heads,

optimum/exporters/ipex/modeling_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def llama_model_forward(
230230
class _IPEXLlamaDecoderLayerRef(nn.Module):
231231
def __init__(self, module, config, distributed=False):
232232
if is_ipex_version("<=", "2.3.0"):
233-
raise ValueError("Only ipex version > 2.3.0 supports linear2SiluMul and linearAdd")
233+
raise ImportError("Only ipex version > 2.3.0 supports linear2SiluMul and linearAdd")
234234

235235
from intel_extension_for_pytorch.llm.modules import linear2SiluMul, linearAdd
236236

optimum/intel/ipex/modeling_base.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def _is_patched_with_ipex(model, task):
6565
if isinstance(model, torch.jit.ScriptModule):
6666
for node in model.graph.nodes():
6767
# Jit will record the codes position so we can check if the node use ipex exporter.
68-
if "optimum/exporters/ipex/modeling_utils.py" in node.__str__():
68+
if "torch_ipex::rotary_position_embedding" in node.__str__():
6969
return True
7070
return False
7171
else:
@@ -123,7 +123,7 @@ def __init__(
123123
self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32
124124
self.model.to(self._device)
125125
self.model_save_dir = model_save_dir
126-
self.is_ipex_exported = _is_patched_with_ipex(model, self.export_feature)
126+
self._is_ipex_exported = _is_patched_with_ipex(model, self.export_feature)
127127

128128
self.input_names = {
129129
inputs.debugName().split(".")[0] for inputs in model.graph.inputs() if inputs.debugName() != "self"
@@ -285,7 +285,7 @@ def _init_warmup(self):
285285
# warmup, the first 2 forwards of an IPEX model include some preprocessing steps and
286286
# the results of the compute are unpredictable
287287
# TODO : add warmup for IPEX exported model
288-
if not self.is_ipex_exported:
288+
if not self._is_ipex_exported:
289289
use_cache = "past_key_values" in self.input_names
290290
dummy_inputs = prepare_jit_inputs(self, self.export_feature, use_cache)
291291
for _ in range(2):
@@ -409,7 +409,7 @@ def __init__(
409409
except AttributeError:
410410
self.model_cls = get_model_class(self.config, AutoModelForCausalLM._model_mapping)
411411

412-
if self.is_ipex_exported:
412+
if self._is_ipex_exported:
413413
self._reorder_cache = _ipex_reorder_cache
414414
else:
415415
# Check if _reorder_cache is a static method
@@ -442,7 +442,7 @@ def _prepare_past_key_values(self, input_ids):
442442
else:
443443
num_attention_heads = self.normalized_config.num_attention_heads
444444

445-
if self.is_ipex_exported:
445+
if self._is_ipex_exported:
446446
# Indirect access kv cache has a different data layout compared with most transformers model,
447447
# see https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/llm.html#indirect-access-kv-cache
448448
beam_idx_tmp = torch.zeros(

0 commit comments

Comments
 (0)