Skip to content

Commit de0777e

Browse files
committed
fix gpt bigcode model laoding with fp16 weights precision
1 parent 753f84d commit de0777e

File tree

2 files changed

+111
-0
lines changed

2 files changed

+111
-0
lines changed

optimum/exporters/openvino/model_configs.py

+21
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
CodeGenOnnxConfig,
3030
FalconOnnxConfig,
3131
GemmaOnnxConfig,
32+
GPTBigCodeOnnxConfig,
3233
GPTJOnnxConfig,
3334
GPTNeoOnnxConfig,
3435
GPTNeoXOnnxConfig,
@@ -68,6 +69,7 @@
6869
FalconModelPatcher,
6970
FluxTransfromerModelPatcher,
7071
Gemma2ModelPatcher,
72+
GptBigCodeModelPatcher,
7173
GptJModelPatcher,
7274
GptNeoModelPatcher,
7375
GptNeoxJapaneseModelPatcher,
@@ -2554,3 +2556,22 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
25542556
)
25552557
class GLMOpenVINOConfig(LlamaOpenVINOConfig):
25562558
MIN_TRANSFORMERS_VERSION = "4.46.0"
2559+
2560+
2561+
@register_in_tasks_manager(
2562+
"gpt-bigcode",
2563+
*[
2564+
"feature-extraction",
2565+
"feature-extraction-with-past",
2566+
"text-generation",
2567+
"text-generation-with-past",
2568+
"text-classification",
2569+
"token-classification",
2570+
],
2571+
library_name="transformers",
2572+
)
2573+
class GPTBigCodeOpenVINOConfig(GPTBigCodeOnnxConfig):
2574+
def patch_model_for_export(
2575+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
2576+
) -> "ModelPatcher":
2577+
return GptBigCodeModelPatcher(self, model, model_kwargs=model_kwargs)

optimum/exporters/openvino/model_patcher.py

+90
Original file line numberDiff line numberDiff line change
@@ -3601,3 +3601,93 @@ def __exit__(self, exc_type, exc_value, traceback):
36013601
for block in self._model.blocks:
36023602
block.forward = block._orig_forward
36033603
block.attn.forward = block.attn._orig_forward
3604+
3605+
3606+
# copied from https://github.com/huggingface/transformers/blob/v4.46.3/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L401
3607+
def gpt_bigcode_attn(self, query, key, value, attention_mask=None, head_mask=None):
3608+
if head_mask is not None:
3609+
# The super dispatch is done in the forward.
3610+
raise ValueError("PyTorch SDPA does not support head_mask. Please open an issue in Transformers repository.")
3611+
3612+
scale = None
3613+
if not self.scale_attn_weights:
3614+
scale = 1
3615+
3616+
# MQA models: (batch_size, query_length, num_heads * head_dim)
3617+
# MHA models: (batch_size, num_heads, query_length, head_dim)
3618+
query_shape = query.shape
3619+
batch_size = query_shape[0]
3620+
key.shape[-2]
3621+
3622+
if self.multi_query:
3623+
query_length = query_shape[1]
3624+
3625+
# SDPA requires the dimension [..., sequence_length, head_dim].
3626+
query = query.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
3627+
3628+
# Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions.
3629+
key = key.unsqueeze(1)
3630+
value = value.unsqueeze(1)
3631+
3632+
# Although these expand are not numerically useful, PyTorch can not dispatch to memory-efficient backend
3633+
# and flash attention backend (No available kernel. Aborting execution.) from the shapes
3634+
# query = [batch_size, num_heads, query_length, head_dim]
3635+
# key = [batch_size, 1, past_length, head_dim]
3636+
# value = [batch_size, 1, past_length, head_dim]
3637+
#
3638+
# torch==2.1.2 is bugged with non-contiguous inputs with custom attn_mask (https://github.com/pytorch/pytorch/issues/112577), hence the check.
3639+
if is_torch_version(">=", "2.2.0"):
3640+
key = key.expand(-1, self.num_heads, -1, -1)
3641+
value = value.expand(-1, self.num_heads, -1, -1)
3642+
else:
3643+
query_length = query_shape[-1]
3644+
3645+
# See the comment above.
3646+
if query.device.type == "cuda" and attention_mask is not None:
3647+
query = query.contiguous()
3648+
key = key.contiguous()
3649+
value = value.contiguous()
3650+
3651+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
3652+
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
3653+
# The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not
3654+
# create a causal mask in case query_length == 1.
3655+
is_causal = True if self.is_causal and attention_mask is None and query_length > 1 else False
3656+
# different from original, due to loading model weights in original format transformer.wte dtype may be different from query dtype
3657+
if attention_mask is not None:
3658+
attention_mask = attention_mask.to(query.dtype)
3659+
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
3660+
query,
3661+
key,
3662+
value,
3663+
attn_mask=attention_mask,
3664+
dropout_p=self.attn_pdrop if self.training else 0.0,
3665+
is_causal=is_causal,
3666+
scale=scale,
3667+
)
3668+
3669+
if self.multi_query:
3670+
# (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim)
3671+
sdpa_result = sdpa_result.transpose(1, 2)
3672+
3673+
# Reshape is kind of expensive here, as it does a memory copy,
3674+
# but I did not manage to make away without it (logits do not match when using view)
3675+
# (batch_size, seq_len, num_heads, head_dim) --> (batch_size, seq_len, num_heads * head_dim)
3676+
sdpa_result = sdpa_result.reshape(query_shape)
3677+
3678+
return sdpa_result, None
3679+
3680+
3681+
class GptBigCodeModelPatcher(DecoderModelPatcher):
3682+
def __enter__(self):
3683+
super().__enter__()
3684+
if getattr(self._model.config, "_attn_implementation", "eager") == "sdpa":
3685+
for layer in self._model.transformer.h:
3686+
layer.attn._orig_attn = layer.attn._attn
3687+
layer.attn._attn = types.MethodType(gpt_bigcode_attn, layer.attn)
3688+
3689+
def __exit__(self, exc_type, exc_value, traceback):
3690+
super().__exit__(exc_type, exc_value, traceback)
3691+
if getattr(self._model.config, "_attn_implementation", "eager") == "sdpa":
3692+
for layer in self._model.transformer.h:
3693+
layer.attn._attn = layer.attn._orig_attn

0 commit comments

Comments
 (0)