Skip to content

Commit 6e3b010

Browse files
committed
add support DBRX
1 parent a85eae6 commit 6e3b010

File tree

4 files changed

+128
-0
lines changed

4 files changed

+128
-0
lines changed

optimum/exporters/openvino/model_configs.py

+67
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
BaichuanModelPatcher,
4747
ChatGLMModelPatcher,
4848
CodeGenModelPatcher,
49+
DBRXModelPatcher,
4950
GemmaModelPatcher,
5051
InternLM2Patcher,
5152
InternLMModelPatcher,
@@ -752,3 +753,69 @@ def patch_model_for_export(
752753
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
753754
) -> "ModelPatcher":
754755
return CodeGenModelPatcher(self, model, model_kwargs=model_kwargs)
756+
757+
758+
class DBRXDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
759+
def __init__(
760+
self,
761+
task: str,
762+
normalized_config: NormalizedTextConfig,
763+
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
764+
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
765+
random_batch_size_range: Optional[Tuple[int, int]] = None,
766+
random_sequence_length_range: Optional[Tuple[int, int]] = None,
767+
**kwargs,
768+
):
769+
super().__init__(
770+
task=task,
771+
normalized_config=normalized_config,
772+
batch_size=batch_size,
773+
sequence_length=sequence_length,
774+
random_batch_size_range=random_batch_size_range,
775+
random_sequence_length_range=random_sequence_length_range,
776+
)
777+
self.num_key_value_heads = normalized_config.num_key_value_heads
778+
779+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
780+
v_shape = (
781+
self.batch_size,
782+
self.num_key_value_heads,
783+
self.sequence_length,
784+
self.hidden_size // self.num_attention_heads,
785+
)
786+
k_shape = (
787+
self.batch_size,
788+
self.num_key_value_heads,
789+
self.sequence_length,
790+
self.hidden_size // self.num_attention_heads * 2,
791+
)
792+
return [
793+
(
794+
self.random_float_tensor(k_shape, framework=framework, dtype=float_dtype),
795+
self.random_float_tensor(v_shape, framework=framework, dtype=float_dtype),
796+
)
797+
for _ in range(self.num_layers)
798+
]
799+
800+
801+
@register_in_tasks_manager(
802+
"dbrx",
803+
*["text-generation", "text-generation-with-past"],
804+
library_name="transformers",
805+
)
806+
class DBRXOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
807+
DEFAULT_ONNX_OPSET = 14
808+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
809+
num_attention_heads="n_heads",
810+
hidden_size="d_model",
811+
num_layers="n_layers",
812+
num_key_value_heads="attn_config.kv_n_heads",
813+
allow_new=True,
814+
)
815+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DBRXDummyPastKeyValuesGenerator)
816+
DUMMY_PKV_GENERATOR_CLASS = DBRXDummyPastKeyValuesGenerator
817+
818+
def patch_model_for_export(
819+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
820+
) -> "ModelPatcher":
821+
return DBRXModelPatcher(self, model, model_kwargs=model_kwargs)

optimum/exporters/openvino/model_patcher.py

+59
Original file line numberDiff line numberDiff line change
@@ -1356,3 +1356,62 @@ def __exit__(self, exc_type, exc_value, traceback):
13561356
for layer in self._model.transformer.h:
13571357
if hasattr(layer.attn, "_orig_attn"):
13581358
layer.attn._attn = layer.attn._orig_attn
1359+
1360+
1361+
def _dbrx_experts_forward(
1362+
self, x: torch.Tensor, weights: torch.Tensor, top_weights: torch.Tensor, top_experts: torch.LongTensor
1363+
):
1364+
bsz, q_len, hidden_size = x.shape
1365+
x = x.view(-1, hidden_size)
1366+
out = torch.zeros_like(x)
1367+
1368+
expert_mask = torch.nn.functional.one_hot(top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0)
1369+
# Chunk experts at once to avoid storing full parameter multiple times in autograd
1370+
w1_chunked = self.mlp.w1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk(
1371+
self.moe_num_experts, dim=0
1372+
)
1373+
v1_chunked = self.mlp.v1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk(
1374+
self.moe_num_experts, dim=0
1375+
)
1376+
w2_chunked = self.mlp.w2.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk(
1377+
self.moe_num_experts, dim=0
1378+
)
1379+
w1_chunked = [w1.squeeze(dim=0) for w1 in w1_chunked]
1380+
v1_chunked = [v1.squeeze(dim=0) for v1 in v1_chunked]
1381+
w2_chunked = [w2.squeeze(dim=0) for w2 in w2_chunked]
1382+
for expert_idx in range(0, self.moe_num_experts):
1383+
topk_idx, token_idx = torch.where(expert_mask[expert_idx])
1384+
1385+
token_list = token_idx
1386+
topk_list = topk_idx
1387+
1388+
expert_tokens = x[None, token_list].reshape(-1, hidden_size)
1389+
expert_out = (
1390+
self.mlp(expert_tokens, w1_chunked[expert_idx], v1_chunked[expert_idx], w2_chunked[expert_idx])
1391+
* top_weights[token_list, topk_list, None]
1392+
)
1393+
1394+
out.index_add_(0, token_idx, expert_out)
1395+
1396+
out = out.reshape(bsz, q_len, hidden_size)
1397+
return out
1398+
1399+
1400+
class DBRXModelPatcher(DecoderModelPatcher):
1401+
def __enter__(self):
1402+
super().__enter__()
1403+
1404+
for block in self._model.transformer.blocks:
1405+
rotary_emb = block.norm_attn_norm.attn.rotary_emb
1406+
if rotary_emb.inv_freq is None:
1407+
inv_freq = 1.0 / (
1408+
rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim)
1409+
)
1410+
rotary_emb.inv_freq = inv_freq
1411+
block.ffn.experts._orig_forward = block.ffn.experts.forward
1412+
block.ffn.experts.forward = types.MethodType(_dbrx_experts_forward, block.ffn.experts)
1413+
1414+
def __exit__(self, exc_type, exc_value, traceback):
1415+
super().__exit__(exc_type, exc_value, traceback)
1416+
for block in self._model.transformer.blocks:
1417+
block.ffn.experts.forward = block.ffn.experts._orig_forward

tests/openvino/test_modeling.py

+1
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
562562
"aquila2",
563563
"xverse",
564564
"internlm",
565+
"dbrx",
565566
)
566567
GENERATION_LENGTH = 100
567568
REMOTE_CODE_MODELS = (

tests/openvino/utils_tests.py

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
"data2vec_text": "hf-internal-testing/tiny-random-Data2VecTextModel",
4242
"data2vec_vision": "hf-internal-testing/tiny-random-Data2VecVisionModel",
4343
"data2vec_audio": "hf-internal-testing/tiny-random-Data2VecAudioModel",
44+
"dbrx": "yujiepan/dbrx-tiny-random",
4445
"deberta": "hf-internal-testing/tiny-random-deberta",
4546
"deberta_v2": "hf-internal-testing/tiny-random-DebertaV2Model",
4647
"deit": "hf-internal-testing/tiny-random-deit",

0 commit comments

Comments
 (0)