Skip to content

Commit fe77316

Browse files
authored
Add support decilm (#899)
1 parent abfb4f4 commit fe77316

File tree

6 files changed

+190
-1
lines changed

6 files changed

+190
-1
lines changed

docs/source/openvino/models.mdx

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ Here is the list of the supported architectures :
4040
- Data2VecVision
4141
- Deberta
4242
- Deberta-v2
43+
- DeciLM
4344
- Deit
4445
- DistilBert
4546
- Electra

optimum/exporters/openvino/model_configs.py

+55
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
ChatGLMModelPatcher,
5454
CodeGenModelPatcher,
5555
DBRXModelPatcher,
56+
DeciLMModelPatcher,
5657
FalconModelPatcher,
5758
Gemma2ModelPatcher,
5859
GptNeoxJapaneseModelPatcher,
@@ -1018,3 +1019,57 @@ def patch_model_for_export(
10181019
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
10191020
) -> "ModelPatcher":
10201021
return Gemma2ModelPatcher(self, model, model_kwargs=model_kwargs)
1022+
1023+
1024+
class DeciDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
1025+
def __init__(
1026+
self,
1027+
task: str,
1028+
normalized_config: NormalizedTextConfig,
1029+
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
1030+
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
1031+
random_batch_size_range: Optional[Tuple[int, int]] = None,
1032+
random_sequence_length_range: Optional[Tuple[int, int]] = None,
1033+
**kwargs,
1034+
):
1035+
super().__init__(
1036+
task=task,
1037+
normalized_config=normalized_config,
1038+
batch_size=batch_size,
1039+
sequence_length=sequence_length,
1040+
random_batch_size_range=random_batch_size_range,
1041+
random_sequence_length_range=random_sequence_length_range,
1042+
)
1043+
self.num_key_value_heads_per_layer = normalized_config.num_key_value_heads_per_layer
1044+
1045+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
1046+
past_key_values = []
1047+
1048+
for layer_id in range(self.num_layers):
1049+
shape = (
1050+
self.batch_size,
1051+
self.num_key_value_heads_per_layer[layer_id],
1052+
self.sequence_length,
1053+
self.hidden_size // self.num_attention_heads,
1054+
)
1055+
past_key_values.append(
1056+
(
1057+
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
1058+
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
1059+
)
1060+
)
1061+
return past_key_values
1062+
1063+
1064+
@register_in_tasks_manager("deci", *["text-generation", "text-generation-with-past"], library_name="transformers")
1065+
class DeciOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
1066+
DEFAULT_ONNX_OPSET = 14
1067+
1068+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DeciDummyPastKeyValuesGenerator)
1069+
DUMMY_PKV_GENERATOR_CLASS = DeciDummyPastKeyValuesGenerator
1070+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
1071+
1072+
def patch_model_for_export(
1073+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
1074+
) -> "ModelPatcher":
1075+
return DeciLMModelPatcher(self, model, model_kwargs=model_kwargs)

optimum/exporters/openvino/model_patcher.py

+131
Original file line numberDiff line numberDiff line change
@@ -2467,3 +2467,134 @@ def patched_forward(*args, **kwargs):
24672467
return outputs
24682468

24692469
self.patched_forward = patched_forward
2470+
2471+
2472+
def _decilm_attn_forward(
2473+
self,
2474+
hidden_states: torch.Tensor,
2475+
attention_mask: Optional[torch.Tensor] = None,
2476+
position_ids: Optional[torch.LongTensor] = None,
2477+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
2478+
output_attentions: bool = False,
2479+
use_cache: bool = False,
2480+
**kwargs,
2481+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
2482+
# decilm contains bug in attention calculation for case if past key values is not None
2483+
def rotate_half(x):
2484+
"""Rotates half the hidden dims of the input."""
2485+
x1 = x[..., : x.shape[-1] // 2]
2486+
x2 = x[..., x.shape[-1] // 2 :]
2487+
return torch.cat((-x2, x1), dim=-1)
2488+
2489+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
2490+
"""Applies Rotary Position Embedding to the query and key tensors.
2491+
2492+
Args:
2493+
q (`torch.Tensor`): The query tensor.
2494+
k (`torch.Tensor`): The key tensor.
2495+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
2496+
sin (`torch.Tensor`): The sine part of the rotary embedding.
2497+
position_ids (`torch.Tensor`):
2498+
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
2499+
used to pass offsetted position ids when working with a KV-cache.
2500+
unsqueeze_dim (`int`, *optional*, defaults to 1):
2501+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
2502+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
2503+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
2504+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
2505+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
2506+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
2507+
Returns:
2508+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
2509+
"""
2510+
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
2511+
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
2512+
q_embed = (q * cos) + (rotate_half(q) * sin)
2513+
k_embed = (k * cos) + (rotate_half(k) * sin)
2514+
return q_embed, k_embed
2515+
2516+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
2517+
"""
2518+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
2519+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
2520+
"""
2521+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
2522+
if n_rep == 1:
2523+
return hidden_states
2524+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
2525+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
2526+
2527+
bsz, q_len, _ = hidden_states.size()
2528+
if self.pretraining_tp > 1:
2529+
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp
2530+
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0)
2531+
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
2532+
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
2533+
2534+
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)]
2535+
query_states = torch.cat(query_states, dim=-1)
2536+
2537+
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)]
2538+
key_states = torch.cat(key_states, dim=-1)
2539+
2540+
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)]
2541+
value_states = torch.cat(value_states, dim=-1)
2542+
2543+
else:
2544+
query_states = self.q_proj(hidden_states)
2545+
key_states = self.k_proj(hidden_states)
2546+
value_states = self.v_proj(hidden_states)
2547+
2548+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
2549+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
2550+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
2551+
2552+
kv_seq_len = key_states.shape[-2]
2553+
if past_key_value is not None:
2554+
kv_seq_len += past_key_value[0].shape[-2]
2555+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
2556+
2557+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
2558+
2559+
if past_key_value is not None:
2560+
# reuse k, v, self_attention
2561+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
2562+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
2563+
2564+
past_key_value = (key_states, value_states) if use_cache else None
2565+
2566+
# repeat k/v heads if n_kv_heads < n_heads
2567+
key_states = repeat_kv(key_states, self.num_key_value_groups)
2568+
value_states = repeat_kv(value_states, self.num_key_value_groups)
2569+
attn_output = F.scaled_dot_product_attention(
2570+
query_states, key_states, value_states, is_causal=attention_mask is None, attn_mask=attention_mask
2571+
)
2572+
2573+
# modified, in original implementation .transpose(1, 2) missed
2574+
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size)
2575+
2576+
if self.pretraining_tp > 1:
2577+
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
2578+
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1)
2579+
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)])
2580+
else:
2581+
attn_output = self.o_proj(attn_output)
2582+
2583+
attn_weights = None
2584+
2585+
return attn_output, attn_weights, past_key_value
2586+
2587+
2588+
class DeciLMModelPatcher(DecoderModelPatcher):
2589+
def __enter__(self):
2590+
super().__enter__()
2591+
2592+
for layer in self._model.model.layers:
2593+
layer.self_attn._orig_forward = layer.self_attn.forward
2594+
layer.self_attn.forward = types.MethodType(_decilm_attn_forward, layer.self_attn)
2595+
2596+
def __exit__(self, exc_type, exc_value, traceback):
2597+
super().__exit__(exc_type, exc_value, traceback)
2598+
2599+
for layer in self._model.model.layers:
2600+
layer.self_attn.forward = layer.self_attn._orig_forward

optimum/intel/openvino/modeling_base.py

-1
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,6 @@ def _compile_model(
219219
ov_config: Optional[Dict[str, str]] = None,
220220
model_save_dir: Union[str, Path] = None,
221221
):
222-
logger.info(f"Compiling the model to {device} ...")
223222
if isinstance(model, str):
224223
model = Path(model)
225224
ov_config = ov_config or {}

tests/openvino/test_modeling.py

+2
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
754754
"internlm",
755755
"jais",
756756
"glm4",
757+
"decilm",
757758
)
758759

759760
if is_transformers_version(">=", "4.40.0"):
@@ -791,6 +792,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
791792
"arctic",
792793
"glm4",
793794
"exaone",
795+
"decilm",
794796
)
795797

796798
@parameterized.expand(SUPPORTED_ARCHITECTURES)

tests/openvino/utils_tests.py

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
"dbrx": "katuni4ka/tiny-random-dbrx",
4545
"deberta": "hf-internal-testing/tiny-random-deberta",
4646
"deberta_v2": "hf-internal-testing/tiny-random-DebertaV2Model",
47+
"decilm": "katuni4ka/tiny-random-decilm",
4748
"deit": "hf-internal-testing/tiny-random-DeiTModel",
4849
"convnext": "hf-internal-testing/tiny-random-convnext",
4950
"convnextv2": "hf-internal-testing/tiny-random-ConvNextV2Model",

0 commit comments

Comments
 (0)