Skip to content

Commit 7a0c966

Browse files
committed
Apply fsdp forward register when fsdp is enabled
1 parent 68c8800 commit 7a0c966

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

src/transformers/trainer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2338,9 +2338,7 @@ def _inner_training_loop(
23382338

23392339
if self.is_fsdp_enabled:
23402340
self.model = self.model_wrapped = model
2341-
2342-
# Fix `got mixed torch.Tensor and DTensor` error in model.generate() for FSDP2 with LoRA
2343-
if is_fsdp2:
2341+
# Fix `got mixed torch.Tensor and DTensor` error in model.generate() for FSDP2 with LoRA
23442342
dist.fsdp.register_fsdp_forward_method(self.model, "generate")
23452343

23462344
# for the rest of this function `model` is the outside model, whether it was wrapped or not

0 commit comments

Comments
 (0)