-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Closed
Labels
Good Second IssueIssues that are more difficult to do than "Good First" issues - give it a try if you want!Issues that are more difficult to do than "Good First" issues - give it a try if you want!
Description
System Info
- transformers: transformers==4.41.1
- pytorch: 2.3.0+cu121
- python: 3.10
Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
In [1]: import torch.nn.functional as F
In [2]: import torch
In [3]: data = torch.load("reproduce_data.pt", map_location='cuda')
In [4]: with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
...: print(F.scaled_dot_product_attention(data['q'], data['k'], data['v'], data['attn_mask'], data['dropout_p'], data['is_causal']).isnan().any((1
...: ,2)))
...:
tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False], # <----- without pandding
...,
[ True, True, True, ..., True, True, True], # <----- with pandding
[ True, True, True, ..., True, True, True], # <----- with pandding
[ True, True, True, ..., True, True, True]], device='cuda:0')
In [5]: with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=True):
...: print(F.scaled_dot_product_attention(data['q'], data['k'], data['v'], data['attn_mask'], data['dropout_p'], data['is_causal']).isnan().any((1
...: ,2)))
...:
tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]], device='cuda:0')
# it's okay with offical math kernel
In [6]: with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
...: print(F.scaled_dot_product_attention(data['q'], data['k'], data['v'], data['attn_mask'], data['dropout_p'], data['is_causal']).isnan().any((1
...: ,2)))
...:
/root/.local/share/conda/envs/bytednlp/lib/python3.10/site-packages/torch/backends/cuda/__init__.py:342: FutureWarning: torch.backends.cuda.sdp_kernel() is deprecated. In the future, this context manager will be removed. Please see, torch.nn.attention.sdpa_kernel() for the new context manager, with updated signature.
warnings.warn(
tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False], # <----- with pandding, which is correct
[False, False, False, ..., False, False, False], # <----- with pandding, which is correct
[False, False, False, ..., False, False, False]], device='cuda:0')
In [7]: mask = data['attn_mask']
# we use min / 2 as float('-inf')
In [10]: mask2 = mask.masked_fill(mask.bool(), torch.finfo(mask.dtype).min / 2)
In [11]: with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=True):
...: print(F.scaled_dot_product_attention(data['q'], data['k'], data['v'], mask2, data['dropout_p'], data['is_causal']).isnan().any((1,2)))
...:
/root/.local/share/conda/envs/bytednlp/lib/python3.10/site-packages/torch/backends/cuda/__init__.py:342: FutureWarning: torch.backends.cuda.sdp_kernel() is deprecated. In the future, this context manager will be removed. Please see, torch.nn.attention.sdpa_kernel() for the new context manager, with updated signature.
warnings.warn(
tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False], # <----- with pandding, which the result is correct.
[False, False, False, ..., False, False, False], # <----- with pandding, which the result is correct.
[False, False, False, ..., False, False, False]], device='cuda:0')Expected behavior
the output should without nan when using bfloat16 and sdap enabled.
I think it is safe to use torch.finfo(dtype).min / 2 instead of torch.finfo(dtype.min.
Metadata
Metadata
Assignees
Labels
Good Second IssueIssues that are more difficult to do than "Good First" issues - give it a try if you want!Issues that are more difficult to do than "Good First" issues - give it a try if you want!