Skip to content
This repository was archived by the owner on Jun 4, 2025. It is now read-only.

Commit 15ba9b7

Browse files
authored
Best model after epoch (#46)
1 parent 053646a commit 15ba9b7

File tree

2 files changed

+18
-9
lines changed

2 files changed

+18
-9
lines changed

src/transformers/trainer.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1725,7 +1725,12 @@ def _save_checkpoint(self, model, trial, metrics=None):
17251725
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
17261726

17271727
# Determine the new best metric / best model checkpoint
1728-
if metrics is not None and self.args.metric_for_best_model is not None:
1728+
if (
1729+
metrics is not None
1730+
and self.args.metric_for_best_model is not None
1731+
and self.args.best_model_after_epoch is not None
1732+
and self.state.epoch > self.args.best_model_after_epoch
1733+
):
17291734
metric_to_check = self.args.metric_for_best_model
17301735
if not metric_to_check.startswith("eval_"):
17311736
metric_to_check = f"eval_{metric_to_check}"
@@ -2661,7 +2666,9 @@ def prediction_step(
26612666
logits = smp_nested_concat(logits_mb)
26622667
else:
26632668
if has_labels:
2664-
with self.autocast_smart_context_manager(enabled=hasattr(self, "scaler") and self.scaler.is_enabled()):
2669+
with self.autocast_smart_context_manager(
2670+
enabled=hasattr(self, "scaler") and self.scaler.is_enabled()
2671+
):
26652672
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
26662673
loss = loss.mean().detach()
26672674

@@ -2671,7 +2678,9 @@ def prediction_step(
26712678
logits = outputs[1:]
26722679
else:
26732680
loss = None
2674-
with self.autocast_smart_context_manager(enabled=hasattr(self, "scaler") and self.scaler.is_enabled()):
2681+
with self.autocast_smart_context_manager(
2682+
enabled=hasattr(self, "scaler") and self.scaler.is_enabled()
2683+
):
26752684
outputs = model(**inputs)
26762685
if isinstance(outputs, dict):
26772686
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)

src/transformers/training_args.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,10 @@ class TrainingArguments:
640640
default=False,
641641
metadata={"help": "Whether or not to load the best model found during training at the end of training."},
642642
)
643+
best_model_after_epoch: int = field(
644+
default=None,
645+
metadata={"help": "Epoch after which best model will be saved."},
646+
)
643647
metric_for_best_model: Optional[str] = field(
644648
default=None, metadata={"help": "The metric to use to compare two different models."}
645649
)
@@ -748,12 +752,8 @@ class TrainingArguments:
748752
metadata={"help": "Used by the SageMaker launcher to send mp-specific args. Ignored in Trainer"},
749753
)
750754
modifier_log_frequency: float = field(
751-
default = 0.1,
752-
metadata={
753-
"help": (
754-
"How often to log SparseML modifier data, in number of epochs or fraction of epochs"
755-
)
756-
}
755+
default=0.1,
756+
metadata={"help": ("How often to log SparseML modifier data, in number of epochs or fraction of epochs")},
757757
)
758758

759759
def __post_init__(self):

0 commit comments

Comments
 (0)