Skip to content

Commit 333c619

Browse files
committed
support more models in export
1 parent 9af1b7c commit 333c619

File tree

2 files changed

+333
-2
lines changed

2 files changed

+333
-2
lines changed

optimum/exporters/openvino/model_configs.py

+71
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
ChatGLMModelPatcher,
3636
GemmaModelPatcher,
3737
MixtralModelPatcher,
38+
OLMoModelPatcher,
3839
QwenModelPatcher,
3940
)
4041

@@ -400,3 +401,73 @@ class Starcoder2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
400401
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
401402
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
402403
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
404+
405+
406+
@register_in_tasks_manager("olmo", *["text-generation", "text-generation-with-past"], library_name="transformers")
407+
class OLMoOpenVINOConfig(TextDecoderOnnxConfig):
408+
# OLMo does not require position_ids input.
409+
DEFAULT_ONNX_OPSET = 13
410+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
411+
412+
def patch_model_for_export(
413+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
414+
) -> "ModelPatcher":
415+
return OLMoModelPatcher(self, model, model_kwargs=model_kwargs)
416+
417+
418+
@register_in_tasks_manager("internln2", *["text-generation", "text-generation-with-past"], library_name="transformers")
419+
class InternLM2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
420+
DEFAULT_ONNX_OPSET = 14
421+
422+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
423+
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
424+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
425+
426+
427+
class DeciDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
428+
def __init__(
429+
self,
430+
task: str,
431+
normalized_config: NormalizedTextConfig,
432+
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
433+
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
434+
random_batch_size_range: Optional[Tuple[int, int]] = None,
435+
random_sequence_length_range: Optional[Tuple[int, int]] = None,
436+
**kwargs,
437+
):
438+
super().__init__(
439+
task=task,
440+
normalized_config=normalized_config,
441+
batch_size=batch_size,
442+
sequence_length=sequence_length,
443+
random_batch_size_range=random_batch_size_range,
444+
random_sequence_length_range=random_sequence_length_range,
445+
)
446+
self.num_key_value_heads_per_layer = normalized_config.num_key_value_heads_per_layer
447+
448+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
449+
past_key_values = []
450+
451+
for layer_id in range(self.num_layers):
452+
shape = (
453+
self.batch_size,
454+
self.num_key_value_heads_per_layer[layer_id],
455+
self.sequence_length,
456+
self.hidden_size // self.num_attention_heads,
457+
)
458+
past_key_values.append(
459+
(
460+
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
461+
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
462+
)
463+
)
464+
return past_key_values
465+
466+
467+
@register_in_tasks_manager("deci", *["text-generation", "text-generation-with-past"], library_name="transformers")
468+
class DeciOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
469+
DEFAULT_ONNX_OPSET = 14
470+
471+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DeciDummyPastKeyValuesGenerator)
472+
DUMMY_PKV_GENERATOR_CLASS = DeciDummyPastKeyValuesGenerator
473+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

optimum/exporters/openvino/model_patcher.py

+262-2
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
# limitations under the License.
1414

1515
import logging as log
16+
import math
1617
import types
17-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
18+
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Union
1819

1920
import torch
2021
import torch.nn.functional as F
@@ -509,5 +510,264 @@ def __init__(
509510
):
510511
super().__init__(config, model, model_kwargs)
511512
# model has first inference buffers initialization
512-
if self._model.lm_head.first_flag:
513+
if hasattr(self._model.lm_head, "first_flag"):
513514
self._model(torch.ones((1, 10), dtype=torch.int64), torch.ones((1, 10), dtype=torch.int64))
515+
516+
517+
class OlmoOutput(NamedTuple):
518+
logits: torch.FloatTensor
519+
"""
520+
A tensor of shape `(batch_size, seq_len, vocab_size)` representing the log probabilities
521+
for the next token *before* normalization via (log) softmax.
522+
"""
523+
524+
attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]]
525+
"""
526+
Attention keys and values from each block.
527+
"""
528+
529+
hidden_states: Optional[Tuple[torch.Tensor]]
530+
"""
531+
Hidden states from each block.
532+
"""
533+
534+
535+
def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False):
536+
"""
537+
Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf``
538+
is ``True`` and to replace ``float("inf")`` with the maximum value of the dtype when ``check_pos_inf`` is ``True``.
539+
"""
540+
if check_neg_inf:
541+
x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min)
542+
if check_pos_inf:
543+
x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max)
544+
545+
546+
def _olmo_model_forward(
547+
self,
548+
input_ids: torch.LongTensor,
549+
input_embeddings: Optional[torch.FloatTensor] = None,
550+
attention_mask: Optional[torch.Tensor] = None,
551+
attention_bias: Optional[torch.Tensor] = None,
552+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor]]] = None,
553+
use_cache: bool = False,
554+
last_logits_only: bool = False,
555+
output_hidden_states: Optional[bool] = None,
556+
):
557+
output_hidden_states = output_hidden_states if output_hidden_states is not None else False
558+
559+
if past_key_values:
560+
assert len(past_key_values) == self.config.n_layers
561+
562+
batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2]
563+
if past_key_values is None:
564+
past_length = 0
565+
else:
566+
past_length = past_key_values[0][0].size(-2)
567+
568+
# Get embeddings of input.
569+
# shape: (batch_size, seq_len, d_model)
570+
x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore
571+
572+
if not (self.config.alibi or self.config.rope):
573+
# Get positional embeddings.
574+
# shape: (1, seq_len)
575+
pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0)
576+
# shape: (1, seq_len, d_model)
577+
pos_emb = self.transformer.wpe(pos) # type: ignore
578+
x = pos_emb + x
579+
580+
# Add input + positional embeddings and apply dropout.
581+
# shape: (batch_size, seq_len, d_model)
582+
x = self.transformer.emb_drop(x) # type: ignore
583+
584+
# Transform the attention mask into what the blocks expect.
585+
if attention_mask is not None:
586+
# shape: (batch_size, 1, 1, seq_len)
587+
attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :]
588+
attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
589+
590+
# Merge attention mask with attention bias.
591+
if attention_bias is not None or attention_mask is not None or self.config.alibi or past_key_values is not None:
592+
if attention_bias is None and self.config.alibi:
593+
attention_bias = self.get_causal_attention_bias(
594+
past_length + seq_len, x.device
595+
) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
596+
elif attention_bias is None:
597+
attention_bias = self.get_causal_attention_bias(past_length + seq_len, x.device)
598+
elif attention_bias.dtype in (torch.int8, torch.bool):
599+
attention_bias = attention_bias.to(dtype=torch.float)
600+
attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)
601+
602+
# Transform to the right shape and data type.
603+
mask_len = seq_len
604+
if attention_mask is not None:
605+
mask_len = attention_mask.shape[-1]
606+
elif past_key_values is not None:
607+
mask_len = past_key_values[0][0].shape[-2] + seq_len
608+
attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float)
609+
610+
# Add in the masking bias.
611+
if attention_mask is not None:
612+
attention_bias = attention_bias + attention_mask
613+
# Might get -infs after adding attention mask, since dtype.min + dtype.min = -inf.
614+
# `F.scaled_dot_product_attention()` doesn't handle -inf like you'd expect, instead
615+
# it can produce NaNs.
616+
ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False)
617+
618+
attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
619+
620+
# decoder layers
621+
all_hidden_states = []
622+
623+
# Apply blocks one-by-one.
624+
if self.config.block_group_size == 1:
625+
for block_idx, block in enumerate(self.transformer.blocks):
626+
if output_hidden_states:
627+
# add hidden states
628+
all_hidden_states.append(x)
629+
630+
layer_past = None if past_key_values is None else past_key_values[block_idx]
631+
# shape: (batch_size, seq_len, d_model)
632+
x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
633+
if attn_key_values is not None:
634+
assert cache is not None
635+
attn_key_values.append(cache)
636+
else:
637+
for group_idx, block_group in enumerate(self.transformer.block_groups):
638+
if output_hidden_states:
639+
# add hidden states
640+
all_hidden_states.append(x)
641+
642+
layers_past = (
643+
None
644+
if past_key_values is None
645+
else past_key_values[
646+
group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size
647+
]
648+
)
649+
x, cache = block_group(x, attention_bias=attention_bias, layers_past=layers_past, use_cache=use_cache)
650+
if attn_key_values is not None:
651+
assert cache is not None
652+
attn_key_values.extend(cache)
653+
654+
if last_logits_only:
655+
# shape: (batch_size, 1, d_model)
656+
x = x[:, -1, :].unsqueeze(1)
657+
658+
# Apply final layer norm.
659+
# shape: (batch_size, seq_len or 1, d_model)
660+
x = self.transformer.ln_f(x) # type: ignore
661+
if output_hidden_states:
662+
# add final hidden state post-final-layernorm, following HuggingFace's convention
663+
all_hidden_states.append(x)
664+
665+
# Get logits.
666+
# shape: (batch_size, seq_len or 1, vocab_size)
667+
if self.config.weight_tying:
668+
logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore
669+
else:
670+
logits = self.transformer.ff_out(x) # type: ignore
671+
if self.config.scale_logits:
672+
logits.mul_(1 / math.sqrt(self.config.d_model))
673+
674+
return OlmoOutput(logits=logits, attn_key_values=attn_key_values, hidden_states=tuple(all_hidden_states) if output_hidden_states else None) # type: ignore[arg-type]
675+
676+
677+
def _olmo_causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor:
678+
att_bias = torch.triu(
679+
torch.ones(seq_len, seq_len, device=device, dtype=torch.float),
680+
diagonal=1,
681+
)
682+
att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min)
683+
return att_bias.view(1, 1, seq_len, seq_len) # type: ignore
684+
685+
686+
def _olmo_get_causal_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor:
687+
if hasattr(self, "causal_bias") and self.causal_bias.shape[-1] >= seq_len:
688+
return self.causal_bias.to(device)
689+
with torch.autocast(device.type, enabled=False):
690+
causal_bias = _olmo_causal_attention_bias(seq_len, device)
691+
self.register_buffer("causal_bias", causal_bias)
692+
return causal_bias
693+
694+
695+
def _olmo_alibi_attention_bias(seq_len: int, config, device: torch.device) -> torch.FloatTensor:
696+
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len)
697+
"""
698+
A tensor of shape `(batch_size, seq_len, vocab_size)` representing the log probabilities
699+
for the next token *before* normalization via (log) softmax.
700+
"""
701+
# shape: (1, 1, seq_len, seq_len)
702+
alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1)
703+
alibi_bias.abs_().mul_(-1)
704+
705+
# shape: (n_heads,)
706+
m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device)
707+
m.mul_(config.alibi_bias_max / config.n_heads)
708+
709+
# shape: (1, n_heads, seq_len, seq_len)
710+
return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) # type: ignore
711+
712+
713+
def _olmo_get_alibi_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor:
714+
alibi_bias = getattr(self, "alibi_attention_bias", None)
715+
if alibi_bias is not None and alibi_bias.shape[-1] >= seq_len:
716+
if alibi_bias.device != device:
717+
alibi_bias = alibi_bias.to(device)
718+
return alibi_bias
719+
with torch.autocast(device.type, enabled=False):
720+
alibi_bias = _olmo_alibi_attention_bias(seq_len, self.config, device)
721+
self.register_buffer("alibi_attention_bias", alibi_bias)
722+
return alibi_bias
723+
724+
725+
def _olmo_get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
726+
if (
727+
hasattr(self, "rope_pos_sin")
728+
and hasattr(self, "rope_pos_cos")
729+
and self.rope_pos_sin.shape[-2] >= seq_len
730+
and self.rope_pos_cos.shape[-2] >= seq_len
731+
):
732+
return self.rope_pos_sin.to(device)[:, :, :seq_len, :], self.rope_pos_sin.to(device)[:, :, :seq_len, :]
733+
734+
with torch.autocast(device.type, enabled=False):
735+
dim = self.config.d_model // self.config.n_heads
736+
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim))
737+
seq = torch.arange(seq_len, device=device, dtype=torch.float)
738+
freqs = torch.einsum("i , j -> i j", seq, inv_freq)
739+
positions = torch.cat((freqs, freqs), dim=-1)
740+
pos_sin, pos_cos = positions.sin()[None, None, :, :], positions.cos()[None, None, :, :]
741+
742+
self.register_buffer("rope_pos_sin", pos_sin)
743+
self.register_buffer("rope_pos_cos", pos_cos)
744+
return pos_sin, pos_cos
745+
746+
747+
class OLMoModelPatcher(DecoderModelPatcher):
748+
def __enter__(self):
749+
super().__enter__()
750+
# model uses custom cache buffers for storing rotary_embeddings and attention biases.
751+
# these objects are nontracable, replace them with standard torch tensors during export
752+
self._model.model._orig_forward = self._model.model.forward
753+
self._model.model._orig_get_alibi_attention_bias = self._model.model.get_alibi_attention_bias
754+
self._model.model.forward = types.MethodType(_olmo_model_forward, self._model.model)
755+
self._model.model.get_alibi_attention_bias = types.MethodType(
756+
_olmo_get_alibi_attention_bias, self._model.model
757+
)
758+
self._model.model.get_alibi_attention_bias(self._model.config.max_sequence_length, torch.device("cpu"))
759+
self._model.model.get_causal_attention_bias = types.MethodType(
760+
_olmo_get_causal_attention_bias, self._model.model
761+
)
762+
self._model.model.get_causal_attention_bias(self._model.config.max_sequence_length, torch.device("cpu"))
763+
for block in self._model.model.transformer.blocks:
764+
block.rotary_emb._orig_get_rotary_embedding = block.rotary_emb.get_rotary_embedding
765+
block.rotary_emb.get_rotary_embedding = types.MethodType(_olmo_get_rotary_embedding, block.rotary_emb)
766+
block.rotary_emb.get_rotary_embedding(self._model.config.max_sequence_length, torch.device("cpu"))
767+
768+
def __exit__(self, exc_type, exc_value, traceback):
769+
super().__exit__(exc_type, exc_value, traceback)
770+
self._model.model.forward = self._model.model._orig_forward
771+
self._model.model.get_alibi_attention_bias = self._model.model._orig_get_alibi_attention_bias
772+
for block in self._model.model.transformer.blocks:
773+
block.rotary_emb.get_rotary_embedding = block.rotary_emb._orig_get_rotary_embedding

0 commit comments

Comments
 (0)