Skip to content

Commit f2f6d7a

Browse files
committed
fix test for cuda version of tree attn decode
1 parent cc1aea8 commit f2f6d7a

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

assert_tree_attn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ def start(
6161
k = torch.randn(batch, heads, seq_len, dim)
6262
v = torch.randn(batch, heads, seq_len, dim)
6363

64+
if use_cuda:
65+
q, k, v = tuple(t.cuda(rank) for t in (q, k, v))
66+
6467
# easy forcing all q, k, v to be same across all device
6568

6669
dist.all_reduce(q)

0 commit comments

Comments
 (0)