Skip to content

Commit a502f65

Browse files
committed
fix accuracy
1 parent 4912f99 commit a502f65

File tree

1 file changed

+53
-31
lines changed

1 file changed

+53
-31
lines changed

optimum/exporters/openvino/model_patcher.py

+53-31
Original file line numberDiff line numberDiff line change
@@ -2681,32 +2681,6 @@ def __exit__(self, exc_type, exc_value, traceback):
26812681
unpatch_update_causal_mask(self._model, "gpt_neox_japanese")
26822682

26832683

2684-
def _gpt_neo_sdpa_attn(self, query, key, value, attention_mask=None, head_mask=None):
2685-
# Keep the attention weights computation in fp32 to avoid overflow issues
2686-
query = query.to(torch.float32)
2687-
key = key.to(torch.float32)
2688-
2689-
# Apply sliding window masking for local attention layers
2690-
query_length, key_length = query.size(-2), key.size(-2)
2691-
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
2692-
# different from original for prevent overflow, apply to mask instead of directly to weights
2693-
mask_value = torch.finfo(torch.float16).min
2694-
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
2695-
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
2696-
mask_value = torch.tensor(mask_value, dtype=query.dtype).to(query.device)
2697-
if attention_mask is None:
2698-
attention_mask = torch.ones_like(causal_mask)
2699-
attention_mask = torch.where(causal_mask, attention_mask[:, :, :, : key.shape[-2]], mask_value)
2700-
2701-
# Mask heads if we want to
2702-
if head_mask is not None:
2703-
attention_mask = attention_mask * head_mask
2704-
2705-
attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)
2706-
2707-
return attn_output, None
2708-
2709-
27102684
def _gpt_neo_attn_forward(
27112685
self,
27122686
hidden_states,
@@ -2731,22 +2705,70 @@ def _gpt_neo_attn_forward(
27312705
)
27322706

27332707

2708+
# Adopted from https://github.com/huggingface/optimum/blob/main/optimum/bettertransformer/models/attention.py#L185
2709+
def _gpt_neo_attn_sdpa(
2710+
self,
2711+
query: torch.Tensor,
2712+
key: torch.Tensor,
2713+
value: torch.Tensor,
2714+
attention_mask: Optional[torch.Tensor] = None,
2715+
head_mask: Optional[torch.Tensor] = None,
2716+
):
2717+
batch_size = query.shape[0]
2718+
2719+
mask_value = torch.finfo(torch.float16).min
2720+
mask_value = torch.full([], mask_value, dtype=value.dtype)
2721+
2722+
dropout_p = float(self.config.attention_dropout) if self.training else 0.0
2723+
if (batch_size == 1 or self.training) and self.attention_type == "global":
2724+
if query.shape[2] > 1:
2725+
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
2726+
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True
2727+
)
2728+
else:
2729+
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
2730+
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False, scale=1.0
2731+
)
2732+
else:
2733+
query_length, key_length = query.size(-2), key.size(-2)
2734+
2735+
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
2736+
2737+
causal_mask = torch.where(causal_mask, 0, mask_value)
2738+
if batch_size > 1:
2739+
# torch.Tensor.expand does no memory copy
2740+
causal_mask = causal_mask.expand(batch_size, -1, -1, -1)
2741+
2742+
if attention_mask is not None:
2743+
attention_mask = causal_mask + attention_mask
2744+
2745+
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
2746+
query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False, scale=1.0
2747+
)
2748+
2749+
return sdpa_result, None
2750+
2751+
27342752
class GptNeoModelPatcher(DecoderModelPatcher):
27352753
def __enter__(self):
27362754
super().__enter__()
27372755
if is_transformers_version(">=", "4.45.0") and is_torch_version(">=", "2.1.0"):
2756+
self._model.config._orig_attn_implementation = self._model.config._attn_implementation
2757+
self._model.config._attn_implementation = "sdpa"
27382758
for layer in self._model.transformer.h:
27392759
self_attn = layer.attn.attention
27402760
self_attn._orig_attn = self_attn._attn
2741-
self_attn._attn = types.MethodType(_gpt_neo_sdpa_attn, self_attn)
2761+
self_attn._attn = types.MethodType(_gpt_neo_attn_sdpa, self_attn)
27422762
self_attn._orig_forward = types.MethodType(_gpt_neo_attn_forward, self_attn)
27432763

27442764
def __exit__(self, exc_type, exc_value, traceback):
27452765
super().__exit__(exc_type, exc_value, traceback)
2746-
for layer in self._model.transformer.h:
2747-
if hasattr(layer.attn.attention, "_orig_forward"):
2748-
layer.attn.attention.forward = layer.attn.attention._orig_forward
2749-
layer.attn.attention._attn = layer.attn.attention._orig_attn
2766+
if hasattr(self._model.config, "_orig_attn_implementation"):
2767+
self._model.config._attn_implementation = self._model.config._orig_attn_implementation
2768+
for layer in self._model.transformer.h:
2769+
for layer in self._model.transformer.h:
2770+
layer.attn.attention.forward = layer.attn.attention._orig_forward
2771+
layer.attn.attention._attn = layer.attn.attention._orig_attn
27502772

27512773

27522774
class Gemma2ModelPatcher(LlamaModelPatcher):

0 commit comments

Comments
 (0)