Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,14 @@ def check_output(self, want, got, optionflags):

if is_torch_available():
import torch
from packaging import version

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
# We set it to `False` for CI. See https://github.com/pytorch/pytorch/issues/157274#issuecomment-3090791615
torch.backends.cudnn.allow_tf32 = False
if version.parse(torch.__version__) >= version.parse("2.9.0"):
torch.backends.cudnn.conv.fp32_precision = "ieee"
else:
torch.backends.cudnn.allow_tf32 = False

# patch `torch.compile`: if `TORCH_COMPILE_FORCE_FULLGRAPH=1` (or values considered as true, e.g. yes, y, etc.),
# the patched version will always run with `fullgraph=True`.
Expand Down
10 changes: 8 additions & 2 deletions docs/source/en/perf_train_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,14 @@ tf32 is enabled by default on NVIDIA Ampere GPUs, but you can also add the code

```py
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
from packaging import version

if version.parse(torch.__version__) >= version.parse("2.9.0"):
torch.backends.cuda.matmul.fp32_precision = "tf32"
torch.backends.cudnn.conv.fp32_precision = "tf32"
else:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
```

Configure [tf32()](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.tf32) in [`TrainingArguments`] to enable mixed precision training with tf32 mode.
Expand Down
10 changes: 8 additions & 2 deletions docs/source/ja/perf_train_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,14 @@ training_args = TrainingArguments(bf16=True, **default_args)

```python
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
from packaging import version

if version.parse(torch.__version__) >= version.parse("2.9.0"):
torch.backends.cuda.matmul.fp32_precision = "tf32"
torch.backends.cudnn.conv.fp32_precision = "tf32"
else:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
```


Expand Down
10 changes: 8 additions & 2 deletions docs/source/ko/perf_train_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,14 @@ tf32는 NVIDIA Ampere GPU에서 기본적으로 활성화되어 있지만, fp32

```py
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
from packaging import version

if version.parse(torch.__version__) >= version.parse("2.9.0"):
torch.backends.cuda.matmul.fp32_precision = "tf32"
torch.backends.cudnn.conv.fp32_precision = "tf32"
else:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
```

tf32 모드에서 혼합 정밀도 학습을 활성화하려면 [`TrainingArguments`]에서 [tf32()](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.tf32) 옵션을 설정하세요.
Expand Down
26 changes: 20 additions & 6 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from functools import cached_property
from typing import Any

from packaging import version

from .debug_utils import DebugOption
from .trainer_utils import (
FSDPOption,
Expand Down Expand Up @@ -1604,8 +1606,12 @@ def __post_init__(self):
if is_torch_musa_available():
torch.backends.mudnn.allow_tf32 = True
else:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
if version.parse(torch.__version__) >= version.parse("2.9.0"):
torch.backends.cuda.matmul.fp32_precision = "tf32"
torch.backends.cudnn.conv.fp32_precision = "tf32"
else:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
else:
logger.warning(
"The speedups for torchdynamo mostly come with GPU Ampere or higher and which is not detected here."
Expand All @@ -1616,17 +1622,25 @@ def __post_init__(self):
if is_torch_musa_available():
torch.backends.mudnn.allow_tf32 = True
else:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
if version.parse(torch.__version__) >= version.parse("2.9.0"):
torch.backends.cuda.matmul.fp32_precision = "tf32"
torch.backends.cudnn.conv.fp32_precision = "tf32"
else:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
else:
raise ValueError("--tf32 requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7")
else:
if is_torch_tf32_available():
if is_torch_musa_available():
torch.backends.mudnn.allow_tf32 = False
else:
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
if version.parse(torch.__version__) >= version.parse("2.9.0"):
torch.backends.cuda.matmul.fp32_precision = "ieee"
torch.backends.cudnn.conv.fp32_precision = "ieee"
else:
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
# no need to assert on else

if self.report_to == "all" or self.report_to == ["all"]:
Expand Down
67 changes: 67 additions & 0 deletions src/transformers/utils/tf32_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""python
# Defensive TF32 toggling helper.
# Usage: from transformers.utils.tf32_utils import set_tf32_mode; set_tf32_mode(True|False)

def set_tf32_mode(enable: bool) -> None:
"""
Safely toggle TF32 on the current environment.

- No-ops on CPU-only runners.
- Tries the new PyTorch >=2.9 API (torch.backends.cuda.matmul.fp32_precision / torch.backends.cudnn.conv.fp32_precision)
- Falls back to the old API (torch.backends.cuda.matmul.allow_tf32 / torch.backends.cudnn.allow_tf32)
- Handles MUSA/mudnn allow_tf32 if present.
- Swallows exceptions to avoid failing tests on exotic environments.
"""
try:
import torch

# If CUDA isn't available, bail out early. This avoids AttributeError on CPU CI.
if not getattr(torch.cuda, "is_available", lambda: False)():
return

# MUSA uses mudnn.allow_tf32 in some environments; try setting if present.
try:
if hasattr(torch.backends, "mudnn"):
try:
torch.backends.mudnn.allow_tf32 = bool(enable)
except Exception:
# Some builds may not expose this; ignore failures.
pass
except Exception:
# defensive outer catch for weird torch builds
pass

# Safely access cuda.matmul and cudnn.conv where available.
cuda_backend = getattr(torch.backends, "cuda", None)
matmul = getattr(cuda_backend, "matmul", None)
cudnn_backend = getattr(torch.backends, "cudnn", None)
cudnn_conv = getattr(cudnn_backend, "conv", None)

# New API (PyTorch >= 2.9): fp32_precision = "tf32" / "ieee"
if matmul is not None and hasattr(matmul, "fp32_precision"):
try:
matmul.fp32_precision = "tf32" if enable else "ieee"
except Exception:
pass
elif matmul is not None and hasattr(matmul, "allow_tf32"):
try:
matmul.allow_tf32 = bool(enable)
except Exception:
pass

# cudnn.conv may have fp32_precision (new API) or allow_tf32 (old API).
if cudnn_conv is not None and hasattr(cudnn_conv, "fp32_precision"):
try:
cudnn_conv.fp32_precision = "tf32" if enable else "ieee"
except Exception:
pass
elif hasattr(torch.backends, "cudnn") and hasattr(torch.backends.cudnn, "allow_tf32"):
try:
torch.backends.cudnn.allow_tf32 = bool(enable)
except Exception:
pass

except Exception:
# Never raise here: toggling TF32 should never break tests or examples.
return
"""
7 changes: 6 additions & 1 deletion utils/modular_model_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,12 @@ def __init__(self, hub_dataset: str):
logging.getLogger(name).setLevel(logging.ERROR)
huggingface_hub_logging.set_verbosity_error()
transformers_logging.set_verbosity_error()
torch.backends.cuda.matmul.allow_tf32 = True
from packaging import version

if version.parse(torch.__version__) >= version.parse("2.9.0"):
torch.backends.cuda.matmul.fp32_precision = "tf32"
else:
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_grad_enabled(False)

self.models_root = MODELS_ROOT
Expand Down