@@ -65,7 +65,7 @@ def _is_patched_with_ipex(model, task):
65
65
if isinstance (model , torch .jit .ScriptModule ):
66
66
for node in model .graph .nodes ():
67
67
# Jit will record the codes position so we can check if the node use ipex exporter.
68
- if "optimum/exporters/ipex/modeling_utils.py " in node .__str__ ():
68
+ if "torch_ipex::rotary_position_embedding " in node .__str__ ():
69
69
return True
70
70
return False
71
71
else :
@@ -123,7 +123,7 @@ def __init__(
123
123
self ._dtype = self .config .torch_dtype if self .config .torch_dtype is not None else torch .float32
124
124
self .model .to (self ._device )
125
125
self .model_save_dir = model_save_dir
126
- self .is_ipex_exported = _is_patched_with_ipex (model , self .export_feature )
126
+ self ._is_ipex_exported = _is_patched_with_ipex (model , self .export_feature )
127
127
128
128
self .input_names = {
129
129
inputs .debugName ().split ("." )[0 ] for inputs in model .graph .inputs () if inputs .debugName () != "self"
@@ -285,7 +285,7 @@ def _init_warmup(self):
285
285
# warmup, the first 2 forwards of an IPEX model include some preprocessing steps and
286
286
# the results of the compute are unpredictable
287
287
# TODO : add warmup for IPEX exported model
288
- if not self .is_ipex_exported :
288
+ if not self ._is_ipex_exported :
289
289
use_cache = "past_key_values" in self .input_names
290
290
dummy_inputs = prepare_jit_inputs (self , self .export_feature , use_cache )
291
291
for _ in range (2 ):
@@ -409,7 +409,7 @@ def __init__(
409
409
except AttributeError :
410
410
self .model_cls = get_model_class (self .config , AutoModelForCausalLM ._model_mapping )
411
411
412
- if self .is_ipex_exported :
412
+ if self ._is_ipex_exported :
413
413
self ._reorder_cache = _ipex_reorder_cache
414
414
else :
415
415
# Check if _reorder_cache is a static method
@@ -442,7 +442,7 @@ def _prepare_past_key_values(self, input_ids):
442
442
else :
443
443
num_attention_heads = self .normalized_config .num_attention_heads
444
444
445
- if self .is_ipex_exported :
445
+ if self ._is_ipex_exported :
446
446
# Indirect access kv cache has a different data layout compared with most transformers model,
447
447
# see https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/llm.html#indirect-access-kv-cache
448
448
beam_idx_tmp = torch .zeros (
0 commit comments