Skip to content

Commit b11becb

Browse files
committed
support more models in export
1 parent 7c1d38b commit b11becb

File tree

3 files changed

+334
-3
lines changed

3 files changed

+334
-3
lines changed

optimum/exporters/openvino/convert.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def ts_patched_forward(*args, **kwargs):
345345
input_dict = dict(zip(keys, tuple_input))
346346
kwargs[input_name] = input_dict
347347
outputs = patched_forward(*args, **kwargs)
348-
return tuple(outputs.values())
348+
return tuple([value if not isinstance(value, list) else tuple(value) for value in outputs.values()])
349349

350350
patcher.patched_forward = ts_patched_forward
351351

optimum/exporters/openvino/model_configs.py

+71
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
ChatGLMModelPatcher,
3636
GemmaModelPatcher,
3737
MixtralModelPatcher,
38+
OLMoModelPatcher,
3839
QwenModelPatcher,
3940
)
4041

@@ -400,3 +401,73 @@ class Starcoder2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
400401
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
401402
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
402403
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
404+
405+
406+
@register_in_tasks_manager("olmo", *["text-generation", "text-generation-with-past"], library_name="transformers")
407+
class OLMoOpenVINOConfig(TextDecoderOnnxConfig):
408+
# OLMo does not require position_ids input.
409+
DEFAULT_ONNX_OPSET = 13
410+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
411+
412+
def patch_model_for_export(
413+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
414+
) -> "ModelPatcher":
415+
return OLMoModelPatcher(self, model, model_kwargs=model_kwargs)
416+
417+
418+
@register_in_tasks_manager("internlm2", *["text-generation", "text-generation-with-past"], library_name="transformers")
419+
class InternLM2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
420+
DEFAULT_ONNX_OPSET = 14
421+
422+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
423+
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
424+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
425+
426+
427+
class DeciDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
428+
def __init__(
429+
self,
430+
task: str,
431+
normalized_config: NormalizedTextConfig,
432+
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
433+
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
434+
random_batch_size_range: Optional[Tuple[int, int]] = None,
435+
random_sequence_length_range: Optional[Tuple[int, int]] = None,
436+
**kwargs,
437+
):
438+
super().__init__(
439+
task=task,
440+
normalized_config=normalized_config,
441+
batch_size=batch_size,
442+
sequence_length=sequence_length,
443+
random_batch_size_range=random_batch_size_range,
444+
random_sequence_length_range=random_sequence_length_range,
445+
)
446+
self.num_key_value_heads_per_layer = normalized_config.num_key_value_heads_per_layer
447+
448+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
449+
past_key_values = []
450+
451+
for layer_id in range(self.num_layers):
452+
shape = (
453+
self.batch_size,
454+
self.num_key_value_heads_per_layer[layer_id],
455+
self.sequence_length,
456+
self.hidden_size // self.num_attention_heads,
457+
)
458+
past_key_values.append(
459+
(
460+
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
461+
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
462+
)
463+
)
464+
return past_key_values
465+
466+
467+
@register_in_tasks_manager("deci", *["text-generation", "text-generation-with-past"], library_name="transformers")
468+
class DeciOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
469+
DEFAULT_ONNX_OPSET = 14
470+
471+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DeciDummyPastKeyValuesGenerator)
472+
DUMMY_PKV_GENERATOR_CLASS = DeciDummyPastKeyValuesGenerator
473+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

optimum/exporters/openvino/model_patcher.py

+262-2
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
# limitations under the License.
1414

1515
import logging as log
16+
import math
1617
import types
17-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
18+
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Union
1819

1920
import torch
2021
import torch.nn.functional as F
@@ -513,5 +514,264 @@ def __init__(
513514
):
514515
super().__init__(config, model, model_kwargs)
515516
# model has first inference buffers initialization
516-
if self._model.lm_head.first_flag:
517+
if hasattr(self._model.lm_head, "first_flag"):
517518
self._model(torch.ones((1, 10), dtype=torch.int64), torch.ones((1, 10), dtype=torch.int64))
519+
520+
521+
class OlmoOutput(NamedTuple):
522+
logits: torch.FloatTensor
523+
"""
524+
A tensor of shape `(batch_size, seq_len, vocab_size)` representing the log probabilities
525+
for the next token *before* normalization via (log) softmax.
526+
"""
527+
528+
attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]]
529+
"""
530+
Attention keys and values from each block.
531+
"""
532+
533+
hidden_states: Optional[Tuple[torch.Tensor]]
534+
"""
535+
Hidden states from each block.
536+
"""
537+
538+
539+
def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False):
540+
"""
541+
Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf``
542+
is ``True`` and to replace ``float("inf")`` with the maximum value of the dtype when ``check_pos_inf`` is ``True``.
543+
"""
544+
if check_neg_inf:
545+
x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min)
546+
if check_pos_inf:
547+
x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max)
548+
549+
550+
def _olmo_model_forward(
551+
self,
552+
input_ids: torch.LongTensor,
553+
input_embeddings: Optional[torch.FloatTensor] = None,
554+
attention_mask: Optional[torch.Tensor] = None,
555+
attention_bias: Optional[torch.Tensor] = None,
556+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor]]] = None,
557+
use_cache: bool = False,
558+
last_logits_only: bool = False,
559+
output_hidden_states: Optional[bool] = None,
560+
):
561+
output_hidden_states = output_hidden_states if output_hidden_states is not None else False
562+
563+
if past_key_values:
564+
assert len(past_key_values) == self.config.n_layers
565+
566+
batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2]
567+
if past_key_values is None:
568+
past_length = 0
569+
else:
570+
past_length = past_key_values[0][0].size(-2)
571+
572+
# Get embeddings of input.
573+
# shape: (batch_size, seq_len, d_model)
574+
x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore
575+
576+
if not (self.config.alibi or self.config.rope):
577+
# Get positional embeddings.
578+
# shape: (1, seq_len)
579+
pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0)
580+
# shape: (1, seq_len, d_model)
581+
pos_emb = self.transformer.wpe(pos) # type: ignore
582+
x = pos_emb + x
583+
584+
# Add input + positional embeddings and apply dropout.
585+
# shape: (batch_size, seq_len, d_model)
586+
x = self.transformer.emb_drop(x) # type: ignore
587+
588+
# Transform the attention mask into what the blocks expect.
589+
if attention_mask is not None:
590+
# shape: (batch_size, 1, 1, seq_len)
591+
attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :]
592+
attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
593+
594+
# Merge attention mask with attention bias.
595+
if attention_bias is not None or attention_mask is not None or self.config.alibi or past_key_values is not None:
596+
if attention_bias is None and self.config.alibi:
597+
attention_bias = self.get_causal_attention_bias(
598+
past_length + seq_len, x.device
599+
) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
600+
elif attention_bias is None:
601+
attention_bias = self.get_causal_attention_bias(past_length + seq_len, x.device)
602+
elif attention_bias.dtype in (torch.int8, torch.bool):
603+
attention_bias = attention_bias.to(dtype=torch.float)
604+
attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)
605+
606+
# Transform to the right shape and data type.
607+
mask_len = seq_len
608+
if attention_mask is not None:
609+
mask_len = attention_mask.shape[-1]
610+
elif past_key_values is not None:
611+
mask_len = past_key_values[0][0].shape[-2] + seq_len
612+
attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float)
613+
614+
# Add in the masking bias.
615+
if attention_mask is not None:
616+
attention_bias = attention_bias + attention_mask
617+
# Might get -infs after adding attention mask, since dtype.min + dtype.min = -inf.
618+
# `F.scaled_dot_product_attention()` doesn't handle -inf like you'd expect, instead
619+
# it can produce NaNs.
620+
ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False)
621+
622+
attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
623+
624+
# decoder layers
625+
all_hidden_states = []
626+
627+
# Apply blocks one-by-one.
628+
if self.config.block_group_size == 1:
629+
for block_idx, block in enumerate(self.transformer.blocks):
630+
if output_hidden_states:
631+
# add hidden states
632+
all_hidden_states.append(x)
633+
634+
layer_past = None if past_key_values is None else past_key_values[block_idx]
635+
# shape: (batch_size, seq_len, d_model)
636+
x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
637+
if attn_key_values is not None:
638+
assert cache is not None
639+
attn_key_values.append(cache)
640+
else:
641+
for group_idx, block_group in enumerate(self.transformer.block_groups):
642+
if output_hidden_states:
643+
# add hidden states
644+
all_hidden_states.append(x)
645+
646+
layers_past = (
647+
None
648+
if past_key_values is None
649+
else past_key_values[
650+
group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size
651+
]
652+
)
653+
x, cache = block_group(x, attention_bias=attention_bias, layers_past=layers_past, use_cache=use_cache)
654+
if attn_key_values is not None:
655+
assert cache is not None
656+
attn_key_values.extend(cache)
657+
658+
if last_logits_only:
659+
# shape: (batch_size, 1, d_model)
660+
x = x[:, -1, :].unsqueeze(1)
661+
662+
# Apply final layer norm.
663+
# shape: (batch_size, seq_len or 1, d_model)
664+
x = self.transformer.ln_f(x) # type: ignore
665+
if output_hidden_states:
666+
# add final hidden state post-final-layernorm, following HuggingFace's convention
667+
all_hidden_states.append(x)
668+
669+
# Get logits.
670+
# shape: (batch_size, seq_len or 1, vocab_size)
671+
if self.config.weight_tying:
672+
logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore
673+
else:
674+
logits = self.transformer.ff_out(x) # type: ignore
675+
if self.config.scale_logits:
676+
logits.mul_(1 / math.sqrt(self.config.d_model))
677+
678+
return OlmoOutput(logits=logits, attn_key_values=attn_key_values, hidden_states=tuple(all_hidden_states) if output_hidden_states else None) # type: ignore[arg-type]
679+
680+
681+
def _olmo_causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor:
682+
att_bias = torch.triu(
683+
torch.ones(seq_len, seq_len, device=device, dtype=torch.float),
684+
diagonal=1,
685+
)
686+
att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min)
687+
return att_bias.view(1, 1, seq_len, seq_len) # type: ignore
688+
689+
690+
def _olmo_get_causal_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor:
691+
if hasattr(self, "causal_bias") and self.causal_bias.shape[-1] >= seq_len:
692+
return self.causal_bias.to(device)
693+
with torch.autocast(device.type, enabled=False):
694+
causal_bias = _olmo_causal_attention_bias(seq_len, device)
695+
self.register_buffer("causal_bias", causal_bias)
696+
return causal_bias
697+
698+
699+
def _olmo_alibi_attention_bias(seq_len: int, config, device: torch.device) -> torch.FloatTensor:
700+
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len)
701+
"""
702+
A tensor of shape `(batch_size, seq_len, vocab_size)` representing the log probabilities
703+
for the next token *before* normalization via (log) softmax.
704+
"""
705+
# shape: (1, 1, seq_len, seq_len)
706+
alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1)
707+
alibi_bias.abs_().mul_(-1)
708+
709+
# shape: (n_heads,)
710+
m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device)
711+
m.mul_(config.alibi_bias_max / config.n_heads)
712+
713+
# shape: (1, n_heads, seq_len, seq_len)
714+
return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) # type: ignore
715+
716+
717+
def _olmo_get_alibi_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor:
718+
alibi_bias = getattr(self, "alibi_attention_bias", None)
719+
if alibi_bias is not None and alibi_bias.shape[-1] >= seq_len:
720+
if alibi_bias.device != device:
721+
alibi_bias = alibi_bias.to(device)
722+
return alibi_bias
723+
with torch.autocast(device.type, enabled=False):
724+
alibi_bias = _olmo_alibi_attention_bias(seq_len, self.config, device)
725+
self.register_buffer("alibi_attention_bias", alibi_bias)
726+
return alibi_bias
727+
728+
729+
def _olmo_get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
730+
if (
731+
hasattr(self, "rope_pos_sin")
732+
and hasattr(self, "rope_pos_cos")
733+
and self.rope_pos_sin.shape[-2] >= seq_len
734+
and self.rope_pos_cos.shape[-2] >= seq_len
735+
):
736+
return self.rope_pos_sin.to(device)[:, :, :seq_len, :], self.rope_pos_sin.to(device)[:, :, :seq_len, :]
737+
738+
with torch.autocast(device.type, enabled=False):
739+
dim = self.config.d_model // self.config.n_heads
740+
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim))
741+
seq = torch.arange(seq_len, device=device, dtype=torch.float)
742+
freqs = torch.einsum("i , j -> i j", seq, inv_freq)
743+
positions = torch.cat((freqs, freqs), dim=-1)
744+
pos_sin, pos_cos = positions.sin()[None, None, :, :], positions.cos()[None, None, :, :]
745+
746+
self.register_buffer("rope_pos_sin", pos_sin)
747+
self.register_buffer("rope_pos_cos", pos_cos)
748+
return pos_sin, pos_cos
749+
750+
751+
class OLMoModelPatcher(DecoderModelPatcher):
752+
def __enter__(self):
753+
super().__enter__()
754+
# model uses custom cache buffers for storing rotary_embeddings and attention biases.
755+
# these objects are nontracable, replace them with standard torch tensors during export
756+
self._model.model._orig_forward = self._model.model.forward
757+
self._model.model._orig_get_alibi_attention_bias = self._model.model.get_alibi_attention_bias
758+
self._model.model.forward = types.MethodType(_olmo_model_forward, self._model.model)
759+
self._model.model.get_alibi_attention_bias = types.MethodType(
760+
_olmo_get_alibi_attention_bias, self._model.model
761+
)
762+
self._model.model.get_alibi_attention_bias(self._model.config.max_sequence_length, torch.device("cpu"))
763+
self._model.model.get_causal_attention_bias = types.MethodType(
764+
_olmo_get_causal_attention_bias, self._model.model
765+
)
766+
self._model.model.get_causal_attention_bias(self._model.config.max_sequence_length, torch.device("cpu"))
767+
for block in self._model.model.transformer.blocks:
768+
block.rotary_emb._orig_get_rotary_embedding = block.rotary_emb.get_rotary_embedding
769+
block.rotary_emb.get_rotary_embedding = types.MethodType(_olmo_get_rotary_embedding, block.rotary_emb)
770+
block.rotary_emb.get_rotary_embedding(self._model.config.max_sequence_length, torch.device("cpu"))
771+
772+
def __exit__(self, exc_type, exc_value, traceback):
773+
super().__exit__(exc_type, exc_value, traceback)
774+
self._model.model.forward = self._model.model._orig_forward
775+
self._model.model.get_alibi_attention_bias = self._model.model._orig_get_alibi_attention_bias
776+
for block in self._model.model.transformer.blocks:
777+
block.rotary_emb.get_rotary_embedding = block.rotary_emb._orig_get_rotary_embedding

0 commit comments

Comments
 (0)