|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | import logging as log
|
| 16 | +import math |
16 | 17 | import types
|
17 |
| -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union |
| 18 | +from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Union |
18 | 19 |
|
19 | 20 | import torch
|
20 | 21 | import torch.nn.functional as F
|
@@ -509,5 +510,264 @@ def __init__(
|
509 | 510 | ):
|
510 | 511 | super().__init__(config, model, model_kwargs)
|
511 | 512 | # model has first inference buffers initialization
|
512 |
| - if self._model.lm_head.first_flag: |
| 513 | + if hasattr(self._model.lm_head, "first_flag"): |
513 | 514 | self._model(torch.ones((1, 10), dtype=torch.int64), torch.ones((1, 10), dtype=torch.int64))
|
| 515 | + |
| 516 | + |
| 517 | +class OlmoOutput(NamedTuple): |
| 518 | + logits: torch.FloatTensor |
| 519 | + """ |
| 520 | + A tensor of shape `(batch_size, seq_len, vocab_size)` representing the log probabilities |
| 521 | + for the next token *before* normalization via (log) softmax. |
| 522 | + """ |
| 523 | + |
| 524 | + attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] |
| 525 | + """ |
| 526 | + Attention keys and values from each block. |
| 527 | + """ |
| 528 | + |
| 529 | + hidden_states: Optional[Tuple[torch.Tensor]] |
| 530 | + """ |
| 531 | + Hidden states from each block. |
| 532 | + """ |
| 533 | + |
| 534 | + |
| 535 | +def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False): |
| 536 | + """ |
| 537 | + Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf`` |
| 538 | + is ``True`` and to replace ``float("inf")`` with the maximum value of the dtype when ``check_pos_inf`` is ``True``. |
| 539 | + """ |
| 540 | + if check_neg_inf: |
| 541 | + x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min) |
| 542 | + if check_pos_inf: |
| 543 | + x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max) |
| 544 | + |
| 545 | + |
| 546 | +def _olmo_model_forward( |
| 547 | + self, |
| 548 | + input_ids: torch.LongTensor, |
| 549 | + input_embeddings: Optional[torch.FloatTensor] = None, |
| 550 | + attention_mask: Optional[torch.Tensor] = None, |
| 551 | + attention_bias: Optional[torch.Tensor] = None, |
| 552 | + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor]]] = None, |
| 553 | + use_cache: bool = False, |
| 554 | + last_logits_only: bool = False, |
| 555 | + output_hidden_states: Optional[bool] = None, |
| 556 | +): |
| 557 | + output_hidden_states = output_hidden_states if output_hidden_states is not None else False |
| 558 | + |
| 559 | + if past_key_values: |
| 560 | + assert len(past_key_values) == self.config.n_layers |
| 561 | + |
| 562 | + batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2] |
| 563 | + if past_key_values is None: |
| 564 | + past_length = 0 |
| 565 | + else: |
| 566 | + past_length = past_key_values[0][0].size(-2) |
| 567 | + |
| 568 | + # Get embeddings of input. |
| 569 | + # shape: (batch_size, seq_len, d_model) |
| 570 | + x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore |
| 571 | + |
| 572 | + if not (self.config.alibi or self.config.rope): |
| 573 | + # Get positional embeddings. |
| 574 | + # shape: (1, seq_len) |
| 575 | + pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0) |
| 576 | + # shape: (1, seq_len, d_model) |
| 577 | + pos_emb = self.transformer.wpe(pos) # type: ignore |
| 578 | + x = pos_emb + x |
| 579 | + |
| 580 | + # Add input + positional embeddings and apply dropout. |
| 581 | + # shape: (batch_size, seq_len, d_model) |
| 582 | + x = self.transformer.emb_drop(x) # type: ignore |
| 583 | + |
| 584 | + # Transform the attention mask into what the blocks expect. |
| 585 | + if attention_mask is not None: |
| 586 | + # shape: (batch_size, 1, 1, seq_len) |
| 587 | + attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :] |
| 588 | + attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min |
| 589 | + |
| 590 | + # Merge attention mask with attention bias. |
| 591 | + if attention_bias is not None or attention_mask is not None or self.config.alibi or past_key_values is not None: |
| 592 | + if attention_bias is None and self.config.alibi: |
| 593 | + attention_bias = self.get_causal_attention_bias( |
| 594 | + past_length + seq_len, x.device |
| 595 | + ) + self.get_alibi_attention_bias(past_length + seq_len, x.device) |
| 596 | + elif attention_bias is None: |
| 597 | + attention_bias = self.get_causal_attention_bias(past_length + seq_len, x.device) |
| 598 | + elif attention_bias.dtype in (torch.int8, torch.bool): |
| 599 | + attention_bias = attention_bias.to(dtype=torch.float) |
| 600 | + attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min) |
| 601 | + |
| 602 | + # Transform to the right shape and data type. |
| 603 | + mask_len = seq_len |
| 604 | + if attention_mask is not None: |
| 605 | + mask_len = attention_mask.shape[-1] |
| 606 | + elif past_key_values is not None: |
| 607 | + mask_len = past_key_values[0][0].shape[-2] + seq_len |
| 608 | + attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float) |
| 609 | + |
| 610 | + # Add in the masking bias. |
| 611 | + if attention_mask is not None: |
| 612 | + attention_bias = attention_bias + attention_mask |
| 613 | + # Might get -infs after adding attention mask, since dtype.min + dtype.min = -inf. |
| 614 | + # `F.scaled_dot_product_attention()` doesn't handle -inf like you'd expect, instead |
| 615 | + # it can produce NaNs. |
| 616 | + ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False) |
| 617 | + |
| 618 | + attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None |
| 619 | + |
| 620 | + # decoder layers |
| 621 | + all_hidden_states = [] |
| 622 | + |
| 623 | + # Apply blocks one-by-one. |
| 624 | + if self.config.block_group_size == 1: |
| 625 | + for block_idx, block in enumerate(self.transformer.blocks): |
| 626 | + if output_hidden_states: |
| 627 | + # add hidden states |
| 628 | + all_hidden_states.append(x) |
| 629 | + |
| 630 | + layer_past = None if past_key_values is None else past_key_values[block_idx] |
| 631 | + # shape: (batch_size, seq_len, d_model) |
| 632 | + x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache) |
| 633 | + if attn_key_values is not None: |
| 634 | + assert cache is not None |
| 635 | + attn_key_values.append(cache) |
| 636 | + else: |
| 637 | + for group_idx, block_group in enumerate(self.transformer.block_groups): |
| 638 | + if output_hidden_states: |
| 639 | + # add hidden states |
| 640 | + all_hidden_states.append(x) |
| 641 | + |
| 642 | + layers_past = ( |
| 643 | + None |
| 644 | + if past_key_values is None |
| 645 | + else past_key_values[ |
| 646 | + group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size |
| 647 | + ] |
| 648 | + ) |
| 649 | + x, cache = block_group(x, attention_bias=attention_bias, layers_past=layers_past, use_cache=use_cache) |
| 650 | + if attn_key_values is not None: |
| 651 | + assert cache is not None |
| 652 | + attn_key_values.extend(cache) |
| 653 | + |
| 654 | + if last_logits_only: |
| 655 | + # shape: (batch_size, 1, d_model) |
| 656 | + x = x[:, -1, :].unsqueeze(1) |
| 657 | + |
| 658 | + # Apply final layer norm. |
| 659 | + # shape: (batch_size, seq_len or 1, d_model) |
| 660 | + x = self.transformer.ln_f(x) # type: ignore |
| 661 | + if output_hidden_states: |
| 662 | + # add final hidden state post-final-layernorm, following HuggingFace's convention |
| 663 | + all_hidden_states.append(x) |
| 664 | + |
| 665 | + # Get logits. |
| 666 | + # shape: (batch_size, seq_len or 1, vocab_size) |
| 667 | + if self.config.weight_tying: |
| 668 | + logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore |
| 669 | + else: |
| 670 | + logits = self.transformer.ff_out(x) # type: ignore |
| 671 | + if self.config.scale_logits: |
| 672 | + logits.mul_(1 / math.sqrt(self.config.d_model)) |
| 673 | + |
| 674 | + return OlmoOutput(logits=logits, attn_key_values=attn_key_values, hidden_states=tuple(all_hidden_states) if output_hidden_states else None) # type: ignore[arg-type] |
| 675 | + |
| 676 | + |
| 677 | +def _olmo_causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor: |
| 678 | + att_bias = torch.triu( |
| 679 | + torch.ones(seq_len, seq_len, device=device, dtype=torch.float), |
| 680 | + diagonal=1, |
| 681 | + ) |
| 682 | + att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min) |
| 683 | + return att_bias.view(1, 1, seq_len, seq_len) # type: ignore |
| 684 | + |
| 685 | + |
| 686 | +def _olmo_get_causal_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor: |
| 687 | + if hasattr(self, "causal_bias") and self.causal_bias.shape[-1] >= seq_len: |
| 688 | + return self.causal_bias.to(device) |
| 689 | + with torch.autocast(device.type, enabled=False): |
| 690 | + causal_bias = _olmo_causal_attention_bias(seq_len, device) |
| 691 | + self.register_buffer("causal_bias", causal_bias) |
| 692 | + return causal_bias |
| 693 | + |
| 694 | + |
| 695 | +def _olmo_alibi_attention_bias(seq_len: int, config, device: torch.device) -> torch.FloatTensor: |
| 696 | + alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len) |
| 697 | + """ |
| 698 | + A tensor of shape `(batch_size, seq_len, vocab_size)` representing the log probabilities |
| 699 | + for the next token *before* normalization via (log) softmax. |
| 700 | + """ |
| 701 | + # shape: (1, 1, seq_len, seq_len) |
| 702 | + alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1) |
| 703 | + alibi_bias.abs_().mul_(-1) |
| 704 | + |
| 705 | + # shape: (n_heads,) |
| 706 | + m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device) |
| 707 | + m.mul_(config.alibi_bias_max / config.n_heads) |
| 708 | + |
| 709 | + # shape: (1, n_heads, seq_len, seq_len) |
| 710 | + return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) # type: ignore |
| 711 | + |
| 712 | + |
| 713 | +def _olmo_get_alibi_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor: |
| 714 | + alibi_bias = getattr(self, "alibi_attention_bias", None) |
| 715 | + if alibi_bias is not None and alibi_bias.shape[-1] >= seq_len: |
| 716 | + if alibi_bias.device != device: |
| 717 | + alibi_bias = alibi_bias.to(device) |
| 718 | + return alibi_bias |
| 719 | + with torch.autocast(device.type, enabled=False): |
| 720 | + alibi_bias = _olmo_alibi_attention_bias(seq_len, self.config, device) |
| 721 | + self.register_buffer("alibi_attention_bias", alibi_bias) |
| 722 | + return alibi_bias |
| 723 | + |
| 724 | + |
| 725 | +def _olmo_get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]: |
| 726 | + if ( |
| 727 | + hasattr(self, "rope_pos_sin") |
| 728 | + and hasattr(self, "rope_pos_cos") |
| 729 | + and self.rope_pos_sin.shape[-2] >= seq_len |
| 730 | + and self.rope_pos_cos.shape[-2] >= seq_len |
| 731 | + ): |
| 732 | + return self.rope_pos_sin.to(device)[:, :, :seq_len, :], self.rope_pos_sin.to(device)[:, :, :seq_len, :] |
| 733 | + |
| 734 | + with torch.autocast(device.type, enabled=False): |
| 735 | + dim = self.config.d_model // self.config.n_heads |
| 736 | + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim)) |
| 737 | + seq = torch.arange(seq_len, device=device, dtype=torch.float) |
| 738 | + freqs = torch.einsum("i , j -> i j", seq, inv_freq) |
| 739 | + positions = torch.cat((freqs, freqs), dim=-1) |
| 740 | + pos_sin, pos_cos = positions.sin()[None, None, :, :], positions.cos()[None, None, :, :] |
| 741 | + |
| 742 | + self.register_buffer("rope_pos_sin", pos_sin) |
| 743 | + self.register_buffer("rope_pos_cos", pos_cos) |
| 744 | + return pos_sin, pos_cos |
| 745 | + |
| 746 | + |
| 747 | +class OLMoModelPatcher(DecoderModelPatcher): |
| 748 | + def __enter__(self): |
| 749 | + super().__enter__() |
| 750 | + # model uses custom cache buffers for storing rotary_embeddings and attention biases. |
| 751 | + # these objects are nontracable, replace them with standard torch tensors during export |
| 752 | + self._model.model._orig_forward = self._model.model.forward |
| 753 | + self._model.model._orig_get_alibi_attention_bias = self._model.model.get_alibi_attention_bias |
| 754 | + self._model.model.forward = types.MethodType(_olmo_model_forward, self._model.model) |
| 755 | + self._model.model.get_alibi_attention_bias = types.MethodType( |
| 756 | + _olmo_get_alibi_attention_bias, self._model.model |
| 757 | + ) |
| 758 | + self._model.model.get_alibi_attention_bias(self._model.config.max_sequence_length, torch.device("cpu")) |
| 759 | + self._model.model.get_causal_attention_bias = types.MethodType( |
| 760 | + _olmo_get_causal_attention_bias, self._model.model |
| 761 | + ) |
| 762 | + self._model.model.get_causal_attention_bias(self._model.config.max_sequence_length, torch.device("cpu")) |
| 763 | + for block in self._model.model.transformer.blocks: |
| 764 | + block.rotary_emb._orig_get_rotary_embedding = block.rotary_emb.get_rotary_embedding |
| 765 | + block.rotary_emb.get_rotary_embedding = types.MethodType(_olmo_get_rotary_embedding, block.rotary_emb) |
| 766 | + block.rotary_emb.get_rotary_embedding(self._model.config.max_sequence_length, torch.device("cpu")) |
| 767 | + |
| 768 | + def __exit__(self, exc_type, exc_value, traceback): |
| 769 | + super().__exit__(exc_type, exc_value, traceback) |
| 770 | + self._model.model.forward = self._model.model._orig_forward |
| 771 | + self._model.model.get_alibi_attention_bias = self._model.model._orig_get_alibi_attention_bias |
| 772 | + for block in self._model.model.transformer.blocks: |
| 773 | + block.rotary_emb.get_rotary_embedding = block.rotary_emb._orig_get_rotary_embedding |
0 commit comments