|
| 1 | +import os |
| 2 | +import click |
| 3 | +from math import ceil |
| 4 | + |
| 5 | +import torch |
| 6 | +import torch.multiprocessing as mp |
| 7 | +import torch.distributed as dist |
| 8 | +from torch.amp import autocast |
| 9 | +from torch.nn.parallel import DistributedDataParallel as DDP |
| 10 | + |
| 11 | +from ring_attention_pytorch import RingAttention |
| 12 | +from ring_attention_pytorch.distributed import all_gather_variable_dim |
| 13 | + |
| 14 | +from einops import rearrange |
| 15 | + |
| 16 | +from ring_attention_pytorch.ring_attention import apply_rotary_pos_emb |
| 17 | + |
| 18 | +from ring_attention_pytorch.zig_zag_attention import ( |
| 19 | + zig_zag_pad_seq, |
| 20 | + zig_zag_attn, |
| 21 | + zig_zag_shard |
| 22 | +) |
| 23 | + |
| 24 | +def abs_diff(x, y): |
| 25 | + return (x - y).abs().amax() |
| 26 | + |
| 27 | +def setup( |
| 28 | + rank, |
| 29 | + world_size, |
| 30 | + use_cuda |
| 31 | +): |
| 32 | + os.environ['MASTER_ADDR'] = 'localhost' |
| 33 | + os.environ['MASTER_PORT'] = '12355' |
| 34 | + |
| 35 | + backend = "gloo" if not use_cuda else "nccl" |
| 36 | + dist.init_process_group(backend, rank = rank, world_size = world_size) |
| 37 | + |
| 38 | + if use_cuda: |
| 39 | + torch.cuda.set_device(rank) |
| 40 | + |
| 41 | +def cleanup(): |
| 42 | + dist.destroy_process_group() |
| 43 | + |
| 44 | +def start( |
| 45 | + rank, |
| 46 | + world_size, |
| 47 | + batch_size, |
| 48 | + batch_size_var_len, |
| 49 | + seq_len, |
| 50 | + num_sharded_batches, |
| 51 | + dim, |
| 52 | + heads, |
| 53 | + num_grouped_query_heads, |
| 54 | + dim_head, |
| 55 | + use_cuda, |
| 56 | + rotary |
| 57 | +): |
| 58 | + setup(rank, world_size, use_cuda) |
| 59 | + |
| 60 | + attention = RingAttention( |
| 61 | + dim = dim, |
| 62 | + dim_head = dim_head, |
| 63 | + heads = heads, |
| 64 | + num_grouped_query_heads = num_grouped_query_heads, |
| 65 | + causal = True, |
| 66 | + rotary_embed = rotary, |
| 67 | + ring_attn = False, |
| 68 | + use_cuda_kernel = use_cuda |
| 69 | + ) |
| 70 | + |
| 71 | + if batch_size_var_len: |
| 72 | + batch_size = batch_size + rank |
| 73 | + |
| 74 | + seq = torch.randn(batch_size, seq_len, dim) |
| 75 | + |
| 76 | + # move to cuda if needed |
| 77 | + |
| 78 | + if use_cuda: |
| 79 | + seq = seq.cuda(rank) |
| 80 | + attention.cuda(rank) |
| 81 | + |
| 82 | + # separate inputs for ring vs flash |
| 83 | + |
| 84 | + regular_input = seq.clone().requires_grad_() |
| 85 | + zig_zag_input = seq.clone().requires_grad_() |
| 86 | + |
| 87 | + # wrap |
| 88 | + |
| 89 | + ddp_attention = DDP(attention) |
| 90 | + |
| 91 | + # regular |
| 92 | + |
| 93 | + out = ddp_attention(regular_input) |
| 94 | + |
| 95 | + out.mean().backward() |
| 96 | + |
| 97 | + # zig zag |
| 98 | + |
| 99 | + padded_inp, remove_pad = zig_zag_pad_seq(zig_zag_input) |
| 100 | + (padded_inp, q_indices, kv_indices), gather_seq = zig_zag_shard(padded_inp, all_gather_batch = True) |
| 101 | + |
| 102 | + qkv = attention.to_qkv(padded_inp) |
| 103 | + |
| 104 | + q, k, v = rearrange(qkv, 'b n (h d) -> b h n d', d = dim_head).split(attention.qkv_head_breakdown, dim = -3) |
| 105 | + |
| 106 | + if rotary: |
| 107 | + pos_emb = attention.rotary_embed(q_indices) |
| 108 | + |
| 109 | + q = apply_rotary_pos_emb(pos_emb, q, head_dim_first = True) |
| 110 | + k = apply_rotary_pos_emb(pos_emb, k, head_dim_first = True) |
| 111 | + |
| 112 | + # causal mask |
| 113 | + |
| 114 | + causal_mask = q_indices[:, None] >= kv_indices[None, :] |
| 115 | + |
| 116 | + # attention |
| 117 | + |
| 118 | + o = zig_zag_attn( |
| 119 | + q, k, v, |
| 120 | + attn_mask = causal_mask |
| 121 | + ) |
| 122 | + |
| 123 | + o = rearrange(o, 'b h n d -> b n (h d)') |
| 124 | + |
| 125 | + padded_out = attention.to_out(o) |
| 126 | + |
| 127 | + padded_out = gather_seq(padded_out) |
| 128 | + |
| 129 | + zig_zag_out = remove_pad(padded_out) |
| 130 | + |
| 131 | + zig_zag_out.mean().backward() |
| 132 | + |
| 133 | + # validate output is the same for sequence split across machines vs without |
| 134 | + |
| 135 | + if rank == 0: |
| 136 | + out = out.cpu() |
| 137 | + zig_zag_out = zig_zag_out.cpu() |
| 138 | + |
| 139 | + output_atol = 1e-2 if use_cuda else 1e-6 |
| 140 | + |
| 141 | + assert torch.allclose(out, zig_zag_out, atol = output_atol), 'output is not the same' |
| 142 | + |
| 143 | + # validate gradients is the same |
| 144 | + |
| 145 | + regular_input_grad = regular_input.grad |
| 146 | + zig_zag_input_grad = zig_zag_input.grad |
| 147 | + |
| 148 | + assert torch.allclose( |
| 149 | + regular_input_grad, |
| 150 | + zig_zag_input_grad, |
| 151 | + atol = 1e-2 |
| 152 | + ), 'grad is not the same' |
| 153 | + |
| 154 | + print('✅ outputs and gradients are same between zig zag attention and regular attention') |
| 155 | + |
| 156 | + cleanup() |
| 157 | + |
| 158 | +@click.command() |
| 159 | +@click.option('--world-size', default = 8, help = 'number of machines / processes') |
| 160 | +@click.option('--batch-size', default = 2, help = 'test batch size') |
| 161 | +@click.option('--num-sharded-batches', default = 1, help = 'number of sharded batches') |
| 162 | +@click.option('--batch-size-var-len', is_flag = True, help = 'test variable lengthed batch sizes') |
| 163 | +@click.option('--use-cuda', is_flag = True, help = 'whether to test with CUDA and NCCL') |
| 164 | +@click.option('--rotary', is_flag = True, help = 'whether to test with rotary embeddings') |
| 165 | +@click.option('--seq-len', default = 31, help = 'sequence length to test') |
| 166 | +@click.option('--model-dim', default = 8, help = 'model dimensions for testing') |
| 167 | +@click.option('--heads', default = 8, help = 'number of query attention heads') |
| 168 | +@click.option('--num-grouped-query-heads', default = 2, help = 'number of query attention head groups') |
| 169 | +@click.option('--dim-head', default = 16, help = 'model dimensions for testing') |
| 170 | +def test( |
| 171 | + world_size: int, |
| 172 | + batch_size: int, |
| 173 | + num_sharded_batches: int, |
| 174 | + batch_size_var_len: bool, |
| 175 | + use_cuda: bool, |
| 176 | + rotary: bool, |
| 177 | + seq_len: int, |
| 178 | + model_dim: int, |
| 179 | + heads: int, |
| 180 | + num_grouped_query_heads: int, |
| 181 | + dim_head: int, |
| 182 | +): |
| 183 | + assert not use_cuda or world_size <= torch.cuda.device_count(), f'world size {world_size} must be less than the number of cuda devices {torch.cuda.device_count()}' |
| 184 | + |
| 185 | + mp.spawn( |
| 186 | + start, |
| 187 | + args = ( |
| 188 | + world_size, |
| 189 | + batch_size, |
| 190 | + batch_size_var_len, |
| 191 | + seq_len, |
| 192 | + num_sharded_batches, |
| 193 | + model_dim, |
| 194 | + heads, |
| 195 | + num_grouped_query_heads, |
| 196 | + dim_head, |
| 197 | + use_cuda, |
| 198 | + rotary |
| 199 | + ), |
| 200 | + nprocs = world_size, |
| 201 | + join = True |
| 202 | + ) |
| 203 | + |
| 204 | +if __name__ == '__main__': |
| 205 | + test() |
0 commit comments