Skip to content

Conversation

@Bissmella
Copy link

@Bissmella Bissmella commented Nov 21, 2025

What does this PR do?

This is a draft implementation of the Unified SP attention approach.

  • Implements _all_to_all_dim_exchange with custom scatter and gather indices
  • Implements TemplatedUnifiedAttention

Core implementation complete, needs:

  • Testing
  • Validation

@sayakpaul
Copy link
Member

It would be nice to get a testing script so that we can quickly check things.

@KarthikSundar2002
Copy link

I added a basic test script with a simple forward and backward op. Is it better to have a test script with flash_attention_backward and forward??

@Bissmella Bissmella force-pushed the unified-SP-attention branch from a244006 to 9dee8f8 Compare November 24, 2025 10:54
@Bissmella Bissmella marked this pull request as ready for review November 24, 2025 10:56
@Bissmella Bissmella force-pushed the unified-SP-attention branch from 9dee8f8 to 9ebcff5 Compare November 24, 2025 23:00
@sayakpaul
Copy link
Member

Let us know if this is ready for a review!

@Bissmella
Copy link
Author

Yep, ready for review! I tested it with a 4-process setup (2×2 mesh, on cpu) and everything checks out, shapes look good and gradients flow correctly. Looking forward for feedback and happy to address any issues.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for getting started on this!

Comment on lines +93 to +96
# if self.ring_degree > 1 and self.ulysses_degree > 1:
# raise ValueError(
# "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
# )
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this needs to be removed?

x = _wait_tensor(x)
return x

def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also add some basic docstrings to this function because it will help readability. Some commentaries on what the function is doing will also be helpful.

Comment on lines +1032 to +1035
B, S_LOCAL, H, D = x.shape
S = S_LOCAL * group_world_size
H_LOCAL = H // group_world_size

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit): prefer using fully qualified variable names.

grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value))

return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the change here?

Comment on lines +1331 to +1333
query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx)
key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx)
value = SeqAllToAllDim.apply(ulysses_group, value, scatter_idx, gather_idx)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it's better to have SeqAllToAllDim accept QKV tensors as inputs rather having them like this?

Comment on lines +1328 to +1329
scatter_idx = 2
gather_idx = 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make this configurable.

@sayakpaul
Copy link
Member

I am trying with the following code:

import torch
from torch import distributed as dist
from diffusers import AutoModel, DiffusionPipeline, ContextParallelConfig

def setup_distributed():
    if not dist.is_initialized():
        dist.init_process_group(backend="nccl")
    device = torch.device(f"cuda:{dist.get_rank()}")
    torch.cuda.set_device(device)
    return device

device = setup_distributed()
    
# Need to add parallel support for this.
# pipeline.transformer.set_attention_backend("flash_hub")
pipeline = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",  torch_dtype=torch.bfloat16,
).to(device)
pipeline.transformer.set_attention_backend("_native_cudnn")
pipeline.transformer.enable_parallelism(
    config=ContextParallelConfig(ulysses_degree=2, ring_degree=2)
)

prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""

generator = torch.Generator().manual_seed(42)
image = pipeline(prompt, guidance_scale=3.5, num_inference_steps=50, generator=generator).images[0]

if dist.get_rank() == 0:
    image.save("output_ua.png")
if dist.is_initialized():
    dist.destroy_process_group()

Run the above with torchrun --nproc-per-node 4 check_unified_attention.py.

And it leads to:
https://pastebin.com/A7KkvXH2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants