Skip to content

sdpa for bert casues nan when using bfloat16 with padding. #31035

@Leoyzen

Description

@Leoyzen

System Info

  • transformers: transformers==4.41.1
  • pytorch: 2.3.0+cu121
  • python: 3.10

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

In [1]: import torch.nn.functional as F

In [2]: import torch

In [3]: data = torch.load("reproduce_data.pt", map_location='cuda')

In [4]: with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
   ...:     print(F.scaled_dot_product_attention(data['q'], data['k'], data['v'], data['attn_mask'], data['dropout_p'], data['is_causal']).isnan().any((1
   ...: ,2)))
   ...:
tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],               # <----- without pandding
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],               # <----- with pandding
        [ True,  True,  True,  ...,  True,  True,  True],               # <----- with pandding
        [ True,  True,  True,  ...,  True,  True,  True]], device='cuda:0')

In [5]: with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=True):
   ...:     print(F.scaled_dot_product_attention(data['q'], data['k'], data['v'], data['attn_mask'], data['dropout_p'], data['is_causal']).isnan().any((1
   ...: ,2)))
   ...:
tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]], device='cuda:0')

# it's okay with offical math kernel
In [6]: with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
   ...:     print(F.scaled_dot_product_attention(data['q'], data['k'], data['v'], data['attn_mask'], data['dropout_p'], data['is_causal']).isnan().any((1
   ...: ,2)))
   ...:
/root/.local/share/conda/envs/bytednlp/lib/python3.10/site-packages/torch/backends/cuda/__init__.py:342: FutureWarning: torch.backends.cuda.sdp_kernel() is deprecated. In the future, this context manager will be removed. Please see, torch.nn.attention.sdpa_kernel() for the new context manager, with updated signature.
  warnings.warn(
tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],               # <----- with pandding, which is correct
        [False, False, False,  ..., False, False, False],               # <----- with pandding, which is correct
        [False, False, False,  ..., False, False, False]], device='cuda:0')

In [7]: mask = data['attn_mask']

# we use min / 2 as float('-inf')
In [10]: mask2 = mask.masked_fill(mask.bool(), torch.finfo(mask.dtype).min / 2)

In [11]: with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=True):
    ...:     print(F.scaled_dot_product_attention(data['q'], data['k'], data['v'], mask2, data['dropout_p'], data['is_causal']).isnan().any((1,2)))
    ...:
/root/.local/share/conda/envs/bytednlp/lib/python3.10/site-packages/torch/backends/cuda/__init__.py:342: FutureWarning: torch.backends.cuda.sdp_kernel() is deprecated. In the future, this context manager will be removed. Please see, torch.nn.attention.sdpa_kernel() for the new context manager, with updated signature.
  warnings.warn(
tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],              # <----- with pandding, which the result is correct.
        [False, False, False,  ..., False, False, False],               # <----- with pandding, which the result is correct.
        [False, False, False,  ..., False, False, False]], device='cuda:0')

Expected behavior

the output should without nan when using bfloat16 and sdap enabled.

I think it is safe to use torch.finfo(dtype).min / 2 instead of torch.finfo(dtype.min.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Good Second IssueIssues that are more difficult to do than "Good First" issues - give it a try if you want!

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions