Skip to content

Commit 00e6bf3

Browse files
committed
rebase
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
2 parents 06798e2 + c6d2d0f commit 00e6bf3

File tree

4 files changed

+18
-7
lines changed

4 files changed

+18
-7
lines changed

optimum/exporters/ipex/cache_utils.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
from transformers import Cache, PretrainedConfig
66

77

8+
# May need to tune based on sequence length and different models but default to 16 currently.
9+
BLOCK_SIZE = 16
10+
11+
812
class IPEXPagedCache(Cache):
913
"""
1014
A PagedCache that grows dynamically as more tokens are generated. everytime it grows block-size memory, vendor could set the pageCache memory layout.
@@ -44,7 +48,7 @@ def __init__(
4448
self.batch_size = batch_size
4549
# Used in `generate` to keep tally of how many tokens the cache has seen
4650
self._seen_tokens = torch.zeros([batch_size], dtype=torch.int32, device=device)
47-
self.block_size = 64
51+
self.block_size = BLOCK_SIZE
4852
self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * batch_size
4953
self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape(
5054
batch_size, -1

optimum/exporters/ipex/model_patcher.py

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from transformers.models.bert.modeling_bert import BertIntermediate
1616
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel
1717
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
18+
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
1819
from transformers.models.llama.modeling_llama import (
1920
LlamaDecoderLayer,
2021
LlamaModel,
@@ -33,6 +34,7 @@
3334
from .modeling_utils import (
3435
_IPEX_MINIMUM_VERSION_FOR_PATCHING,
3536
_IPEXGPT2MLP,
37+
_IPEXGPT2MLP,
3638
_falcon_model_forward,
3739
_gpt2_block_forward,
3840
_gpt2_model_forward,

optimum/exporters/ipex/modeling_utils.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -769,19 +769,17 @@ def attention_interface(
769769
is_causal=True,
770770
)
771771
self.use_sdpa = True
772-
elif self.has_flash_attn(query):
772+
elif self.has_flash_attn(query) and past_len == 0:
773773
attn_output = torch.empty_like(query)
774774
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
775-
query_len_tensor = seq_len_tensor if past_len == 0 else torch.arange(seq_len_tensor.shape[0]).int()
776-
query_max_len = input_lens.max() if past_len == 0 else 1
777775
PagedAttention.flash_attn_varlen_func(
778776
attn_output,
779777
query.contiguous() if query.device.type == "xpu" else query,
780778
key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache,
781779
value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache,
782-
query_len_tensor,
783780
seq_len_tensor,
784-
query_max_len,
781+
seq_len_tensor,
782+
input_lens.max(),
785783
input_lens.max(),
786784
1.0 / math.sqrt(self.head_dim),
787785
True,

optimum/exporters/openvino/convert.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
compare_versions,
4646
is_diffusers_version,
4747
is_openvino_tokenizers_version,
48+
is_openvino_version,
4849
is_tokenizers_version,
4950
is_transformers_version,
5051
)
@@ -366,6 +367,7 @@ def export_pytorch(
366367
import torch
367368
from torch.utils._pytree import tree_map
368369

370+
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
369371
from optimum.exporters.utils import check_dummy_inputs_are_allowed
370372

371373
logger.info(f"Using framework PyTorch: {torch.__version__}")
@@ -428,15 +430,20 @@ def ts_patched_forward(*args, **kwargs):
428430

429431
patcher.patched_forward = ts_patched_forward
430432

433+
ts_decoder_kwargs = {}
434+
if library_name == "diffusers" and is_openvino_version(">=", "2025.0"):
435+
ts_decoder_kwargs["trace_kwargs"] = {"check_trace": False}
436+
431437
with patcher:
432438
if patch_16bit_model:
433439
from openvino.frontend.pytorch.patch_model import __make_16bit_traceable
434440

435441
__make_16bit_traceable(model)
436442
check_dummy_inputs_are_allowed(model, dummy_inputs)
437443
input_info = _get_input_info(model, config, dummy_inputs)
444+
ts_decoder = TorchScriptPythonDecoder(model, example_input=dummy_inputs, **ts_decoder_kwargs)
438445
ov_model = convert_model(
439-
model,
446+
ts_decoder,
440447
example_input=dummy_inputs,
441448
input=[(item.shape, item.type) for item in input_info],
442449
)

0 commit comments

Comments
 (0)