From 6cf2d1c9e889cb155ee27a6059ce2852f11cc539 Mon Sep 17 00:00:00 2001 From: Nathan Frey Date: Wed, 5 Mar 2025 21:36:59 -0500 Subject: [PATCH] configure lr scheduler (#41) * update configs * fix HF datasets tests, add mocking (#42) * add mocking * rename * update configs * get_scheduler * refactor --------- Co-authored-by: Karina Zadorozhny Co-authored-by: freyn6 --- .../hydra_config/lr_scheduler/default.yaml | 7 ++ src/lobster/hydra_config/model/clm.yaml | 11 +++ src/lobster/hydra_config/model/mlm.yaml | 11 +++ .../hydra_config/model/modern_bert.yaml | 16 +++- src/lobster/hydra_config/train.yaml | 1 + src/lobster/model/_cbmlm.py | 8 +- src/lobster/model/_clm.py | 48 +++++++++--- .../model/_conditioanalclassifiermlm.py | 8 +- src/lobster/model/_conditioanalmlm.py | 8 +- src/lobster/model/_dyab.py | 3 + src/lobster/model/_lobster_fold.py | 11 ++- src/lobster/model/_mgm.py | 11 ++- src/lobster/model/_mlm.py | 45 +++++++++-- src/lobster/model/_seq2seq.py | 11 ++- src/lobster/model/hyena/_hyena.py | 12 ++- src/lobster/model/modern_bert/_modern_bert.py | 78 +++++++++++++++++-- .../model/modern_bert/test__modern_bert.py | 70 ++++++++++++++--- 17 files changed, 276 insertions(+), 83 deletions(-) create mode 100644 src/lobster/hydra_config/lr_scheduler/default.yaml diff --git a/src/lobster/hydra_config/lr_scheduler/default.yaml b/src/lobster/hydra_config/lr_scheduler/default.yaml new file mode 100644 index 0000000..cf7ee70 --- /dev/null +++ b/src/lobster/hydra_config/lr_scheduler/default.yaml @@ -0,0 +1,7 @@ +defaults: + - _self_ + +scheduler: + _target_: "transformers.optimization.get_linear_schedule_with_warmup" + num_warmup_steps: ${model.num_warmup_steps} + num_training_steps: ${model.num_training_steps} \ No newline at end of file diff --git a/src/lobster/hydra_config/model/clm.yaml b/src/lobster/hydra_config/model/clm.yaml index 09c5a64..f50e6e1 100644 --- a/src/lobster/hydra_config/model/clm.yaml +++ b/src/lobster/hydra_config/model/clm.yaml @@ -9,3 +9,14 @@ max_length: 512 num_training_steps: ${trainer.max_steps} num_key_value_heads: null attention_bias: false + +# Model-specific configuration parameters +model_kwargs: + embedding_layer: linear_pos + hidden_act: gelu + +# Scheduler-specific configuration parameters +scheduler_kwargs: + # Any specific scheduler parameters would go here + # For example: + # min_lr: 1e-7 # For cosine_with_min_lr scheduler diff --git a/src/lobster/hydra_config/model/mlm.yaml b/src/lobster/hydra_config/model/mlm.yaml index ec03bbb..0e46b2b 100644 --- a/src/lobster/hydra_config/model/mlm.yaml +++ b/src/lobster/hydra_config/model/mlm.yaml @@ -8,3 +8,14 @@ num_warmup_steps: 10_000 tokenizer_dir: pmlm_tokenizer max_length: 512 num_training_steps: ${trainer.max_steps} + +# Model-specific configuration parameters +model_kwargs: + embedding_layer: linear_pos + hidden_act: gelu + +# Scheduler-specific configuration parameters +scheduler_kwargs: + # Any specific scheduler parameters would go here + # For example: + # min_lr: 1e-7 # For cosine_with_min_lr scheduler diff --git a/src/lobster/hydra_config/model/modern_bert.yaml b/src/lobster/hydra_config/model/modern_bert.yaml index d8ad6ae..7e18af2 100644 --- a/src/lobster/hydra_config/model/modern_bert.yaml +++ b/src/lobster/hydra_config/model/modern_bert.yaml @@ -1,12 +1,24 @@ _target_: lobster.model.modern_bert.FlexBERT +# Base model parameters lr: 1e-3 num_training_steps: ${trainer.max_steps} model_name: UME_mini num_warmup_steps: 10_000 max_length: 512 tokenizer_dir: pmlm_tokenizer -embedding_layer: linear_pos -hidden_act: gelu mask_percentage: 0.25 +scheduler: "constant_with_warmup" ckpt_path: null + +# Model-specific configuration parameters +model_kwargs: + embedding_layer: linear_pos + hidden_act: gelu + +# Scheduler-specific configuration parameters +scheduler_kwargs: + # Any specific scheduler parameters would go here + # For example: + # min_lr: 1e-7 # For cosine_with_min_lr scheduler + diff --git a/src/lobster/hydra_config/train.yaml b/src/lobster/hydra_config/train.yaml index 4ce6d17..4458724 100644 --- a/src/lobster/hydra_config/train.yaml +++ b/src/lobster/hydra_config/train.yaml @@ -7,6 +7,7 @@ defaults: - trainer: default.yaml - setup: default.yaml - paths: default.yaml + - lr_scheduler: default.yaml - plugins: null - experiment: null diff --git a/src/lobster/model/_cbmlm.py b/src/lobster/model/_cbmlm.py index 919a59b..b392577 100644 --- a/src/lobster/model/_cbmlm.py +++ b/src/lobster/model/_cbmlm.py @@ -9,8 +9,8 @@ import torch import torch.nn.functional as F from huggingface_hub import hf_hub_download +from hydra.utils import instantiate from transformers.configuration_utils import PretrainedConfig -from transformers.optimization import get_linear_schedule_with_warmup from lobster.tokenization import CUSTOM_TOKENIZER, PmlmConceptTokenizerTransform, PmlmTokenizer @@ -291,11 +291,7 @@ def configure_optimizers(self): optimizer = torch.optim.AdamW( self.model.parameters(), lr=self._lr, betas=(self._beta1, self._beta2), eps=self._eps ) - scheduler = get_linear_schedule_with_warmup( - optimizer, - num_warmup_steps=self._num_warmup_steps, - num_training_steps=self._num_training_steps, - ) + scheduler = instantiate(self.scheduler_cfg, optimizer=optimizer) scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} return {"optimizer": optimizer, "lr_scheduler": scheduler} diff --git a/src/lobster/model/_clm.py b/src/lobster/model/_clm.py index 165c908..5f67a50 100644 --- a/src/lobster/model/_clm.py +++ b/src/lobster/model/_clm.py @@ -1,11 +1,10 @@ import importlib.resources -from typing import Callable, Optional, Tuple, Union +from typing import Callable, Literal, Optional, Tuple, Union import lightning.pytorch as pl import torch from torch.nn import CrossEntropyLoss -from transformers import LlamaConfig, LlamaForCausalLM, pipeline -from transformers.optimization import get_linear_schedule_with_warmup +from transformers import LlamaConfig, LlamaForCausalLM, get_scheduler, pipeline from lobster.tokenization import PmlmTokenizer, PmlmTokenizerTransform from lobster.transforms import Transform @@ -29,6 +28,20 @@ def __init__( max_length: int = 512, num_key_value_heads: int = None, attention_bias: bool = False, + scheduler: Literal[ + "linear", + "cosine", + "cosine_with_restarts", + "polynomial", + "constant", + "constant_with_warmup", + "inverse_sqrt", + "reduce_lr_on_plateau", + "cosine_with_min_lr", + "warmup_stable_decay", + ] = "constant_with_warmup", + model_kwargs: dict = None, + scheduler_kwargs: dict = None, ): """ Prescient Protein Causal Language Model. @@ -40,6 +53,12 @@ def __init__( Grouped Query Attention. If`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. + scheduler: str, optional + The type of learning rate scheduler. + model_kwargs: dict, optional + Additional keyword arguments to pass to the model. + scheduler_kwargs: dict, optional + Additional keyword arguments to pass to the scheduler. """ super().__init__() @@ -55,6 +74,9 @@ def __init__( self._tokenizer_dir = tokenizer_dir self._max_length = max_length self._attention_bias = attention_bias + self.scheduler = scheduler + self.scheduler_kwargs = scheduler_kwargs or {} + model_kwargs = model_kwargs or {} if self._tokenizer_dir is not None: path = importlib.resources.files("lobster") / "assets" / self._tokenizer_dir @@ -73,6 +95,7 @@ def __init__( self._num_key_value_heads = num_key_value_heads config = LlamaConfig( + **config_args, mask_token_id=self.tokenizer.mask_token_id, pad_token_id=self.tokenizer.pad_token_id, cls_token_id=self.tokenizer.cls_token_id, @@ -81,7 +104,7 @@ def __init__( max_position_embeddings=self._max_length, num_key_value_heads=self._num_key_value_heads, attention_bias=self._attention_bias, - **config_args, + **model_kwargs, ) self.model = LlamaForCausalLM(config) self.config = self.model.config @@ -95,8 +118,6 @@ def training_step(self, batch, batch_idx): self.log("train_loss", loss, sync_dist=True) self.log("train_perplexity", ppl, sync_dist=True) - # self.log("loss", loss, batch_size=len(batch["input_ids"]), sync_dist=True) - return {"loss": loss} def validation_step(self, batch, batch_idx): @@ -115,11 +136,16 @@ def configure_optimizers(self): eps=self._eps, ) - scheduler = get_linear_schedule_with_warmup( - optimizer, - num_warmup_steps=self._num_warmup_steps, - num_training_steps=self._num_training_steps, - ) + # Create base kwargs for the scheduler + scheduler_params = { + "num_warmup_steps": self._num_warmup_steps, + "num_training_steps": self._num_training_steps, + } + + # Add any additional scheduler kwargs from initialization + scheduler_params.update(self.scheduler_kwargs) + + scheduler = get_scheduler(self.scheduler, optimizer, **scheduler_params) scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} diff --git a/src/lobster/model/_conditioanalclassifiermlm.py b/src/lobster/model/_conditioanalclassifiermlm.py index c918f15..5b520bb 100644 --- a/src/lobster/model/_conditioanalclassifiermlm.py +++ b/src/lobster/model/_conditioanalclassifiermlm.py @@ -5,8 +5,8 @@ import pandas as pd import torch import torch.nn.functional as F +from hydra.utils import instantiate from transformers.configuration_utils import PretrainedConfig -from transformers.optimization import get_linear_schedule_with_warmup from lobster.tokenization import CUSTOM_TOKENIZER, PmlmConceptTokenizerTransform, PmlmTokenizer from lobster.transforms import Transform @@ -218,11 +218,7 @@ def configure_optimizers(self): optimizer = torch.optim.AdamW( self.model.parameters(), lr=self._lr, betas=(self._beta1, self._beta2), eps=self._eps ) - scheduler = get_linear_schedule_with_warmup( - optimizer, - num_warmup_steps=self._num_warmup_steps, - num_training_steps=self._num_training_steps, - ) + scheduler = instantiate(self.scheduler_cfg, optimizer=optimizer) scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} diff --git a/src/lobster/model/_conditioanalmlm.py b/src/lobster/model/_conditioanalmlm.py index 70746e8..ee9d40b 100644 --- a/src/lobster/model/_conditioanalmlm.py +++ b/src/lobster/model/_conditioanalmlm.py @@ -4,8 +4,8 @@ import lightning.pytorch as pl import pandas as pd import torch +from hydra.utils import instantiate from transformers.configuration_utils import PretrainedConfig -from transformers.optimization import get_linear_schedule_with_warmup from lobster.tokenization import CUSTOM_TOKENIZER, PmlmConceptTokenizerTransform, PmlmTokenizer from lobster.transforms import Transform @@ -180,11 +180,7 @@ def configure_optimizers(self): optimizer = torch.optim.AdamW( self.model.parameters(), lr=self._lr, betas=(self._beta1, self._beta2), eps=self._eps ) - scheduler = get_linear_schedule_with_warmup( - optimizer, - num_warmup_steps=self._num_warmup_steps, - num_training_steps=self._num_training_steps, - ) + scheduler = instantiate(self.scheduler_cfg, optimizer=optimizer) scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} diff --git a/src/lobster/model/_dyab.py b/src/lobster/model/_dyab.py index 93ebb29..46e6fd3 100644 --- a/src/lobster/model/_dyab.py +++ b/src/lobster/model/_dyab.py @@ -6,6 +6,7 @@ import torch.nn as nn import torch.optim as optim import torchvision.models as models +from omegaconf import DictConfig from torchmetrics import ( MeanAbsoluteError, R2Score, @@ -34,6 +35,7 @@ def __init__( diff_channel_0: Literal["diff", "add", "mul", "div"] = "diff", diff_channel_1: Optional[Literal["sub", "add", "mul", "div"]] = None, diff_channel_2: Optional[Literal["diff", "add", "mul", "div"]] = None, + scheduler_cfg: DictConfig = None, ): """ DyAb head. @@ -78,6 +80,7 @@ def __init__( self._diff_channel_0 = diff_channel_0 self._diff_channel_1 = diff_channel_1 self._diff_channel_2 = diff_channel_2 + self.scheduler_cfg = scheduler_cfg if model_name is None and checkpoint is None: model_name = "esm2_t6_8M_UR50D" diff --git a/src/lobster/model/_lobster_fold.py b/src/lobster/model/_lobster_fold.py index e2785cf..5fcf07e 100644 --- a/src/lobster/model/_lobster_fold.py +++ b/src/lobster/model/_lobster_fold.py @@ -4,9 +4,10 @@ import lightning.pytorch as pl import torch +from hydra.utils import instantiate +from omegaconf import DictConfig from transformers import AutoTokenizer, EsmForProteinFolding from transformers.configuration_utils import PretrainedConfig -from transformers.optimization import get_linear_schedule_with_warmup from lobster.extern.openfold_utils import atom14_to_atom37, backbone_loss from lobster.transforms import AutoTokenizerTransform, Transform @@ -33,6 +34,7 @@ def __init__( tokenizer_dir: Optional[str] = "pmlm_tokenizer", max_length: int = 512, cache_dir: str = None, + scheduler_cfg: DictConfig = None, ): """ Prescient Protein Language Model for Folding. @@ -62,6 +64,7 @@ def __init__( self._num_warmup_steps = num_warmup_steps self._tokenizer_dir = tokenizer_dir self._max_length = max_length + self.scheduler_cfg = scheduler_cfg cache_dir = cache_dir or "~/.cache/huggingface/datasets" self._cache_dir = cache_dir @@ -176,11 +179,7 @@ def configure_optimizers(self): betas=(self._beta1, self._beta2), eps=self._eps, ) - scheduler = get_linear_schedule_with_warmup( - optimizer, - num_warmup_steps=self._num_warmup_steps, - num_training_steps=self._num_training_steps, - ) + scheduler = instantiate(self.scheduler_cfg, optimizer=optimizer) scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} diff --git a/src/lobster/model/_mgm.py b/src/lobster/model/_mgm.py index f155c6d..c8972e5 100644 --- a/src/lobster/model/_mgm.py +++ b/src/lobster/model/_mgm.py @@ -4,9 +4,10 @@ import lightning.pytorch as pl import pandas as pd import torch +from hydra.utils import instantiate +from omegaconf import DictConfig from transformers import AutoTokenizer, EsmForMaskedLM from transformers.configuration_utils import PretrainedConfig -from transformers.optimization import get_linear_schedule_with_warmup from lobster.tokenization import MgmTokenizer, MgmTokenizerTransform, PmlmTokenizer, PmlmTokenizerTransform from lobster.transforms import AutoTokenizerTransform, Transform @@ -34,6 +35,7 @@ def __init__( tokenizer_dir: Optional[str] = "mgm_tokenizer", max_length: int = 512, position_embedding_type: Literal["rotary", "absolute"] = "rotary", + scheduler_cfg: DictConfig = None, ): """ Multi-granularity model (MGM). @@ -68,6 +70,7 @@ def __init__( self._tokenizer_dir = tokenizer_dir self._max_length = max_length self._position_embedding_type = position_embedding_type + self.scheduler_cfg = scheduler_cfg if model_name and "esm2" in model_name: self.tokenizer = AutoTokenizer.from_pretrained(f"facebook/{model_name}", do_lower_case=False) @@ -239,11 +242,7 @@ def configure_optimizers(self): betas=(self._beta1, self._beta2), eps=self._eps, ) - scheduler = get_linear_schedule_with_warmup( - optimizer, - num_warmup_steps=self._num_warmup_steps, - num_training_steps=self._num_training_steps, - ) + scheduler = instantiate(self.scheduler_cfg, optimizer=optimizer) scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} diff --git a/src/lobster/model/_mlm.py b/src/lobster/model/_mlm.py index fdc6d50..46e54c6 100644 --- a/src/lobster/model/_mlm.py +++ b/src/lobster/model/_mlm.py @@ -6,9 +6,8 @@ import lightning.pytorch as pl import pandas as pd import torch -from transformers import AutoModelForMaskedLM, AutoTokenizer, EsmForMaskedLM +from transformers import AutoModelForMaskedLM, AutoTokenizer, EsmForMaskedLM, get_scheduler from transformers.configuration_utils import PretrainedConfig -from transformers.optimization import get_linear_schedule_with_warmup from lobster.tokenization import PmlmTokenizer, PmlmTokenizerTransform from lobster.transforms import AutoTokenizerTransform, Transform @@ -37,6 +36,20 @@ def __init__( max_length: int = 512, position_embedding_type: Literal["rotary", "absolute"] = "rotary", use_bfloat16: bool = False, + scheduler: Literal[ + "linear", + "cosine", + "cosine_with_restarts", + "polynomial", + "constant", + "constant_with_warmup", + "inverse_sqrt", + "reduce_lr_on_plateau", + "cosine_with_min_lr", + "warmup_stable_decay", + ] = "constant_with_warmup", + model_kwargs: dict = None, + scheduler_kwargs: dict = None, ): """ Prescient Protein Masked Language Model. @@ -54,6 +67,12 @@ def __init__( tokenizer_dir: a tokenizer saved to src/lobster/assets max_length: max sequence length the model will see use_bfloat16: use bfloat16 instead of float32 for ESM-C model weights + scheduler: str, optional + The type of learning rate scheduler. + model_kwargs: dict, optional + Additional keyword arguments to pass to the model. + scheduler_kwargs: dict, optional + Additional keyword arguments to pass to the scheduler. """ super().__init__() @@ -72,6 +91,9 @@ def __init__( self._max_length = max_length self._position_embedding_type = position_embedding_type self._use_esmc = False + self.scheduler = scheduler + self.scheduler_kwargs = scheduler_kwargs or {} + model_kwargs = model_kwargs or {} load_pretrained = config is None and model_name not in PMLM_CONFIG_ARGS @@ -133,13 +155,14 @@ def __init__( assert model_name in PMLM_CONFIG_ARGS config_args = PMLM_CONFIG_ARGS[model_name] config = PMLMConfig( + **config_args, attention_probs_dropout_prob=0.0, mask_token_id=self.tokenizer.mask_token_id, pad_token_id=self.tokenizer.pad_token_id, position_embedding_type=self._position_embedding_type, vocab_size=len(self.tokenizer.get_vocab()), max_position_embeddings=self._max_length, - **config_args, + **model_kwargs, ) self.model = LMBaseForMaskedLM(config) @@ -238,11 +261,17 @@ def configure_optimizers(self): betas=(self._beta1, self._beta2), eps=self._eps, ) - scheduler = get_linear_schedule_with_warmup( - optimizer, - num_warmup_steps=self._num_warmup_steps, - num_training_steps=self._num_training_steps, - ) + + # Create base kwargs for the scheduler + scheduler_params = { + "num_warmup_steps": self._num_warmup_steps, + "num_training_steps": self._num_training_steps, + } + + # Add any additional scheduler kwargs from initialization + scheduler_params.update(self.scheduler_kwargs) + + scheduler = get_scheduler(self.scheduler, optimizer, **scheduler_params) scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} diff --git a/src/lobster/model/_seq2seq.py b/src/lobster/model/_seq2seq.py index e24c89b..b084f3f 100644 --- a/src/lobster/model/_seq2seq.py +++ b/src/lobster/model/_seq2seq.py @@ -3,6 +3,8 @@ import lightning.pytorch as pl import torch +from hydra.utils import instantiate +from omegaconf import DictConfig from transformers import ( AutoModelForSeq2SeqLM, T5Config, @@ -11,7 +13,6 @@ T5Tokenizer, ) from transformers.configuration_utils import PretrainedConfig -from transformers.optimization import get_linear_schedule_with_warmup from lobster.transforms import Transform @@ -33,6 +34,7 @@ def __init__( transform_fn: Union[Callable, Transform, None] = None, config: Union[PretrainedConfig, T5Config, None] = None, ckpt_path: str = None, + scheduler_cfg: DictConfig = None, ): """ Prescient Protein T5 Model. @@ -63,6 +65,7 @@ def __init__( self._num_training_steps = num_training_steps self._num_warmup_steps = num_warmup_steps self.tokenizer = T5Tokenizer.from_pretrained(f"Rostlab/{model_name}", do_lower_case=False) + self.scheduler_cfg = scheduler_cfg if is_training: self.model = T5ForConditionalGeneration.from_pretrained(f"Rostlab/{model_name}") @@ -106,11 +109,7 @@ def configure_optimizers(self): optimizer = torch.optim.AdamW( self.model.parameters(), lr=self._lr, betas=(self._beta1, self._beta2), eps=self._eps ) - scheduler = get_linear_schedule_with_warmup( - optimizer, - num_warmup_steps=self._num_warmup_steps, - num_training_steps=self._num_training_steps, - ) + scheduler = instantiate(self.scheduler_cfg, optimizer=optimizer) scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} diff --git a/src/lobster/model/hyena/_hyena.py b/src/lobster/model/hyena/_hyena.py index 170d053..ea6fdd2 100644 --- a/src/lobster/model/hyena/_hyena.py +++ b/src/lobster/model/hyena/_hyena.py @@ -3,10 +3,10 @@ import lightning.pytorch as pl import torch +from hydra.utils import instantiate +from omegaconf import DictConfig # from transformers import LlamaConfig, LlamaForCausalLM, pipeline -from transformers.optimization import get_linear_schedule_with_warmup - # from lobster.tokenization import PmlmTokenizer, PmlmTokenizerTransform from lobster.tokenization import HyenaTokenizer, HyenaTokenizerTransform from lobster.transforms import Transform @@ -29,6 +29,7 @@ def __init__( tokenizer_dir: Optional[str] = "hyena_tokenizer", ckpt_path: str = None, max_length: int = 1024, + scheduler_cfg: DictConfig = None, ): """ Prescient HyenaDNA Causal Language Model. @@ -51,6 +52,7 @@ def __init__( self._ckpt_path = ckpt_path self._tokenizer_dir = tokenizer_dir self._max_length = max_length + self.scheduler_cfg = scheduler_cfg if self._tokenizer_dir is not None: path = importlib.resources.files("lobster") / "assets" / self._tokenizer_dir @@ -101,11 +103,7 @@ def configure_optimizers(self): eps=self._eps, ) - scheduler = get_linear_schedule_with_warmup( - optimizer, - num_warmup_steps=self._num_warmup_steps, - num_training_steps=self._num_training_steps, - ) + scheduler = instantiate(self.scheduler_cfg, optimizer=optimizer) scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} diff --git a/src/lobster/model/modern_bert/_modern_bert.py b/src/lobster/model/modern_bert/_modern_bert.py index 3109198..578a9a3 100644 --- a/src/lobster/model/modern_bert/_modern_bert.py +++ b/src/lobster/model/modern_bert/_modern_bert.py @@ -4,9 +4,10 @@ import lightning.pytorch as pl import torch from torch import nn -from transformers.optimization import get_linear_schedule_with_warmup -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast -import lightning.fabric.utilities.throughput +from omegaconf import DictConfig, OmegaConf +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, get_scheduler +import lightning.fabric.utilities.throughput +from omegaconf import DictConfig, OmegaConf from ._config import FlexBertConfig from ._model import FlexBertModel, FlexBertPredictionHead @@ -41,8 +42,58 @@ def __init__( num_warmup_steps: int = 1_000, mask_percentage: float = 0.25, max_length: int = 512, - **model_kwargs, + scheduler: Literal["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup", + "inverse_sqrt", "reduce_lr_on_plateau", "cosine_with_min_lr", "warmup_stable_decay", + ] = "constant_with_warmup", + model_kwargs: dict = None, + scheduler_kwargs: dict = None, + **kwargs, ): + """FlexBERT model for unsupervised pretraining. + + Parameters + ---------- + model_name: str + One of the keys in `FLEXBERT_CONFIG_ARGS`. + vocab_size: int, optional + The size of the vocabulary. Required if `tokenizer` is not provided. + pad_token_id: int, optional + The ID of the padding token. Required if `tokenizer` is not provided. + mask_token_id: int, optional + The ID of the mask token. Required if `tokenizer` is not provided. + cls_token_id: int, optional + The ID of the classification token. Required if `tokenizer` is not provided. + eos_token_id: int, optional + The ID of the end-of-sequence token. Required if `tokenizer` is not provided. + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, optional + A pretrained tokenizer. Required if `vocab_size`, `pad_token_id`, `mask_token_id`, `cls_token_id`, and + `eos_token_id` are not provided. + lr: float, optional + The learning rate. + beta1: float, optional + The beta1 parameter for the Adam optimizer. + beta2: float, optional + The beta2 parameter for the Adam optimizer. + eps: float, optional + The epsilon parameter for the Adam optimizer. + num_training_steps: int, optional + The total number of training steps. + num_warmup_steps: int, optional + The number of warmup steps. + mask_percentage: float, optional + The percentage of tokens to mask. + max_length: int, optional + The maximum sequence length. + scheduler: str, optional + The type of learning rate scheduler. + model_kwargs: dict, optional + Additional keyword arguments to pass to the model. + scheduler_kwargs: dict, optional + Additional keyword arguments to pass to the scheduler. + kwargs + Additional keyword arguments. + """ + super().__init__() self._model_name = model_name self._lr = lr @@ -53,8 +104,11 @@ def __init__( self._num_warmup_steps = num_warmup_steps self._mask_percentage = mask_percentage self.max_length = max_length + self.scheduler = scheduler + self.scheduler_kwargs = scheduler_kwargs or {} config_args = FLEXBERT_CONFIG_ARGS[model_name] + model_kwargs = model_kwargs or {} self.config = FlexBertConfig( **config_args, @@ -114,11 +168,19 @@ def configure_optimizers(self): eps=self._eps, ) - # TODO: Make this configurable - scheduler = get_linear_schedule_with_warmup( + # Create base kwargs for the scheduler + scheduler_params = { + "num_warmup_steps": self._num_warmup_steps, + "num_training_steps": self._num_training_steps, + } + + # Add any additional scheduler kwargs from initialization + scheduler_params.update(self.scheduler_kwargs) + + scheduler = get_scheduler( + self.scheduler, optimizer, - num_warmup_steps=self._num_warmup_steps, - num_training_steps=self._num_training_steps, + **scheduler_params ) scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} diff --git a/tests/lobster/model/modern_bert/test__modern_bert.py b/tests/lobster/model/modern_bert/test__modern_bert.py index 9365627..3f8ec7f 100644 --- a/tests/lobster/model/modern_bert/test__modern_bert.py +++ b/tests/lobster/model/modern_bert/test__modern_bert.py @@ -1,7 +1,8 @@ from importlib.util import find_spec +from pathlib import Path -import torch -from torch import Size, Tensor +from hydra.utils import instantiate +from omegaconf import OmegaConf _FLASH_ATTN_AVAILABLE = False @@ -12,14 +13,61 @@ class TestFlexBERT: - def test_sequences_to_latents(self): - if _FLASH_ATTN_AVAILABLE and torch.cuda.is_available(): - model = FlexBERT(model_name="UME_mini").cuda() + # def test_sequences_to_latents(self): + # if _FLASH_ATTN_AVAILABLE and torch.cuda.is_available(): + # model = FlexBERT(model_name="UME_mini").cuda() - inputs = ["ACDAC", "ACDAC"] - outputs = model.sequences_to_latents(inputs) + # inputs = ["ACDAC", "ACDAC"] + # outputs = model.sequences_to_latents(inputs) - assert len(outputs) == 2 - assert isinstance(outputs[0], Tensor) - assert outputs[-1].shape == Size([512, 252]) # L, d_model - assert outputs[0].device == model.device + # assert len(outputs) == 2 + # assert isinstance(outputs[0], Tensor) + # assert outputs[-1].shape == Size([512, 252]) # L, d_model + # assert outputs[0].device == model.device + + def test_hydra_instantiate(self): + if not _FLASH_ATTN_AVAILABLE: + import pytest + + pytest.skip("flash_attn not available") + + # Define path to the config file + config_path = config_path = ( + Path(__file__).parents[4] / "src" / "lobster" / "hydra_config" / "model" / "modern_bert.yaml" + ) + + # Load the config directly from YAML + config = OmegaConf.load(config_path) + + # Need to resolve the trainer.max_steps variable for testing + config.num_training_steps = 10_000 # Set to a fixed value for testing + + # Add missing required parameters for proper instantiation + config.vocab_size = 30522 # Standard BERT vocab size + config.pad_token_id = 0 + config.mask_token_id = 103 + config.cls_token_id = 101 + config.eos_token_id = 102 + + # Instantiate the model using the loaded config + model = instantiate(config) + + # Test basic model properties + assert isinstance(model, FlexBERT) + assert model._model_name == "UME_mini" + assert model._lr == 1e-3 + assert model._beta1 == 0.9 # Default value + assert model._beta2 == 0.98 # Default value + assert model._eps == 1e-12 # Default value + assert model._num_training_steps == 10_000 + assert model._num_warmup_steps == 10_000 # From the YAML + assert model._mask_percentage == 0.25 + assert model.max_length == 512 # Note: Updated attribute name + assert model.scheduler == "constant_with_warmup" # Note: Updated attribute name + + # Test that model_kwargs were correctly passed to the config + assert model.config.embedding_layer == "linear_pos" + assert model.config.hidden_act == "gelu" + + # Test scheduler_kwargs is initialized correctly + assert hasattr(model, "scheduler_kwargs")