Skip to content

Commit 37f2094

Browse files
committed
fix test models
1 parent d1782d0 commit 37f2094

File tree

3 files changed

+171
-11
lines changed

3 files changed

+171
-11
lines changed

optimum/exporters/openvino/model_configs.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -786,22 +786,16 @@ def __init__(
786786
self.num_key_value_heads = normalized_config.num_key_value_heads
787787

788788
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
789-
v_shape = (
789+
shape = (
790790
self.batch_size,
791791
self.num_key_value_heads,
792792
self.sequence_length,
793793
self.hidden_size // self.num_attention_heads,
794794
)
795-
k_shape = (
796-
self.batch_size,
797-
self.num_key_value_heads,
798-
self.sequence_length,
799-
self.hidden_size // self.num_attention_heads * 2,
800-
)
801795
return [
802796
(
803-
self.random_float_tensor(k_shape, framework=framework, dtype=float_dtype),
804-
self.random_float_tensor(v_shape, framework=framework, dtype=float_dtype),
797+
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
798+
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
805799
)
806800
for _ in range(self.num_layers)
807801
]

optimum/exporters/openvino/model_patcher.py

+166
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import torch
2222
import torch.nn.functional as F
23+
from transformers.cache_utils import Cache, StaticCache
2324
from transformers.modeling_outputs import BaseModelOutputWithPast
2425
from transformers.utils import is_tf_available
2526

@@ -1397,9 +1398,173 @@ def _dbrx_experts_forward(
13971398
return out
13981399

13991400

1401+
def _dbrx_update_causal_mask_legacy(
1402+
self, attention_mask: Optional[torch.Tensor], input_tensor: torch.Tensor, cache_position: torch.Tensor
1403+
) -> Optional[torch.Tensor]:
1404+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
1405+
1406+
if self.config._attn_implementation == "flash_attention_2":
1407+
if attention_mask is not None and 0.0 in attention_mask:
1408+
return attention_mask
1409+
return None
1410+
1411+
dtype, device = input_tensor.dtype, input_tensor.device
1412+
min_dtype = torch.finfo(torch.float16).min
1413+
sequence_length = input_tensor.shape[1]
1414+
if hasattr(self.blocks[0].norm_attn_norm.attn, "past_key_value"): # static cache
1415+
target_length = self.config.max_position_embeddings
1416+
else: # dynamic cache
1417+
target_length = (
1418+
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
1419+
)
1420+
1421+
causal_mask = torch.full((sequence_length, target_length), fill_value=1, dtype=dtype, device=device) * min_dtype
1422+
if sequence_length != 1:
1423+
causal_mask = torch.triu(causal_mask, diagonal=1)
1424+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1425+
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
1426+
if attention_mask is not None:
1427+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1428+
if attention_mask.dim() == 2:
1429+
mask_length = attention_mask.shape[-1]
1430+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
1431+
padding_mask = padding_mask == 0
1432+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
1433+
padding_mask, min_dtype
1434+
)
1435+
elif attention_mask.dim() == 4:
1436+
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
1437+
# cache. In that case, the 4D attention mask attends to the newest tokens only.
1438+
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
1439+
offset = cache_position[0]
1440+
else:
1441+
offset = 0
1442+
mask_shape = attention_mask.shape
1443+
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
1444+
causal_mask[
1445+
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
1446+
] = mask_slice
1447+
1448+
if (
1449+
self.config._attn_implementation == "sdpa"
1450+
and attention_mask is not None
1451+
and attention_mask.device.type == "cuda"
1452+
):
1453+
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
1454+
is_tracing = (
1455+
torch.jit.is_tracing()
1456+
or isinstance(input_tensor, torch.fx.Proxy)
1457+
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
1458+
)
1459+
if not is_tracing and torch.any(attention_mask != 1):
1460+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1461+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1462+
# Details: https://github.com/pytorch/pytorch/issues/110213
1463+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1464+
1465+
return causal_mask
1466+
1467+
1468+
def _dbrx_update_causal_mask_latest(
1469+
self,
1470+
attention_mask: torch.Tensor,
1471+
input_tensor: torch.Tensor,
1472+
cache_position: torch.Tensor,
1473+
past_key_values: Cache,
1474+
output_attentions: bool,
1475+
):
1476+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
1477+
1478+
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
1479+
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1480+
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1481+
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1482+
1483+
if self.config._attn_implementation == "flash_attention_2":
1484+
if attention_mask is not None and 0.0 in attention_mask:
1485+
return attention_mask
1486+
return None
1487+
1488+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1489+
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1490+
# to infer the attention mask.
1491+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1492+
using_static_cache = isinstance(past_key_values, StaticCache)
1493+
1494+
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1495+
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
1496+
if AttentionMaskConverter._ignore_causal_mask_sdpa(
1497+
attention_mask,
1498+
inputs_embeds=input_tensor,
1499+
past_key_values_length=past_seen_tokens,
1500+
is_training=self.training,
1501+
):
1502+
return None
1503+
1504+
dtype, device = input_tensor.dtype, input_tensor.device
1505+
# difference with original modeling
1506+
# using minimum from dtype with larger bandwith (floa32) may lead to overflow
1507+
# during execution on platforms with default lower precision (bfloat16, float16)
1508+
min_dtype = torch.finfo(torch.float16).min
1509+
sequence_length = input_tensor.shape[1]
1510+
if using_static_cache:
1511+
target_length = past_key_values.get_max_length()
1512+
else:
1513+
target_length = (
1514+
attention_mask.shape[-1]
1515+
if isinstance(attention_mask, torch.Tensor)
1516+
else past_seen_tokens + sequence_length + 1
1517+
)
1518+
1519+
if attention_mask is not None and attention_mask.dim() == 4:
1520+
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
1521+
if attention_mask.max() != 0:
1522+
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
1523+
causal_mask = attention_mask
1524+
else:
1525+
# difference with original modeling
1526+
causal_mask = (
1527+
torch.full((sequence_length, target_length), fill_value=1, dtype=dtype, device=device) * min_dtype
1528+
)
1529+
if sequence_length != 1:
1530+
causal_mask = torch.triu(causal_mask, diagonal=1)
1531+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1532+
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
1533+
if attention_mask is not None:
1534+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1535+
mask_length = attention_mask.shape[-1]
1536+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
1537+
padding_mask = padding_mask == 0
1538+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
1539+
padding_mask, min_dtype
1540+
)
1541+
if (
1542+
self.config._attn_implementation == "sdpa"
1543+
and attention_mask is not None
1544+
and attention_mask.device.type == "cuda"
1545+
and not output_attentions
1546+
):
1547+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1548+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1549+
# Details: https://github.com/pytorch/pytorch/issues/110213
1550+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1551+
1552+
return causal_mask
1553+
1554+
1555+
if is_transformers_version(">", "4.40.2"):
1556+
_dbrx_update_causal_mask = _dbrx_update_causal_mask_latest
1557+
else:
1558+
_dbrx_update_causal_mask = _dbrx_update_causal_mask_legacy
1559+
1560+
14001561
class DBRXModelPatcher(DecoderModelPatcher):
14011562
def __enter__(self):
14021563
super().__enter__()
1564+
self._model.transformer._orig_update_causal_mask = self._model.transformer._update_causal_mask
1565+
self._model.transformer._update_causal_mask = types.MethodType(
1566+
_dbrx_update_causal_mask, self._model.transformer
1567+
)
14031568

14041569
for block in self._model.transformer.blocks:
14051570
rotary_emb = block.norm_attn_norm.attn.rotary_emb
@@ -1413,5 +1578,6 @@ def __enter__(self):
14131578

14141579
def __exit__(self, exc_type, exc_value, traceback):
14151580
super().__exit__(exc_type, exc_value, traceback)
1581+
self._model.transformer._update_causal_mask = self._model.transformer._orig_update_causal_mask
14161582
for block in self._model.transformer.blocks:
14171583
block.ffn.experts.forward = block.ffn.experts._orig_forward

tests/openvino/utils_tests.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
"data2vec_text": "hf-internal-testing/tiny-random-Data2VecTextModel",
4242
"data2vec_vision": "hf-internal-testing/tiny-random-Data2VecVisionModel",
4343
"data2vec_audio": "hf-internal-testing/tiny-random-Data2VecAudioModel",
44-
"dbrx": "yujiepan/dbrx-tiny-random",
44+
"dbrx": "katuni4ka/tiny-random-dbrx",
4545
"deberta": "hf-internal-testing/tiny-random-deberta",
4646
"deberta_v2": "hf-internal-testing/tiny-random-DebertaV2Model",
4747
"deit": "hf-internal-testing/tiny-random-deit",
@@ -93,7 +93,7 @@
9393
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
9494
"qwen": "katuni4ka/tiny-random-qwen",
9595
"qwen2": "Qwen/Qwen1.5-0.5B",
96-
"qwen2-moe": "yujiepan/qwen1.5-moe-tiny-random",
96+
"qwen2-moe": "katuni4ka/tiny-random-qwen1.5-moe",
9797
"resnet": "hf-internal-testing/tiny-random-resnet",
9898
"roberta": "hf-internal-testing/tiny-random-roberta",
9999
"roformer": "hf-internal-testing/tiny-random-roformer",

0 commit comments

Comments
 (0)