diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index 12a0a3d..c383153 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -709,10 +709,9 @@ def train( gpus=n_gpu, max_steps=num_steps, gradient_clip_val=max_grad_norm, - checkpoint_callback=False, + enable_checkpointing=False, #checkpoint_callback deprecated in pytorch_lighning v1.7 logger=loggers if loggers else False, - weights_summary=None, - progress_bar_refresh_rate=progress_bar_refresh_rate, # ignored + enable_model_summary=None, #weights_summary and progress_bar_refresh_rate are removed in pytorch_lighning v1.7 callbacks=[ ATGProgressBar( save_every, diff --git a/aitextgen/train.py b/aitextgen/train.py index 2e289cb..ce54a4a 100644 --- a/aitextgen/train.py +++ b/aitextgen/train.py @@ -141,6 +141,12 @@ def on_train_start(self, trainer, pl_module): def on_train_end(self, trainer, pl_module): self.main_progress_bar.close() self.unfreeze_layers(pl_module) + + def get_metrics(self, trainer, pl_module): + # don't show the version number + items = super().get_metrics(trainer, pl_module) + items.pop("v_num", None) + return items def on_batch_end(self, trainer, pl_module): super().on_batch_end(trainer, pl_module) @@ -150,7 +156,8 @@ def on_batch_end(self, trainer, pl_module): if self.steps == 0 and self.gpu: torch.cuda.empty_cache() - current_loss = float(trainer.progress_bar_dict["loss"]) + metrics = self.get_metrics(trainer, pl_module) + current_loss = float(metrics["loss"]) self.steps += 1 avg_loss = 0 if current_loss == current_loss: # don't add if current_loss is NaN