Skip to content

Commit 77e3733

Browse files
authored
add logic for llama3 context parallelism scheme
1 parent 2a925e5 commit 77e3733

File tree

6 files changed

+384
-21
lines changed

6 files changed

+384
-21
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,4 +195,15 @@ $ python assert_tree_attn.py --use-cuda --seq-len 8192
195195
}
196196
```
197197

198+
```bibtex
199+
@article{Dubey2024TheL3,
200+
title = {The Llama 3 Herd of Models},
201+
author = {Abhimanyu Dubey and Abhinav Jauhri and Abhinav Pandey and Abhishek Kadian and Ahmad Al-Dahle and Aiesha Letman and Akhil Mathur and Alan Schelten and Amy Yang and Angela Fan and Anirudh Goyal and Anthony Hartshorn and Aobo Yang and Archi Mitra and Archie Sravankumar and Artem Korenev and Arthur Hinsvark and Arun Rao and Aston Zhang and Aurelien Rodriguez and Austen Gregerson and Ava Spataru and Baptiste Rozi{\`e}re and Bethany Biron and Binh Tang and Bobbie Chern and Charlotte Caucheteux and Chaya Nayak and Chloe Bi and Chris Marra and Chris McConnell and Christian Keller and Christophe Touret and Chunyang Wu and Corinne Wong and Cristian Cant{\'o}n Ferrer and Cyrus Nikolaidis and Damien Allonsius and Daniel Song and Danielle Pintz and Danny Livshits and David Esiobu and Dhruv Choudhary and Dhruv Mahajan and Diego Garcia-Olano and Diego Perino and Dieuwke Hupkes and Egor Lakomkin and Ehab A. AlBadawy and Elina Lobanova and Emily Dinan and Eric Michael Smith and Filip Radenovic and Frank Zhang and Gabriele Synnaeve and Gabrielle Lee and Georgia Lewis Anderson and Graeme Nail and Gr{\'e}goire Mialon and Guanglong Pang and Guillem Cucurell and Hailey Nguyen and Hannah Korevaar and Hu Xu and Hugo Touvron and Iliyan Zarov and Imanol Arrieta Ibarra and Isabel M. Kloumann and Ishan Misra and Ivan Evtimov and Jade Copet and Jaewon Lee and Jan Laurens Geffert and Jana Vranes and Jason Park and Jay Mahadeokar and Jeet Shah and Jelmer van der Linde and Jennifer Billock and Jenny Hong and Jenya Lee and Jeremy Fu and Jianfeng Chi and Jianyu Huang and Jiawen Liu and Jie Wang and Jiecao Yu and Joanna Bitton and Joe Spisak and Jongsoo Park and Joseph Rocca and Joshua Johnstun and Joshua Saxe and Ju-Qing Jia and Kalyan Vasuden Alwala and K. Upasani and Kate Plawiak and Keqian Li and Ken-591 neth Heafield and Kevin Stone and Khalid El-Arini and Krithika Iyer and Kshitiz Malik and Kuenley Chiu and Kunal Bhalla and Lauren Rantala-Yeary and Laurens van der Maaten and Lawrence Chen and Liang Tan and Liz Jenkins and Louis Martin and Lovish Madaan and Lubo Malo and Lukas Blecher and Lukas Landzaat and Luke de Oliveira and Madeline C. Muzzi and Mahesh Babu Pasupuleti and Mannat Singh and Manohar Paluri and Marcin Kardas and Mathew Oldham and Mathieu Rita and Maya Pavlova and Melissa Hall Melanie Kambadur and Mike Lewis and Min Si and Mitesh Kumar Singh and Mona Hassan and Naman Goyal and Narjes Torabi and Nikolay Bashlykov and Nikolay Bogoychev and Niladri S. Chatterji and Olivier Duchenne and Onur cCelebi and Patrick Alrassy and Pengchuan Zhang and Pengwei Li and Petar Vasi{\'c} and Peter Weng and Prajjwal Bhargava and Pratik Dubal and Praveen Krishnan and Punit Singh Koura and Puxin Xu and Qing He and Qingxiao Dong and Ragavan Srinivasan and Raj Ganapathy and Ramon Calderer and Ricardo Silveira Cabral and Robert Stojnic and Roberta Raileanu and Rohit Girdhar and Rohit Patel and Romain Sauvestre and Ronnie Polidoro and Roshan Sumbaly and Ross Taylor and Ruan Silva and Rui Hou and Rui Wang and Saghar Hosseini and Sahana Chennabasappa and Sanjay Singh and Sean Bell and Seohyun Sonia Kim and Sergey Edunov and Shaoliang Nie and Sharan Narang and Sharath Chandra Raparthy and Sheng Shen and Shengye Wan and Shruti Bhosale and Shun Zhang and Simon Vandenhende and Soumya Batra and Spencer Whitman and Sten Sootla and Stephane Collot and Suchin Gururangan and Sydney Borodinsky and Tamar Herman and Tara Fowler and Tarek Sheasha and Thomas Georgiou and Thomas Scialom and Tobias Speckbacher and Todor Mihaylov and Tong Xiao and Ujjwal Karn and Vedanuj Goswami and Vibhor Gupta and Vignesh Ramanathan and Viktor Kerkez and Vincent Gonguet and Virginie Do and Vish Vogeti and Vladan Petrovic and Weiwei Chu and Wenhan Xiong and Wenyin Fu and Whitney Meers and Xavier Martinet and Xiaodong Wang and Xiaoqing Ellen Tan and Xinfeng Xie and Xuchao Jia and Xuewei Wang and Yaelle Goldschlag and Yashesh Gaur and Yasmine Babaei and Yiqian Wen and Yiwen Song and Yuchen Zhang and Yue Li and Yuning Mao and Zacharie Delpierre Coudert and Zhengxu Yan and Zhengxing Chen and Zoe Papakipos and Aaditya K. Singh and Aaron Grattafiori and Abha Jain and Adam Kelsey and Adam Shajnfeld and Adi Gangidi and Adolfo Victoria and Ahuva Goldstand and Ajay Menon and Ajay Sharma and Alex Boesenberg and Alex Vaughan and Alexei Baevski and Allie Feinstein and Amanda Kallet and Amit Sangani and Anam Yunus and Andrei Lupu and Andres Alvarado and Andrew Caples and Andrew Gu and Andrew Ho and Andrew Poulton and Andrew Ryan and Ankit Ramchandani and Annie Franco and Aparajita Saraf and Arkabandhu Chowdhury and Ashley Gabriel and Ashwin Bharambe and Assaf Eisenman and Azadeh Yazdan and Beau James and Ben Maurer and Ben Leonhardi and Bernie Huang and Beth Loyd and Beto De Paola and Bhargavi Paranjape and Bing Liu and Bo Wu and Boyu Ni and Braden Hancock and Bram Wasti and Brandon Spence and Brani Stojkovic and Brian Gamido and Britt Montalvo and Carl Parker and Carly Burton and Catalina Mejia and Changhan Wang and Changkyu Kim and Chao Zhou and Chester Hu and Ching-Hsiang Chu and Chris Cai and Chris Tindal and Christoph Feichtenhofer and Damon Civin and Dana Beaty and Daniel Kreymer and Shang-Wen Li and Danny Wyatt and David Adkins and David Xu and Davide Testuggine and Delia David and Devi Parikh and Diana Liskovich and Didem Foss and Dingkang Wang and Duc Le and Dustin Holland and Edward Dowling and Eissa Jamil and Elaine Montgomery and Eleonora Presani and Emily Hahn and Emily Wood and Erik Brinkman and Esteban Arcaute and Evan Dunbar and Evan Smothers and Fei Sun and Felix Kreuk and Feng Tian and Firat Ozgenel and Francesco Caggioni and Francisco Guzm'an and Frank J. Kanayet and Frank Seide and Gabriela Medina Florez and Gabriella Schwarz and Gada Badeer and Georgia Swee and Gil Halpern and Govind Thattai and Grant Herman and Grigory G. Sizov and Guangyi Zhang and Guna Lakshminarayanan and Hamid Shojanazeri and Han Zou and Hannah Wang and Han Zha and Haroun Habeeb and Harrison Rudolph and Helen Suk and Henry Aspegren and Hunter Goldman and Igor Molybog and Igor Tufanov and Irina-Elena Veliche and Itai Gat and Jake Weissman and James Geboski and James Kohli and Japhet Asher and Jean-Baptiste Gaya and Jeff Marcus and Jeff Tang and Jennifer Chan and Jenny Zhen and Jeremy Reizenstein and Jeremy Teboul and Jessica Zhong and Jian Jin and Jingyi Yang and Joe Cummings and Jon Carvill and Jon Shepard and Jonathan McPhie and Jonathan Torres and Josh Ginsburg and Junjie Wang and Kaixing(Kai) Wu and U KamHou and Karan Saxena and Karthik Prasad and Kartikay Khandelwal and Katayoun Zand and Kathy Matosich and Kaushik Veeraraghavan and Kelly Michelena and Keqian Li and Kun Huang and Kunal Chawla and Kushal Lakhotia and Kyle Huang and Lailin Chen and Lakshya Garg and A Lavender and Leandro Silva and Lee Bell and Lei Zhang and Liangpeng Guo and Licheng Yu and Liron Moshkovich and Luca Wehrstedt and Madian Khabsa and Manav Avalani and Manish Bhatt and Maria Tsimpoukelli and Martynas Mankus and Matan Hasson and Matthew Lennie and Matthias Reso and Maxim Groshev and Maxim Naumov and Maya Lathi and Meghan Keneally and Michael L. Seltzer and Michal Valko and Michelle Restrepo and Mihir Patel and Mik Vyatskov and Mikayel Samvelyan and Mike Clark and Mike Macey and Mike Wang and Miquel Jubert Hermoso and Mo Metanat and Mohammad Rastegari and Munish Bansal and Nandhini Santhanam and Natascha Parks and Natasha White and Navyata Bawa and Nayan Singhal and Nick Egebo and Nicolas Usunier and Nikolay Pavlovich Laptev and Ning Dong and Ning Zhang and Norman Cheng and Oleg Chernoguz and Olivia Hart and Omkar Salpekar and Ozlem Kalinli and Parkin Kent and Parth Parekh and Paul Saab and Pavan Balaji and Pedro Rittner and Philip Bontrager and Pierre Roux and Piotr Doll{\'a}r and Polina Zvyagina and Prashant Ratanchandani and Pritish Yuvraj and Qian Liang and Rachad Alao and Rachel Rodriguez and Rafi Ayub and Raghotham Murthy and Raghu Nayani and Rahul Mitra and Raymond Li and Rebekkah Hogan and Robin Battey and Rocky Wang and Rohan Maheswari and Russ Howes and Ruty Rinott and Sai Jayesh Bondu and Samyak Datta and Sara Chugh and Sara Hunt and Sargun Dhillon and Sasha Sidorov and Satadru Pan and Saurabh Verma and Seiji Yamamoto and Sharadh Ramaswamy and Shaun Lindsay and Sheng Feng and Shenghao Lin and Shengxin Cindy Zha and Shiva Shankar and Shuqiang Zhang and Sinong Wang and Sneha Agarwal and Soji Sajuyigbe and Soumith Chintala and Stephanie Max and Stephen Chen and Steve Kehoe and Steve Satterfield and Sudarshan Govindaprasad and Sumit Gupta and Sung-Bae Cho and Sunny Virk and Suraj Subramanian and Sy Choudhury and Sydney Goldman and Tal Remez and Tamar Glaser and Tamara Best and Thilo Kohler and Thomas Robinson and Tianhe Li and Tianjun Zhang and Tim Matthews and Timothy Chou and Tzook Shaked and Varun Vontimitta and Victoria Ajayi and Victoria Montanez and Vijai Mohan and Vinay Satish Kumar and Vishal Mangla and Vlad Ionescu and Vlad Andrei Poenaru and Vlad T. Mihailescu and Vladimir Ivanov and Wei Li and Wenchen Wang and Wenwen Jiang and Wes Bouaziz and Will Constable and Xia Tang and Xiaofang Wang and Xiaojian Wu and Xiaolan Wang and Xide Xia and Xilun Wu and Xinbo Gao and Yanjun Chen and Ye Hu and Ye Jia and Ye Qi and Yenda Li and Yilin Zhang and Ying Zhang and Yossi Adi and Youngjin Nam and Yu Wang and Yuchen Hao and Yundi Qian and Yuzi He and Zach Rait and Zachary DeVito and Zef Rosnbrick and Zhaoduo Wen and Zhenyu Yang and Zhiwei Zhao},
202+
journal = {ArXiv},
203+
year = {2024},
204+
volume = {abs/2407.21783},
205+
url = {https://api.semanticscholar.org/CorpusID:271571434}
206+
}
207+
```
208+
198209
*<a href="http://www.incompleteideas.net/IncIdeas/BitterLesson.html">The Bitter Lesson</a>* - Richard Sutton

assert_zig_zag.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
import os
2+
import click
3+
from math import ceil
4+
5+
import torch
6+
import torch.multiprocessing as mp
7+
import torch.distributed as dist
8+
from torch.amp import autocast
9+
from torch.nn.parallel import DistributedDataParallel as DDP
10+
11+
from ring_attention_pytorch import RingAttention
12+
from ring_attention_pytorch.distributed import all_gather_variable_dim
13+
14+
from einops import rearrange
15+
16+
from ring_attention_pytorch.ring_attention import apply_rotary_pos_emb
17+
18+
from ring_attention_pytorch.zig_zag_attention import (
19+
zig_zag_pad_seq,
20+
zig_zag_attn,
21+
zig_zag_shard
22+
)
23+
24+
def abs_diff(x, y):
25+
return (x - y).abs().amax()
26+
27+
def setup(
28+
rank,
29+
world_size,
30+
use_cuda
31+
):
32+
os.environ['MASTER_ADDR'] = 'localhost'
33+
os.environ['MASTER_PORT'] = '12355'
34+
35+
backend = "gloo" if not use_cuda else "nccl"
36+
dist.init_process_group(backend, rank = rank, world_size = world_size)
37+
38+
if use_cuda:
39+
torch.cuda.set_device(rank)
40+
41+
def cleanup():
42+
dist.destroy_process_group()
43+
44+
def start(
45+
rank,
46+
world_size,
47+
batch_size,
48+
batch_size_var_len,
49+
seq_len,
50+
num_sharded_batches,
51+
dim,
52+
heads,
53+
num_grouped_query_heads,
54+
dim_head,
55+
use_cuda,
56+
rotary
57+
):
58+
setup(rank, world_size, use_cuda)
59+
60+
attention = RingAttention(
61+
dim = dim,
62+
dim_head = dim_head,
63+
heads = heads,
64+
num_grouped_query_heads = num_grouped_query_heads,
65+
causal = True,
66+
rotary_embed = rotary,
67+
ring_attn = False,
68+
use_cuda_kernel = use_cuda
69+
)
70+
71+
if batch_size_var_len:
72+
batch_size = batch_size + rank
73+
74+
seq = torch.randn(batch_size, seq_len, dim)
75+
76+
# move to cuda if needed
77+
78+
if use_cuda:
79+
seq = seq.cuda(rank)
80+
attention.cuda(rank)
81+
82+
# separate inputs for ring vs flash
83+
84+
regular_input = seq.clone().requires_grad_()
85+
zig_zag_input = seq.clone().requires_grad_()
86+
87+
# wrap
88+
89+
ddp_attention = DDP(attention)
90+
91+
# regular
92+
93+
out = ddp_attention(regular_input)
94+
95+
out.mean().backward()
96+
97+
# zig zag
98+
99+
padded_inp, remove_pad = zig_zag_pad_seq(zig_zag_input)
100+
(padded_inp, q_indices, kv_indices), gather_seq = zig_zag_shard(padded_inp, all_gather_batch = True)
101+
102+
qkv = attention.to_qkv(padded_inp)
103+
104+
q, k, v = rearrange(qkv, 'b n (h d) -> b h n d', d = dim_head).split(attention.qkv_head_breakdown, dim = -3)
105+
106+
if rotary:
107+
pos_emb = attention.rotary_embed(q_indices)
108+
109+
q = apply_rotary_pos_emb(pos_emb, q, head_dim_first = True)
110+
k = apply_rotary_pos_emb(pos_emb, k, head_dim_first = True)
111+
112+
# causal mask
113+
114+
causal_mask = q_indices[:, None] >= kv_indices[None, :]
115+
116+
# attention
117+
118+
o = zig_zag_attn(
119+
q, k, v,
120+
attn_mask = causal_mask
121+
)
122+
123+
o = rearrange(o, 'b h n d -> b n (h d)')
124+
125+
padded_out = attention.to_out(o)
126+
127+
padded_out = gather_seq(padded_out)
128+
129+
zig_zag_out = remove_pad(padded_out)
130+
131+
zig_zag_out.mean().backward()
132+
133+
# validate output is the same for sequence split across machines vs without
134+
135+
if rank == 0:
136+
out = out.cpu()
137+
zig_zag_out = zig_zag_out.cpu()
138+
139+
output_atol = 1e-2 if use_cuda else 1e-6
140+
141+
assert torch.allclose(out, zig_zag_out, atol = output_atol), 'output is not the same'
142+
143+
# validate gradients is the same
144+
145+
regular_input_grad = regular_input.grad
146+
zig_zag_input_grad = zig_zag_input.grad
147+
148+
assert torch.allclose(
149+
regular_input_grad,
150+
zig_zag_input_grad,
151+
atol = 1e-2
152+
), 'grad is not the same'
153+
154+
print('✅ outputs and gradients are same between zig zag attention and regular attention')
155+
156+
cleanup()
157+
158+
@click.command()
159+
@click.option('--world-size', default = 8, help = 'number of machines / processes')
160+
@click.option('--batch-size', default = 2, help = 'test batch size')
161+
@click.option('--num-sharded-batches', default = 1, help = 'number of sharded batches')
162+
@click.option('--batch-size-var-len', is_flag = True, help = 'test variable lengthed batch sizes')
163+
@click.option('--use-cuda', is_flag = True, help = 'whether to test with CUDA and NCCL')
164+
@click.option('--rotary', is_flag = True, help = 'whether to test with rotary embeddings')
165+
@click.option('--seq-len', default = 31, help = 'sequence length to test')
166+
@click.option('--model-dim', default = 8, help = 'model dimensions for testing')
167+
@click.option('--heads', default = 8, help = 'number of query attention heads')
168+
@click.option('--num-grouped-query-heads', default = 2, help = 'number of query attention head groups')
169+
@click.option('--dim-head', default = 16, help = 'model dimensions for testing')
170+
def test(
171+
world_size: int,
172+
batch_size: int,
173+
num_sharded_batches: int,
174+
batch_size_var_len: bool,
175+
use_cuda: bool,
176+
rotary: bool,
177+
seq_len: int,
178+
model_dim: int,
179+
heads: int,
180+
num_grouped_query_heads: int,
181+
dim_head: int,
182+
):
183+
assert not use_cuda or world_size <= torch.cuda.device_count(), f'world size {world_size} must be less than the number of cuda devices {torch.cuda.device_count()}'
184+
185+
mp.spawn(
186+
start,
187+
args = (
188+
world_size,
189+
batch_size,
190+
batch_size_var_len,
191+
seq_len,
192+
num_sharded_batches,
193+
model_dim,
194+
heads,
195+
num_grouped_query_heads,
196+
dim_head,
197+
use_cuda,
198+
rotary
199+
),
200+
nprocs = world_size,
201+
join = True
202+
)
203+
204+
if __name__ == '__main__':
205+
test()

ring_attention_pytorch/distributed.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,5 @@ def split_by_rank(x):
125125

126126
sizes = torch.tensor(sizes, device = out.device, dtype = torch.long)
127127
return out, sizes
128+
129+
all_gather = AllGatherFunction.apply

ring_attention_pytorch/ring_attention.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
from torch import nn, einsum, Tensor
66
import torch.nn.functional as F
7-
from torch.cuda.amp import autocast
7+
from torch.amp import autocast
88
from torch.nn import Module, ModuleList
99

1010
from einops import rearrange, repeat
@@ -125,33 +125,36 @@ def device(self):
125125
def is_cuda(self):
126126
return self.inv_freq.is_cuda
127127

128-
@autocast(enabled = False)
128+
@autocast('cuda', enabled = False)
129129
@beartype
130130
def forward(
131131
self,
132-
seq_len: int
132+
seq: int | Tensor
133133
):
134134
device = self.device
135135

136136
pos = None
137+
if torch.is_tensor(seq):
138+
pos = seq
137139

138-
if self.ring:
139-
if self.striped:
140-
buckets = 1 if self.is_cuda else self.buckets
141-
ring_stride = get_world_size() * buckets
140+
if not exists(pos):
141+
if self.ring:
142+
if self.striped:
143+
buckets = 1 if self.is_cuda else self.buckets
144+
ring_stride = get_world_size() * buckets
142145

143-
pos = torch.arange(seq_len // buckets, device = device)
144-
pos = repeat(pos, 'n -> n b', b = buckets)
146+
pos = torch.arange(seq // buckets, device = device)
147+
pos = repeat(pos, 'n -> n b', b = buckets)
145148

146-
pos = pos * ring_stride
147-
pos += torch.arange(buckets, device = device) + (get_rank() * buckets)
148-
pos = rearrange(pos, 'n b -> (b n)')
149+
pos = pos * ring_stride
150+
pos += torch.arange(buckets, device = device) + (get_rank() * buckets)
151+
pos = rearrange(pos, 'n b -> (b n)')
149152

153+
else:
154+
pos = torch.arange(seq, device = device)
155+
pos += seq * get_rank()
150156
else:
151-
pos = torch.arange(seq_len, device = device)
152-
pos += seq_len * get_rank()
153-
else:
154-
pos = torch.arange(seq_len, device = device)
157+
pos = torch.arange(seq, device = device)
155158

156159
pos = pos.type_as(self.inv_freq)
157160
freqs = einsum('i , j -> i j', pos, self.inv_freq)
@@ -161,9 +164,11 @@ def rotate_half(x):
161164
x1, x2 = x.chunk(2, dim = -1)
162165
return torch.cat((-x2, x1), dim=-1)
163166

164-
@autocast(enabled = False)
165-
def apply_rotary_pos_emb(pos, t):
166-
pos = rearrange(pos, 'n d -> n 1 d')
167+
@autocast('cuda', enabled = False)
168+
def apply_rotary_pos_emb(pos, t, head_dim_first = False):
169+
if not head_dim_first:
170+
pos = rearrange(pos, 'n d -> n 1 d')
171+
167172
return t * pos.cos() + rotate_half(t) * pos.sin()
168173

169174
# batch to sequence sharding and back

ring_attention_pytorch/ring_flash_attention_cuda.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from torch import nn, einsum, Tensor
77
from torch.autograd.function import Function
8-
from torch.cuda.amp import autocast
8+
from torch.amp import autocast
99

1010
from ring_attention_pytorch.ring import (
1111
ring_pass,
@@ -352,7 +352,7 @@ def backward(ctx, do):
352352

353353
ring_flash_attn_cuda_ = RingFlashAttentionCUDAFunction.apply
354354

355-
@autocast(enabled = False)
355+
@autocast('cuda', enabled = False)
356356
@beartype
357357
def ring_flash_attn_cuda(
358358
q: Tensor,

0 commit comments

Comments
 (0)