Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit e05557a

Browse files
committedFeb 18, 2024·
fix jit model
1 parent 8f7d016 commit e05557a

File tree

1 file changed

+35
-3
lines changed

1 file changed

+35
-3
lines changed
 

‎optimum/intel/ipex/modeling_base.py

+35-3
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,13 @@ def _from_transformers(
9090
cls,
9191
model_id: str,
9292
config: PretrainedConfig,
93+
use_cache: bool = True,
9394
use_auth_token: Optional[Union[bool, str]] = None,
9495
revision: Optional[str] = None,
9596
force_download: bool = False,
9697
cache_dir: Optional[str] = None,
9798
subfolder: str = "",
9899
local_files_only: bool = False,
99-
use_cache: bool = True,
100100
torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
101101
trust_remote_code: bool = False,
102102
):
@@ -134,6 +134,7 @@ def _from_transformers(
134134
cache_dir=cache_dir,
135135
local_files_only=local_files_only,
136136
use_cache=use_cache,
137+
model_dtype=torch_dtype,
137138
)
138139

139140
@classmethod
@@ -325,9 +326,11 @@ def __init__(
325326
):
326327
# Perform the initial warmup at the end of __init__
327328
super().__init__(model, config, model_save_dir=model_save_dir, warmup=False)
329+
GenerationMixin.__init__(self)
328330

329331
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
330332
self.model_dtype = kwargs.get("model_dtype", self.dtype)
333+
self._dtype = self.model_dtype
331334
self.use_cache = "past_key_values" in self.input_names
332335

333336
if use_cache ^ self.use_cache:
@@ -346,15 +349,44 @@ def __init__(
346349
)
347350
except AttributeError:
348351
self.model_cls = get_model_class(self.config, AutoModelForCausalLM._model_mapping)
349-
self._reorder_cache = self.model_cls._reorder_cache.__get__(self)
350-
self.prepare_inputs_for_generation = self.model_cls.prepare_inputs_for_generation.__get__(self)
351352
if hasattr(self.model_cls, "_convert_to_standard_cache"):
352353
self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache
353354
if hasattr(self.model_cls, "_convert_to_bloom_cache"):
354355
self._convert_to_bloom_cache = self.model_cls._convert_to_bloom_cache
355356
if warmup:
356357
self._init_warmup()
357358

359+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
360+
past_key_values = past_key_values or kwargs.get("past", None)
361+
362+
if self.use_cache and past_key_values is not None:
363+
input_ids = input_ids[:, -1:]
364+
365+
# `past_key_values` may be in the stardard format (e.g. in contrastive search), converts to bloom's format if needed
366+
if past_key_values is not None and self.config.model_type == "bloom":
367+
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
368+
past_key_values = self._convert_to_bloom_cache(past_key_values)
369+
370+
position_ids = kwargs.get("position_ids", None)
371+
372+
attention_mask = kwargs.get("attention_mask", None)
373+
374+
if attention_mask is not None and position_ids is None:
375+
# create position_ids on the fly for batch generation
376+
position_ids = attention_mask.long().cumsum(-1) - 1
377+
position_ids.masked_fill_(attention_mask == 0, 1)
378+
if past_key_values:
379+
position_ids = position_ids[:, -1].unsqueeze(-1)
380+
381+
return {
382+
"input_ids": input_ids,
383+
"past_key_values": past_key_values,
384+
"use_cache": self.use_cache,
385+
"position_ids": position_ids,
386+
"attention_mask": attention_mask,
387+
"token_type_ids": None,
388+
}
389+
358390
def _prepare_past_key_values(self, input_ids):
359391
model_type = self.config.model_type.replace("_", "-")
360392
nb_pkv = 2

0 commit comments

Comments
 (0)
Please sign in to comment.