24
24
from huggingface_hub import hf_hub_download
25
25
from intel_extension_for_pytorch .cpu ._auto_kernel_selection import _enable_tpp
26
26
from intel_extension_for_pytorch .transformers .optimize import get_dummy_input
27
+ from packaging import version
27
28
from transformers import (
28
29
AutoConfig ,
29
30
AutoModel ,
47
48
from optimum .modeling_base import OptimizedModel
48
49
from optimum .utils import NormalizedConfigManager
49
50
50
- from ...exporters .ipex import export_model
51
+ from ...exporters .ipex . model_patcher import IPEX_EXPORTED_ARCH , IPEX_EXPORTED_TASK , _patch_model
51
52
from ..generation .modeling import jit_trace , prepare_jit_inputs
52
53
from ..utils .import_utils import is_torch_version , is_transformers_version
53
54
from ..utils .modeling_utils import MULTI_QUERY_ATTN_MODELS , patch_decoder_attention_mask
56
57
logger = logging .getLogger (__name__ )
57
58
58
59
59
- IPEX_EXPORTED_LIST = ( "LlamaForCausalLM" , )
60
+ IPEX_SUPPORT_MODEL_TYPES = "llama"
60
61
61
62
62
- def is_ipex_exported_model (model_name ):
63
- for name in IPEX_EXPORTED_LIST :
64
- if model_name == name :
65
- return True
66
- return False
63
+ def is_model_support_ipex_export (model , task ):
64
+ if isinstance (model , torch .jit .ScriptModule ):
65
+ is_ipex_exported = model .original_name in IPEX_EXPORTED_ARCH
66
+ else :
67
+ is_ipex_exported = model .config .model_type in IPEX_SUPPORT_MODEL_TYPES and task in IPEX_EXPORTED_TASK
68
+
69
+ return is_ipex_exported
70
+
71
+
72
+ def ipex_jit_trace (model , task , use_cache ):
73
+ if version .parse (ipex .__version__ ) <= version .parse ("2.3.0" ) or not is_model_support_ipex_export (model , task ):
74
+ model = patch_decoder_attention_mask (model )
75
+ model = ipex .optimize (model , dtype = model .dtype , level = "O1" , auto_kernel_selection = True )
76
+ return jit_trace (model , task , use_cache )
67
77
78
+ if is_torch_version ("<" , "2.1.0" ):
79
+ raise ImportError ("`torch>=2.1.0` is needed to trace your model" )
68
80
69
- def ipex_jit_trace (model ):
81
+ model = _patch_model (model )
70
82
sample_inputs = get_dummy_input (model , return_dict = True )
71
83
model .config .return_dict = False
72
84
_enable_tpp ()
@@ -104,7 +116,7 @@ def __init__(
104
116
self ._dtype = self .config .torch_dtype if self .config .torch_dtype is not None else torch .float32
105
117
self .model .to (self ._device )
106
118
self .model_save_dir = model_save_dir
107
- self .is_ipex_exported = kwargs . get ( "is_ipex_exported" , None )
119
+ self .is_ipex_exported = is_model_support_ipex_export ( model , self . export_feature )
108
120
109
121
self .input_names = {
110
122
inputs .debugName ().split ("." )[0 ] for inputs in model .graph .inputs () if inputs .debugName () != "self"
@@ -148,14 +160,7 @@ def _from_transformers(
148
160
}
149
161
150
162
model = TasksManager .get_model_from_task (task , model_id , ** model_kwargs )
151
- is_ipex_exported = is_ipex_exported_model (model .__class__ .__name__ )
152
- if is_ipex_exported :
153
- model = export_model (model )
154
- traced_model = ipex_jit_trace (model )
155
- else :
156
- model = patch_decoder_attention_mask (model )
157
- model = ipex .optimize (model , dtype = torch_dtype , level = "O1" , auto_kernel_selection = True )
158
- traced_model = jit_trace (model , task , use_cache )
163
+ traced_model = ipex_jit_trace (model , task , use_cache )
159
164
160
165
save_dir = TemporaryDirectory ()
161
166
save_dir_path = Path (save_dir .name )
@@ -173,7 +178,6 @@ def _from_transformers(
173
178
local_files_only = local_files_only ,
174
179
use_cache = use_cache ,
175
180
model_dtype = torch_dtype ,
176
- is_ipex_exported = is_ipex_exported ,
177
181
)
178
182
179
183
@classmethod
@@ -210,8 +214,6 @@ def _from_pretrained(
210
214
211
215
model = torch .jit .load (model_cache_path )
212
216
torch .jit .freeze (model .eval ())
213
- is_ipex_exported = is_ipex_exported_model (model .original_name )
214
- kwargs ["is_ipex_exported" ] = is_ipex_exported
215
217
216
218
return cls (model , config = config , model_save_dir = model_save_dir , ** kwargs )
217
219
@@ -372,7 +374,6 @@ def __init__(
372
374
model_type = config .model_type .replace ("_" , "-" )
373
375
self .normalized_config = NormalizedConfigManager .get_normalized_config_class (model_type )(config )
374
376
self .use_cache = "past_key_values" in self .input_names
375
- self .is_ipex_exported = kwargs .get ("is_ipex_exported" , None )
376
377
377
378
if use_cache ^ self .use_cache :
378
379
raise ValueError (
@@ -422,7 +423,11 @@ def _prepare_past_key_values(self, input_ids):
422
423
num_attention_heads = self .normalized_config .num_attention_heads
423
424
424
425
if self .is_ipex_exported :
425
- beam_idx_tmp = torch .zeros ((2048 , input_ids .shape [0 ]), dtype = torch .long ).contiguous ()
426
+ # Indirect access kv cache has a different data layout compared with most transformers model,
427
+ # see https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/llm.html#indirect-access-kv-cache
428
+ beam_idx_tmp = torch .zeros (
429
+ (self .config .max_position_embeddings , input_ids .shape [0 ]), dtype = torch .long
430
+ ).contiguous ()
426
431
past_key_values = tuple (
427
432
[
428
433
(
@@ -562,8 +567,8 @@ def _prepare_inputs_for_generation_for_llama(
562
567
def _ipex_reorder_cache (
563
568
past_key_values : Tuple [Tuple [torch .Tensor ]], beam_idx : torch .Tensor
564
569
) -> Tuple [Tuple [torch .Tensor ]]:
565
-
566
- if len (past_key_values [0 ]) == 4 and past_key_values [0 ][0 ].shape [- 1 ] == 1 : # discrete kv_cache
570
+ # Ipex patched model uses indirect access kv cache which has a different shape with other transformers models
571
+ if len (past_key_values [0 ]) == 4 and past_key_values [0 ][0 ].shape [- 1 ] == 1 :
567
572
for layer_past in past_key_values :
568
573
layer_past [3 ][layer_past [0 ].size (- 2 ) - 1 ] = beam_idx
569
574
return past_key_values
@@ -577,8 +582,3 @@ def _ipex_reorder_cache(
577
582
tuple (past_state .index_select (0 , beam_idx .to (past_state .device )) for past_state in layer_past )
578
583
for layer_past in past_key_values
579
584
)
580
-
581
- return tuple (
582
- tuple (past_state .index_select (0 , beam_idx .to (past_state .device )) for past_state in layer_past )
583
- for layer_past in past_key_values
584
- )
0 commit comments