Skip to content

Commit 49441bc

Browse files
authored
fix gpt bigcode model loading with fp16 weights precision (#1098)
* fix gpt bigcode model laoding with fp16 weights precision * code style after rebase
1 parent 58aec63 commit 49441bc

File tree

2 files changed

+110
-0
lines changed

2 files changed

+110
-0
lines changed

optimum/exporters/openvino/model_configs.py

+20
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
CodeGenOnnxConfig,
3030
FalconOnnxConfig,
3131
GemmaOnnxConfig,
32+
GPTBigCodeOnnxConfig,
3233
GPTJOnnxConfig,
3334
GPTNeoOnnxConfig,
3435
GPTNeoXOnnxConfig,
@@ -73,6 +74,7 @@
7374
FalconModelPatcher,
7475
FluxTransfromerModelPatcher,
7576
Gemma2ModelPatcher,
77+
GptBigCodeModelPatcher,
7678
GptJModelPatcher,
7779
GptNeoModelPatcher,
7880
GptNeoxJapaneseModelPatcher,
@@ -2591,3 +2593,21 @@ def patch_model_for_export(
25912593
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
25922594
) -> ModelPatcher:
25932595
return GraniteMoEModelPatcher(self, model, model_kwargs=model_kwargs)
2596+
2597+
2598+
@register_in_tasks_manager(
2599+
"gpt-bigcode",
2600+
*[
2601+
"feature-extraction",
2602+
"feature-extraction-with-past",
2603+
"text-generation",
2604+
"text-generation-with-past",
2605+
"text-classification",
2606+
],
2607+
library_name="transformers",
2608+
)
2609+
class GPTBigCodeOpenVINOConfig(GPTBigCodeOnnxConfig):
2610+
def patch_model_for_export(
2611+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
2612+
) -> "ModelPatcher":
2613+
return GptBigCodeModelPatcher(self, model, model_kwargs=model_kwargs)

optimum/exporters/openvino/model_patcher.py

+90
Original file line numberDiff line numberDiff line change
@@ -3650,3 +3650,93 @@ def __exit__(self, exc_type, exc_value, traceback):
36503650
block_sparse_moe.router.forward = block_sparse_moe.router._orig_forward
36513651
block_sparse_moe.input_linear.forward = block_sparse_moe.input_linear._orig_forward
36523652
block_sparse_moe.output_linear.forward = block_sparse_moe.output_linear._orig_forward
3653+
3654+
3655+
# copied from https://github.com/huggingface/transformers/blob/v4.46.3/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L401
3656+
def gpt_bigcode_attn(self, query, key, value, attention_mask=None, head_mask=None):
3657+
if head_mask is not None:
3658+
# The super dispatch is done in the forward.
3659+
raise ValueError("PyTorch SDPA does not support head_mask. Please open an issue in Transformers repository.")
3660+
3661+
scale = None
3662+
if not self.scale_attn_weights:
3663+
scale = 1
3664+
3665+
# MQA models: (batch_size, query_length, num_heads * head_dim)
3666+
# MHA models: (batch_size, num_heads, query_length, head_dim)
3667+
query_shape = query.shape
3668+
batch_size = query_shape[0]
3669+
key.shape[-2]
3670+
3671+
if self.multi_query:
3672+
query_length = query_shape[1]
3673+
3674+
# SDPA requires the dimension [..., sequence_length, head_dim].
3675+
query = query.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
3676+
3677+
# Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions.
3678+
key = key.unsqueeze(1)
3679+
value = value.unsqueeze(1)
3680+
3681+
# Although these expand are not numerically useful, PyTorch can not dispatch to memory-efficient backend
3682+
# and flash attention backend (No available kernel. Aborting execution.) from the shapes
3683+
# query = [batch_size, num_heads, query_length, head_dim]
3684+
# key = [batch_size, 1, past_length, head_dim]
3685+
# value = [batch_size, 1, past_length, head_dim]
3686+
#
3687+
# torch==2.1.2 is bugged with non-contiguous inputs with custom attn_mask (https://github.com/pytorch/pytorch/issues/112577), hence the check.
3688+
if is_torch_version(">=", "2.2.0"):
3689+
key = key.expand(-1, self.num_heads, -1, -1)
3690+
value = value.expand(-1, self.num_heads, -1, -1)
3691+
else:
3692+
query_length = query_shape[-1]
3693+
3694+
# See the comment above.
3695+
if query.device.type == "cuda" and attention_mask is not None:
3696+
query = query.contiguous()
3697+
key = key.contiguous()
3698+
value = value.contiguous()
3699+
3700+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
3701+
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
3702+
# The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not
3703+
# create a causal mask in case query_length == 1.
3704+
is_causal = True if self.is_causal and attention_mask is None and query_length > 1 else False
3705+
# different from original, due to loading model weights in original format transformer.wte dtype may be different from query dtype
3706+
if attention_mask is not None:
3707+
attention_mask = attention_mask.to(query.dtype)
3708+
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
3709+
query,
3710+
key,
3711+
value,
3712+
attn_mask=attention_mask,
3713+
dropout_p=self.attn_pdrop if self.training else 0.0,
3714+
is_causal=is_causal,
3715+
scale=scale,
3716+
)
3717+
3718+
if self.multi_query:
3719+
# (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim)
3720+
sdpa_result = sdpa_result.transpose(1, 2)
3721+
3722+
# Reshape is kind of expensive here, as it does a memory copy,
3723+
# but I did not manage to make away without it (logits do not match when using view)
3724+
# (batch_size, seq_len, num_heads, head_dim) --> (batch_size, seq_len, num_heads * head_dim)
3725+
sdpa_result = sdpa_result.reshape(query_shape)
3726+
3727+
return sdpa_result, None
3728+
3729+
3730+
class GptBigCodeModelPatcher(DecoderModelPatcher):
3731+
def __enter__(self):
3732+
super().__enter__()
3733+
if getattr(self._model.config, "_attn_implementation", "eager") == "sdpa":
3734+
for layer in self._model.transformer.h:
3735+
layer.attn._orig_attn = layer.attn._attn
3736+
layer.attn._attn = types.MethodType(gpt_bigcode_attn, layer.attn)
3737+
3738+
def __exit__(self, exc_type, exc_value, traceback):
3739+
super().__exit__(exc_type, exc_value, traceback)
3740+
if getattr(self._model.config, "_attn_implementation", "eager") == "sdpa":
3741+
for layer in self._model.transformer.h:
3742+
layer.attn._attn = layer.attn._orig_attn

0 commit comments

Comments
 (0)