Skip to content

Commit 7f5c209

Browse files
authored
logic to select tf32 API as per Pytorch version (#42428)
* logic to select tf32 API as per Pytorch version * new method added into __all__ * make style and quality ran * added global setting for tf32 * added support for MUSA as well * make style and quality run * cleared >= 2.9.0 torch version logic
1 parent 01823d7 commit 7f5c209

File tree

5 files changed

+33
-21
lines changed

5 files changed

+33
-21
lines changed

conftest.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
patch_testing_methods_to_collect_info,
3232
patch_torch_compile_force_graph,
3333
)
34+
from transformers.utils import enable_tf32
3435

3536

3637
NOT_DEVICE_TESTS = {
@@ -137,11 +138,9 @@ def check_output(self, want, got, optionflags):
137138
doctest.DocTestParser = HfDocTestParser
138139

139140
if is_torch_available():
140-
import torch
141-
142141
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
143142
# We set it to `False` for CI. See https://github.com/pytorch/pytorch/issues/157274#issuecomment-3090791615
144-
torch.backends.cudnn.allow_tf32 = False
143+
enable_tf32(False)
145144

146145
# patch `torch.compile`: if `TORCH_COMPILE_FORCE_FULLGRAPH=1` (or values considered as true, e.g. yes, y, etc.),
147146
# the patched version will always run with `fullgraph=True`.

src/transformers/training_args.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
requires_backends,
5454
)
5555
from .utils.generic import strtobool
56-
from .utils.import_utils import is_optimum_neuron_available
56+
from .utils.import_utils import enable_tf32, is_optimum_neuron_available
5757

5858

5959
logger = logging.get_logger(__name__)
@@ -379,7 +379,7 @@ class TrainingArguments:
379379
metric values.
380380
tf32 (`bool`, *optional*):
381381
Whether to enable the TF32 mode, available in Ampere and newer GPU architectures. The default value depends
382-
on PyTorch's version default of `torch.backends.cuda.matmul.allow_tf32`. For more details please refer to
382+
on PyTorch's version default of `torch.backends.cuda.matmul.allow_tf32` and For PyTorch 2.9+ torch.backends.cuda.matmul.fp32_precision. For more details please refer to
383383
the [TF32](https://huggingface.co/docs/transformers/perf_train_gpu_one#tf32) documentation. This is an
384384
experimental API and it may change.
385385
ddp_backend (`str`, *optional*):
@@ -1601,32 +1601,20 @@ def __post_init__(self):
16011601
f"Setting TF32 in {device_str} backends to speedup torch compile, you won't see any improvement"
16021602
" otherwise."
16031603
)
1604-
if is_torch_musa_available():
1605-
torch.backends.mudnn.allow_tf32 = True
1606-
else:
1607-
torch.backends.cuda.matmul.allow_tf32 = True
1608-
torch.backends.cudnn.allow_tf32 = True
1604+
enable_tf32(True)
16091605
else:
16101606
logger.warning(
16111607
"The speedups for torchdynamo mostly come with GPU Ampere or higher and which is not detected here."
16121608
)
16131609
if is_torch_available() and self.tf32 is not None:
16141610
if self.tf32:
16151611
if is_torch_tf32_available():
1616-
if is_torch_musa_available():
1617-
torch.backends.mudnn.allow_tf32 = True
1618-
else:
1619-
torch.backends.cuda.matmul.allow_tf32 = True
1620-
torch.backends.cudnn.allow_tf32 = True
1612+
enable_tf32(True)
16211613
else:
16221614
raise ValueError("--tf32 requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7")
16231615
else:
16241616
if is_torch_tf32_available():
1625-
if is_torch_musa_available():
1626-
torch.backends.mudnn.allow_tf32 = False
1627-
else:
1628-
torch.backends.cuda.matmul.allow_tf32 = False
1629-
torch.backends.cudnn.allow_tf32 = False
1617+
enable_tf32(False)
16301618
# no need to assert on else
16311619

16321620
if self.report_to == "all" or self.report_to == ["all"]:

src/transformers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@
108108
_LazyModule,
109109
check_torch_load_is_safe,
110110
direct_transformers_import,
111+
enable_tf32,
111112
get_torch_version,
112113
is_accelerate_available,
113114
is_apex_available,

src/transformers/utils/import_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,29 @@ def is_torch_tf32_available() -> bool:
508508
return True
509509

510510

511+
@lru_cache
512+
def enable_tf32(enable: bool) -> None:
513+
"""
514+
Set TF32 mode using the appropriate PyTorch API.
515+
For PyTorch 2.9+, uses the new fp32_precision API.
516+
For older versions, uses the legacy allow_tf32 flags.
517+
Args:
518+
enable: Whether to enable TF32 mode
519+
"""
520+
import torch
521+
522+
pytorch_version = version.parse(get_torch_version())
523+
if pytorch_version >= version.parse("2.9.0"):
524+
precision_mode = "tf32" if enable else "ieee"
525+
torch.backends.fp32_precision = precision_mode
526+
else:
527+
if is_torch_musa_available():
528+
torch.backends.mudnn.allow_tf32 = enable
529+
else:
530+
torch.backends.cuda.matmul.allow_tf32 = enable
531+
torch.backends.cudnn.allow_tf32 = enable
532+
533+
511534
@lru_cache
512535
def is_torch_flex_attn_available() -> bool:
513536
return is_torch_available() and version.parse(get_torch_version()) >= version.parse("2.5.0")

utils/modular_model_detector.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@
117117

118118
import transformers
119119
from transformers import AutoModel, AutoTokenizer
120+
from transformers.utils import enable_tf32
120121
from transformers.utils import logging as transformers_logging
121122

122123

@@ -247,7 +248,7 @@ def __init__(self, hub_dataset: str):
247248
logging.getLogger(name).setLevel(logging.ERROR)
248249
huggingface_hub_logging.set_verbosity_error()
249250
transformers_logging.set_verbosity_error()
250-
torch.backends.cuda.matmul.allow_tf32 = True
251+
enable_tf32(True)
251252
torch.set_grad_enabled(False)
252253

253254
self.models_root = MODELS_ROOT

0 commit comments

Comments
 (0)