Skip to content

Commit 7d7de7c

Browse files
authored
add support granite and granitemoe models (#1099)
* add support granite and granitemoe models * add tests and docs * add models to test cases
1 parent bb1c68a commit 7d7de7c

File tree

5 files changed

+103
-0
lines changed

5 files changed

+103
-0
lines changed

docs/source/openvino/models.mdx

+2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ Here is the list of the supported architectures :
5858
- GPT-NeoX-Japanese
5959
- Gemma
6060
- Gemma2
61+
- Granite
62+
- GraniteMoE
6163
- Hubert
6264
- IBert
6365
- InternLM

optimum/exporters/openvino/model_configs.py

+28
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
GptNeoModelPatcher,
7373
GptNeoxJapaneseModelPatcher,
7474
GptNeoxModelPatcher,
75+
GraniteMoEModelPatcher,
7576
IBertModelPatcher,
7677
InputEmbeddingPatcher,
7778
InternLM2Patcher,
@@ -2554,3 +2555,30 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
25542555
)
25552556
class GLMOpenVINOConfig(LlamaOpenVINOConfig):
25562557
MIN_TRANSFORMERS_VERSION = "4.46.0"
2558+
2559+
2560+
@register_in_tasks_manager(
2561+
"granite",
2562+
*[
2563+
"feature-extraction",
2564+
"feature-extraction-with-past",
2565+
"text-generation",
2566+
"text-generation-with-past",
2567+
"text-classification",
2568+
],
2569+
library_name="transformers",
2570+
)
2571+
class GraniteOpenVINOConfig(LlamaOpenVINOConfig):
2572+
MIN_TRANSFORMERS_VERSION = "4.45.0"
2573+
2574+
2575+
@register_in_tasks_manager(
2576+
"granitemoe", *["text-generation", "text-generation-with-past"], library_name="transformers"
2577+
)
2578+
class GraniteMoEOpenVINOConfig(LlamaOpenVINOConfig):
2579+
MIN_TRANSFORMERS_VERSION = "4.45.0"
2580+
2581+
def patch_model_for_export(
2582+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
2583+
) -> ModelPatcher:
2584+
return GraniteMoEModelPatcher(self, model, model_kwargs=model_kwargs)

optimum/exporters/openvino/model_patcher.py

+69
Original file line numberDiff line numberDiff line change
@@ -3581,3 +3581,72 @@ def __exit__(self, exc_type, exc_value, traceback):
35813581
for block in self._model.blocks:
35823582
block.forward = block._orig_forward
35833583
block.attn.forward = block.attn._orig_forward
3584+
3585+
3586+
# copied from https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/granitemoe/modeling_granitemoe.py#L321
3587+
def _granite_moe_topk_gating_forward(self, hidden_states):
3588+
# compute the top_k routing decision
3589+
logits = self.layer(hidden_states).float() # [batch_size x seq_len, num_experts]
3590+
top_k_logits, top_k_indices = logits.topk(self.top_k, dim=1) # [num_tokens, top_k]
3591+
top_k_gates = torch.softmax(top_k_logits, dim=1).type_as(hidden_states) # [num_tokens, top_k]
3592+
3593+
# compute number of input given to each expert
3594+
zeros = torch.zeros(
3595+
[top_k_gates.size(0), self.num_experts], dtype=top_k_gates.dtype, device=top_k_gates.device
3596+
) # [num_tokens, num_experts]
3597+
gates = zeros.scatter(1, top_k_indices, 1) # [num_tokens, num_experts]
3598+
expert_size = gates.long().sum(0) # [num_experts,]
3599+
# difference with original, removed expert_size = expert_size.tolist() due to incorrect tracing
3600+
3601+
# sort and group input tokens according to expert assignment
3602+
top_k_experts = top_k_indices.flatten() # [num_tokens * top_k]
3603+
_, index_sorted_experts = top_k_experts.sort(0) # [num_tokens * top_k]
3604+
batch_index = index_sorted_experts.div(self.top_k, rounding_mode="trunc") # [num_tokens * top_k]
3605+
3606+
# gather the gate values for grouped input tokens
3607+
top_k_gates = top_k_gates.flatten() # [num_tokens * top_k]
3608+
batch_gates = top_k_gates[index_sorted_experts] # [num_tokens * top_k]
3609+
3610+
return index_sorted_experts, batch_index, batch_gates, expert_size, logits
3611+
3612+
3613+
# copied from https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/granitemoe/modeling_granitemoe.py#L281
3614+
def _granite_moe_parallel_experts_forward(self, inputs, expert_size):
3615+
output_list = []
3616+
# difference with original
3617+
# 1) expert_size is tensor instead of list of ints after gating patching, that does not allow use original inputs.split(expert_size)
3618+
# 2) use index_start:next_index for obtaining expert inputs splits one by one instead of precomputed splits once before cycle
3619+
index_start = torch.tensor(0, dtype=torch.int64)
3620+
for i in range(self.num_experts):
3621+
next_index = index_start + expert_size[i]
3622+
output_list.append(F.linear(inputs[index_start:next_index], self.weight[i]))
3623+
index_start = next_index
3624+
results = torch.cat(output_list, dim=0)
3625+
return results
3626+
3627+
3628+
class GraniteMoEModelPatcher(LlamaModelPatcher):
3629+
def __enter__(self):
3630+
super().__enter__()
3631+
for layer in self._model.model.layers:
3632+
block_sparse_moe = layer.block_sparse_moe
3633+
block_sparse_moe.router._orig_forward = block_sparse_moe.router.forward
3634+
block_sparse_moe.router.forward = types.MethodType(
3635+
_granite_moe_topk_gating_forward, block_sparse_moe.router
3636+
)
3637+
block_sparse_moe.input_linear._orig_forward = block_sparse_moe.input_linear.forward
3638+
block_sparse_moe.input_linear.forward = types.MethodType(
3639+
_granite_moe_parallel_experts_forward, block_sparse_moe.input_linear
3640+
)
3641+
block_sparse_moe.output_linear._orig_forward = block_sparse_moe.output_linear.forward
3642+
block_sparse_moe.output_linear.forward = types.MethodType(
3643+
_granite_moe_parallel_experts_forward, block_sparse_moe.output_linear
3644+
)
3645+
3646+
def __exit__(self, exc_type, exc_value, traceback):
3647+
super().__exit__(exc_type, exc_value, traceback)
3648+
for layer in self._model.model.layers:
3649+
block_sparse_moe = layer.block_sparse_moe
3650+
block_sparse_moe.router.forward = block_sparse_moe.router._orig_forward
3651+
block_sparse_moe.input_linear.forward = block_sparse_moe.input_linear._orig_forward
3652+
block_sparse_moe.output_linear.forward = block_sparse_moe.output_linear._orig_forward

tests/openvino/test_modeling.py

+2
Original file line numberDiff line numberDiff line change
@@ -916,6 +916,8 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
916916
"mistral-nemo",
917917
"minicpm3",
918918
"glm",
919+
"granite",
920+
"granite-moe",
919921
)
920922

921923
# gptq and awq install disabled for windows test environment

tests/openvino/utils_tests.py

+2
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@
7272
"gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
7373
"gpt_neox_japanese": "hf-internal-testing/tiny-random-GPTNeoXJapaneseForCausalLM",
7474
"gptj": "hf-internal-testing/tiny-random-GPTJModel",
75+
"granite": "katuni4ka/tiny-random-granite",
76+
"granite-moe": "katuni4ka/tiny-random-granite-moe",
7577
"hubert": "hf-internal-testing/tiny-random-HubertModel",
7678
"ibert": "hf-internal-testing/tiny-random-ibert",
7779
"internlm": "katuni4ka/tiny-random-internlm",

0 commit comments

Comments
 (0)