|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | import logging as log
|
16 |
| -import math |
17 | 16 | import types
|
18 |
| -from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Union |
| 17 | +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union |
19 | 18 |
|
20 | 19 | import torch
|
21 | 20 | import torch.nn.functional as F
|
@@ -516,262 +515,3 @@ def __init__(
|
516 | 515 | # model has first inference buffers initialization
|
517 | 516 | if hasattr(self._model.lm_head, "first_flag"):
|
518 | 517 | self._model(torch.ones((1, 10), dtype=torch.int64), torch.ones((1, 10), dtype=torch.int64))
|
519 |
| - |
520 |
| - |
521 |
| -class OlmoOutput(NamedTuple): |
522 |
| - logits: torch.FloatTensor |
523 |
| - """ |
524 |
| - A tensor of shape `(batch_size, seq_len, vocab_size)` representing the log probabilities |
525 |
| - for the next token *before* normalization via (log) softmax. |
526 |
| - """ |
527 |
| - |
528 |
| - attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] |
529 |
| - """ |
530 |
| - Attention keys and values from each block. |
531 |
| - """ |
532 |
| - |
533 |
| - hidden_states: Optional[Tuple[torch.Tensor]] |
534 |
| - """ |
535 |
| - Hidden states from each block. |
536 |
| - """ |
537 |
| - |
538 |
| - |
539 |
| -def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False): |
540 |
| - """ |
541 |
| - Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf`` |
542 |
| - is ``True`` and to replace ``float("inf")`` with the maximum value of the dtype when ``check_pos_inf`` is ``True``. |
543 |
| - """ |
544 |
| - if check_neg_inf: |
545 |
| - x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min) |
546 |
| - if check_pos_inf: |
547 |
| - x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max) |
548 |
| - |
549 |
| - |
550 |
| -def _olmo_model_forward( |
551 |
| - self, |
552 |
| - input_ids: torch.LongTensor, |
553 |
| - input_embeddings: Optional[torch.FloatTensor] = None, |
554 |
| - attention_mask: Optional[torch.Tensor] = None, |
555 |
| - attention_bias: Optional[torch.Tensor] = None, |
556 |
| - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor]]] = None, |
557 |
| - use_cache: bool = False, |
558 |
| - last_logits_only: bool = False, |
559 |
| - output_hidden_states: Optional[bool] = None, |
560 |
| -): |
561 |
| - output_hidden_states = output_hidden_states if output_hidden_states is not None else False |
562 |
| - |
563 |
| - if past_key_values: |
564 |
| - assert len(past_key_values) == self.config.n_layers |
565 |
| - |
566 |
| - batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2] |
567 |
| - if past_key_values is None: |
568 |
| - past_length = 0 |
569 |
| - else: |
570 |
| - past_length = past_key_values[0][0].size(-2) |
571 |
| - |
572 |
| - # Get embeddings of input. |
573 |
| - # shape: (batch_size, seq_len, d_model) |
574 |
| - x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore |
575 |
| - |
576 |
| - if not (self.config.alibi or self.config.rope): |
577 |
| - # Get positional embeddings. |
578 |
| - # shape: (1, seq_len) |
579 |
| - pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0) |
580 |
| - # shape: (1, seq_len, d_model) |
581 |
| - pos_emb = self.transformer.wpe(pos) # type: ignore |
582 |
| - x = pos_emb + x |
583 |
| - |
584 |
| - # Add input + positional embeddings and apply dropout. |
585 |
| - # shape: (batch_size, seq_len, d_model) |
586 |
| - x = self.transformer.emb_drop(x) # type: ignore |
587 |
| - |
588 |
| - # Transform the attention mask into what the blocks expect. |
589 |
| - if attention_mask is not None: |
590 |
| - # shape: (batch_size, 1, 1, seq_len) |
591 |
| - attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :] |
592 |
| - attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min |
593 |
| - |
594 |
| - # Merge attention mask with attention bias. |
595 |
| - if attention_bias is not None or attention_mask is not None or self.config.alibi or past_key_values is not None: |
596 |
| - if attention_bias is None and self.config.alibi: |
597 |
| - attention_bias = self.get_causal_attention_bias( |
598 |
| - past_length + seq_len, x.device |
599 |
| - ) + self.get_alibi_attention_bias(past_length + seq_len, x.device) |
600 |
| - elif attention_bias is None: |
601 |
| - attention_bias = self.get_causal_attention_bias(past_length + seq_len, x.device) |
602 |
| - elif attention_bias.dtype in (torch.int8, torch.bool): |
603 |
| - attention_bias = attention_bias.to(dtype=torch.float) |
604 |
| - attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min) |
605 |
| - |
606 |
| - # Transform to the right shape and data type. |
607 |
| - mask_len = seq_len |
608 |
| - if attention_mask is not None: |
609 |
| - mask_len = attention_mask.shape[-1] |
610 |
| - elif past_key_values is not None: |
611 |
| - mask_len = past_key_values[0][0].shape[-2] + seq_len |
612 |
| - attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float) |
613 |
| - |
614 |
| - # Add in the masking bias. |
615 |
| - if attention_mask is not None: |
616 |
| - attention_bias = attention_bias + attention_mask |
617 |
| - # Might get -infs after adding attention mask, since dtype.min + dtype.min = -inf. |
618 |
| - # `F.scaled_dot_product_attention()` doesn't handle -inf like you'd expect, instead |
619 |
| - # it can produce NaNs. |
620 |
| - ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False) |
621 |
| - |
622 |
| - attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None |
623 |
| - |
624 |
| - # decoder layers |
625 |
| - all_hidden_states = [] |
626 |
| - |
627 |
| - # Apply blocks one-by-one. |
628 |
| - if self.config.block_group_size == 1: |
629 |
| - for block_idx, block in enumerate(self.transformer.blocks): |
630 |
| - if output_hidden_states: |
631 |
| - # add hidden states |
632 |
| - all_hidden_states.append(x) |
633 |
| - |
634 |
| - layer_past = None if past_key_values is None else past_key_values[block_idx] |
635 |
| - # shape: (batch_size, seq_len, d_model) |
636 |
| - x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache) |
637 |
| - if attn_key_values is not None: |
638 |
| - assert cache is not None |
639 |
| - attn_key_values.append(cache) |
640 |
| - else: |
641 |
| - for group_idx, block_group in enumerate(self.transformer.block_groups): |
642 |
| - if output_hidden_states: |
643 |
| - # add hidden states |
644 |
| - all_hidden_states.append(x) |
645 |
| - |
646 |
| - layers_past = ( |
647 |
| - None |
648 |
| - if past_key_values is None |
649 |
| - else past_key_values[ |
650 |
| - group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size |
651 |
| - ] |
652 |
| - ) |
653 |
| - x, cache = block_group(x, attention_bias=attention_bias, layers_past=layers_past, use_cache=use_cache) |
654 |
| - if attn_key_values is not None: |
655 |
| - assert cache is not None |
656 |
| - attn_key_values.extend(cache) |
657 |
| - |
658 |
| - if last_logits_only: |
659 |
| - # shape: (batch_size, 1, d_model) |
660 |
| - x = x[:, -1, :].unsqueeze(1) |
661 |
| - |
662 |
| - # Apply final layer norm. |
663 |
| - # shape: (batch_size, seq_len or 1, d_model) |
664 |
| - x = self.transformer.ln_f(x) # type: ignore |
665 |
| - if output_hidden_states: |
666 |
| - # add final hidden state post-final-layernorm, following HuggingFace's convention |
667 |
| - all_hidden_states.append(x) |
668 |
| - |
669 |
| - # Get logits. |
670 |
| - # shape: (batch_size, seq_len or 1, vocab_size) |
671 |
| - if self.config.weight_tying: |
672 |
| - logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore |
673 |
| - else: |
674 |
| - logits = self.transformer.ff_out(x) # type: ignore |
675 |
| - if self.config.scale_logits: |
676 |
| - logits.mul_(1 / math.sqrt(self.config.d_model)) |
677 |
| - |
678 |
| - return OlmoOutput(logits=logits, attn_key_values=attn_key_values, hidden_states=tuple(all_hidden_states) if output_hidden_states else None) # type: ignore[arg-type] |
679 |
| - |
680 |
| - |
681 |
| -def _olmo_causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor: |
682 |
| - att_bias = torch.triu( |
683 |
| - torch.ones(seq_len, seq_len, device=device, dtype=torch.float), |
684 |
| - diagonal=1, |
685 |
| - ) |
686 |
| - att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min) |
687 |
| - return att_bias.view(1, 1, seq_len, seq_len) # type: ignore |
688 |
| - |
689 |
| - |
690 |
| -def _olmo_get_causal_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor: |
691 |
| - if hasattr(self, "causal_bias") and self.causal_bias.shape[-1] >= seq_len: |
692 |
| - return self.causal_bias.to(device) |
693 |
| - with torch.autocast(device.type, enabled=False): |
694 |
| - causal_bias = _olmo_causal_attention_bias(seq_len, device) |
695 |
| - self.register_buffer("causal_bias", causal_bias) |
696 |
| - return causal_bias |
697 |
| - |
698 |
| - |
699 |
| -def _olmo_alibi_attention_bias(seq_len: int, config, device: torch.device) -> torch.FloatTensor: |
700 |
| - alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len) |
701 |
| - """ |
702 |
| - A tensor of shape `(batch_size, seq_len, vocab_size)` representing the log probabilities |
703 |
| - for the next token *before* normalization via (log) softmax. |
704 |
| - """ |
705 |
| - # shape: (1, 1, seq_len, seq_len) |
706 |
| - alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1) |
707 |
| - alibi_bias.abs_().mul_(-1) |
708 |
| - |
709 |
| - # shape: (n_heads,) |
710 |
| - m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device) |
711 |
| - m.mul_(config.alibi_bias_max / config.n_heads) |
712 |
| - |
713 |
| - # shape: (1, n_heads, seq_len, seq_len) |
714 |
| - return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) # type: ignore |
715 |
| - |
716 |
| - |
717 |
| -def _olmo_get_alibi_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor: |
718 |
| - alibi_bias = getattr(self, "alibi_attention_bias", None) |
719 |
| - if alibi_bias is not None and alibi_bias.shape[-1] >= seq_len: |
720 |
| - if alibi_bias.device != device: |
721 |
| - alibi_bias = alibi_bias.to(device) |
722 |
| - return alibi_bias |
723 |
| - with torch.autocast(device.type, enabled=False): |
724 |
| - alibi_bias = _olmo_alibi_attention_bias(seq_len, self.config, device) |
725 |
| - self.register_buffer("alibi_attention_bias", alibi_bias) |
726 |
| - return alibi_bias |
727 |
| - |
728 |
| - |
729 |
| -def _olmo_get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]: |
730 |
| - if ( |
731 |
| - hasattr(self, "rope_pos_sin") |
732 |
| - and hasattr(self, "rope_pos_cos") |
733 |
| - and self.rope_pos_sin.shape[-2] >= seq_len |
734 |
| - and self.rope_pos_cos.shape[-2] >= seq_len |
735 |
| - ): |
736 |
| - return self.rope_pos_sin.to(device)[:, :, :seq_len, :], self.rope_pos_sin.to(device)[:, :, :seq_len, :] |
737 |
| - |
738 |
| - with torch.autocast(device.type, enabled=False): |
739 |
| - dim = self.config.d_model // self.config.n_heads |
740 |
| - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim)) |
741 |
| - seq = torch.arange(seq_len, device=device, dtype=torch.float) |
742 |
| - freqs = torch.einsum("i , j -> i j", seq, inv_freq) |
743 |
| - positions = torch.cat((freqs, freqs), dim=-1) |
744 |
| - pos_sin, pos_cos = positions.sin()[None, None, :, :], positions.cos()[None, None, :, :] |
745 |
| - |
746 |
| - self.register_buffer("rope_pos_sin", pos_sin) |
747 |
| - self.register_buffer("rope_pos_cos", pos_cos) |
748 |
| - return pos_sin, pos_cos |
749 |
| - |
750 |
| - |
751 |
| -class OLMoModelPatcher(DecoderModelPatcher): |
752 |
| - def __enter__(self): |
753 |
| - super().__enter__() |
754 |
| - # model uses custom cache buffers for storing rotary_embeddings and attention biases. |
755 |
| - # these objects are nontracable, replace them with standard torch tensors during export |
756 |
| - self._model.model._orig_forward = self._model.model.forward |
757 |
| - self._model.model._orig_get_alibi_attention_bias = self._model.model.get_alibi_attention_bias |
758 |
| - self._model.model.forward = types.MethodType(_olmo_model_forward, self._model.model) |
759 |
| - self._model.model.get_alibi_attention_bias = types.MethodType( |
760 |
| - _olmo_get_alibi_attention_bias, self._model.model |
761 |
| - ) |
762 |
| - self._model.model.get_alibi_attention_bias(self._model.config.max_sequence_length, torch.device("cpu")) |
763 |
| - self._model.model.get_causal_attention_bias = types.MethodType( |
764 |
| - _olmo_get_causal_attention_bias, self._model.model |
765 |
| - ) |
766 |
| - self._model.model.get_causal_attention_bias(self._model.config.max_sequence_length, torch.device("cpu")) |
767 |
| - for block in self._model.model.transformer.blocks: |
768 |
| - block.rotary_emb._orig_get_rotary_embedding = block.rotary_emb.get_rotary_embedding |
769 |
| - block.rotary_emb.get_rotary_embedding = types.MethodType(_olmo_get_rotary_embedding, block.rotary_emb) |
770 |
| - block.rotary_emb.get_rotary_embedding(self._model.config.max_sequence_length, torch.device("cpu")) |
771 |
| - |
772 |
| - def __exit__(self, exc_type, exc_value, traceback): |
773 |
| - super().__exit__(exc_type, exc_value, traceback) |
774 |
| - self._model.model.forward = self._model.model._orig_forward |
775 |
| - self._model.model.get_alibi_attention_bias = self._model.model._orig_get_alibi_attention_bias |
776 |
| - for block in self._model.model.transformer.blocks: |
777 |
| - block.rotary_emb.get_rotary_embedding = block.rotary_emb._orig_get_rotary_embedding |
0 commit comments