Skip to content

Commit b64a5c8

Browse files
committed
apply review comments
1 parent a255a08 commit b64a5c8

File tree

2 files changed

+14
-40
lines changed

2 files changed

+14
-40
lines changed

optimum/exporters/openvino/model_configs.py

+2-39
Original file line numberDiff line numberDiff line change
@@ -764,43 +764,6 @@ def patch_model_for_export(
764764
return CodeGenModelPatcher(self, model, model_kwargs=model_kwargs)
765765

766766

767-
class DBRXDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
768-
def __init__(
769-
self,
770-
task: str,
771-
normalized_config: NormalizedTextConfig,
772-
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
773-
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
774-
random_batch_size_range: Optional[Tuple[int, int]] = None,
775-
random_sequence_length_range: Optional[Tuple[int, int]] = None,
776-
**kwargs,
777-
):
778-
super().__init__(
779-
task=task,
780-
normalized_config=normalized_config,
781-
batch_size=batch_size,
782-
sequence_length=sequence_length,
783-
random_batch_size_range=random_batch_size_range,
784-
random_sequence_length_range=random_sequence_length_range,
785-
)
786-
self.num_key_value_heads = normalized_config.num_key_value_heads
787-
788-
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
789-
shape = (
790-
self.batch_size,
791-
self.num_key_value_heads,
792-
self.sequence_length,
793-
self.hidden_size // self.num_attention_heads,
794-
)
795-
return [
796-
(
797-
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
798-
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
799-
)
800-
for _ in range(self.num_layers)
801-
]
802-
803-
804767
@register_in_tasks_manager(
805768
"dbrx",
806769
*["text-generation", "text-generation-with-past"],
@@ -815,8 +778,8 @@ class DBRXOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
815778
num_key_value_heads="attn_config.kv_n_heads",
816779
allow_new=True,
817780
)
818-
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DBRXDummyPastKeyValuesGenerator)
819-
DUMMY_PKV_GENERATOR_CLASS = DBRXDummyPastKeyValuesGenerator
781+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
782+
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
820783

821784
def patch_model_for_export(
822785
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None

optimum/exporters/openvino/model_patcher.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -1468,6 +1468,7 @@ def _dbrx_experts_forward(
14681468
return out
14691469

14701470

1471+
# adapted from https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/dbrx/modeling_dbrx.py#L1228
14711472
def _dbrx_update_causal_mask_legacy(
14721473
self, attention_mask: Optional[torch.Tensor], input_tensor: torch.Tensor, cache_position: torch.Tensor
14731474
) -> Optional[torch.Tensor]:
@@ -1479,6 +1480,9 @@ def _dbrx_update_causal_mask_legacy(
14791480
return None
14801481

14811482
dtype, device = input_tensor.dtype, input_tensor.device
1483+
# difference with original modeling
1484+
# using minimum from dtype with larger bandwith (floa32) may lead to overflow
1485+
# during execution on platforms with default lower precision (bfloat16, float16)
14821486
min_dtype = torch.finfo(torch.float16).min
14831487
sequence_length = input_tensor.shape[1]
14841488
if hasattr(self.blocks[0].norm_attn_norm.attn, "past_key_value"): # static cache
@@ -1487,7 +1491,9 @@ def _dbrx_update_causal_mask_legacy(
14871491
target_length = (
14881492
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
14891493
)
1490-
1494+
# difference with original modeling
1495+
# removed target_length = int(target_length).
1496+
# Casting to int leads to constant folding during tracing that makes impossible to use model for sequence of different length
14911497
causal_mask = torch.full((sequence_length, target_length), fill_value=1, dtype=dtype, device=device) * min_dtype
14921498
if sequence_length != 1:
14931499
causal_mask = torch.triu(causal_mask, diagonal=1)
@@ -1535,6 +1541,7 @@ def _dbrx_update_causal_mask_legacy(
15351541
return causal_mask
15361542

15371543

1544+
# adopted from https://github.com/huggingface/transformers/blob/1b3dba9417eebe16b7c206d1dfca6a4c7f11dbec/src/transformers/models/dbrx/modeling_dbrx.py#L1204
15381545
def _dbrx_update_causal_mask_latest(
15391546
self,
15401547
attention_mask: torch.Tensor,
@@ -1631,18 +1638,22 @@ def _dbrx_update_causal_mask_latest(
16311638
class DBRXModelPatcher(DecoderModelPatcher):
16321639
def __enter__(self):
16331640
super().__enter__()
1641+
# dbrx has some accuracy issues with bf16 with transformers >= 4.40
1642+
# fill causal mask in slightly different way for avoid overflow on some platforms
16341643
self._model.transformer._orig_update_causal_mask = self._model.transformer._update_causal_mask
16351644
self._model.transformer._update_causal_mask = types.MethodType(
16361645
_dbrx_update_causal_mask, self._model.transformer
16371646
)
16381647

16391648
for block in self._model.transformer.blocks:
16401649
rotary_emb = block.norm_attn_norm.attn.rotary_emb
1650+
# initialize inv_freq for torchscript tracing
16411651
if rotary_emb.inv_freq is None:
16421652
inv_freq = 1.0 / (
16431653
rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim)
16441654
)
16451655
rotary_emb.inv_freq = inv_freq
1656+
# remove continue-operator from iteration loop over experts
16461657
block.ffn.experts._orig_forward = block.ffn.experts.forward
16471658
block.ffn.experts.forward = types.MethodType(_dbrx_experts_forward, block.ffn.experts)
16481659

0 commit comments

Comments
 (0)