Skip to content
This repository was archived by the owner on Jun 4, 2025. It is now read-only.

Commit 053646a

Browse files
authored
Fix DataParallel validation forward signatures (#47)
* Fix: DataParallel validation forward signatures * Update: generalize forward_fn selection * nit: space
1 parent 5afbd46 commit 053646a

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/transformers/trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2435,11 +2435,12 @@ def evaluation_loop(
24352435

24362436
observed_num_examples = 0
24372437
# Main evaluation loop
2438+
module_forward_fn = model.module.forward if isinstance(model, nn.DataParallel) else model.forward
24382439
for step, inputs in enumerate(dataloader):
24392440
inputs = {
24402441
k: inputs[k]
24412442
for k in inputs
2442-
if k in list(inspect.signature(model.forward).parameters.keys())
2443+
if k in list(inspect.signature(module_forward_fn).parameters.keys())
24432444
}
24442445

24452446
# Update the observed num examples

0 commit comments

Comments
 (0)