@@ -1725,7 +1725,12 @@ def _save_checkpoint(self, model, trial, metrics=None):
17251725 torch .save (self .scaler .state_dict (), os .path .join (output_dir , SCALER_NAME ))
17261726
17271727 # Determine the new best metric / best model checkpoint
1728- if metrics is not None and self .args .metric_for_best_model is not None :
1728+ if (
1729+ metrics is not None
1730+ and self .args .metric_for_best_model is not None
1731+ and self .args .best_model_after_epoch is not None
1732+ and self .state .epoch > self .args .best_model_after_epoch
1733+ ):
17291734 metric_to_check = self .args .metric_for_best_model
17301735 if not metric_to_check .startswith ("eval_" ):
17311736 metric_to_check = f"eval_{ metric_to_check } "
@@ -2661,7 +2666,9 @@ def prediction_step(
26612666 logits = smp_nested_concat (logits_mb )
26622667 else :
26632668 if has_labels :
2664- with self .autocast_smart_context_manager (enabled = hasattr (self , "scaler" ) and self .scaler .is_enabled ()):
2669+ with self .autocast_smart_context_manager (
2670+ enabled = hasattr (self , "scaler" ) and self .scaler .is_enabled ()
2671+ ):
26652672 loss , outputs = self .compute_loss (model , inputs , return_outputs = True )
26662673 loss = loss .mean ().detach ()
26672674
@@ -2671,7 +2678,9 @@ def prediction_step(
26712678 logits = outputs [1 :]
26722679 else :
26732680 loss = None
2674- with self .autocast_smart_context_manager (enabled = hasattr (self , "scaler" ) and self .scaler .is_enabled ()):
2681+ with self .autocast_smart_context_manager (
2682+ enabled = hasattr (self , "scaler" ) and self .scaler .is_enabled ()
2683+ ):
26752684 outputs = model (** inputs )
26762685 if isinstance (outputs , dict ):
26772686 logits = tuple (v for k , v in outputs .items () if k not in ignore_keys )
0 commit comments