Skip to content

Commit c42c246

Browse files
authored
Enhance MLA Reduce for Prefill (#1562)
* Divide workload of batch whose seqlen > 1 to multiple work groups if there is enough CUs. * format * Add simple pipeline * Fix issues raised by copilot * Fix trivial issues and add nhead=1, hdim=128 cases. * simplify template args * add more group cases * Remove unnecessary memory check.
1 parent 472513d commit c42c246

File tree

5 files changed

+378
-252
lines changed

5 files changed

+378
-252
lines changed

aiter/mla.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ def mla_decode_fwd(
317317
reduce_indptr,
318318
reduce_final_map,
319319
reduce_partial_map,
320+
max_seqlen_q,
320321
o,
321322
final_lse,
322323
)

aiter/ops/attention.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ def pa_reduce_v1(
176176
reduce_indptr: torch.Tensor,
177177
reduce_final_map: Optional[torch.Tensor],
178178
reduce_partial_map: torch.Tensor,
179+
max_seqlen_q: int,
179180
final_output: torch.Tensor,
180181
final_lse: Optional[torch.Tensor] = None,
181182
) -> None:
@@ -185,6 +186,7 @@ def pa_reduce_v1(
185186
reduce_indptr,
186187
reduce_final_map,
187188
reduce_partial_map,
189+
max_seqlen_q,
188190
final_output,
189191
final_lse,
190192
)
@@ -252,6 +254,7 @@ def pa_persistent_fwd(
252254
reduce_indptr,
253255
reduce_final_map,
254256
reduce_partial_map,
257+
max_qlen,
255258
output,
256259
final_lse,
257260
)
@@ -757,6 +760,7 @@ def mla_reduce_v1(
757760
reduce_indptr: torch.Tensor,
758761
reduce_final_map: Optional[torch.Tensor],
759762
reduce_partial_map: torch.Tensor,
763+
max_seqlen_q: int,
760764
final_output: torch.Tensor,
761765
final_lse: Optional[torch.Tensor] = None,
762766
) -> None: ...

csrc/include/mla.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ void mla_reduce_v1(const torch::Tensor& partial_output,
6868
const torch::Tensor& reduce_indptr,
6969
const std::optional<torch::Tensor>& reduce_final_map,
7070
const torch::Tensor& reduce_partial_map,
71+
const int max_seqlen_q,
7172
torch::Tensor& final_output,
7273
std::optional<torch::Tensor>& final_lse);
7374

csrc/include/rocm_ops.hpp

Lines changed: 50 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1406,51 +1406,51 @@ namespace py = pybind11;
14061406
py::arg("stride0"), \
14071407
py::arg("stride1"));
14081408

1409-
#define MLA_METADATA_PYBIND \
1410-
m.def("get_mla_metadata_v1", \
1411-
&get_mla_metadata_v1, \
1412-
"get_mla_metadata_v1", \
1413-
py::arg("seqlens_qo_indptr"), \
1414-
py::arg("seqlens_kv_indptr"), \
1415-
py::arg("num_heads_per_head_k"), \
1416-
py::arg("num_heads_k"), \
1417-
py::arg("is_causal"), \
1418-
py::arg("work_metadata_ptrs"), \
1419-
py::arg("work_info_set"), \
1420-
py::arg("work_indptr"), \
1421-
py::arg("reduce_indptr"), \
1422-
py::arg("reduce_final_map"), \
1423-
py::arg("reduce_partial_map"), \
1424-
py::arg("kv_granularity") = 16, \
1425-
py::arg("max_seqlen_qo") = -1, \
1426-
py::arg("uni_seqlen_qo") = -1, \
1427-
py::arg("fast_mode") = true, \
1428-
py::arg("topk") = -1, \
1429-
py::arg("max_split_per_batch") = -1, \
1430-
py::arg("dtype_q") = std::nullopt, \
1431-
py::arg("dtype_kv") = std::nullopt); \
1409+
#define MLA_METADATA_PYBIND \
1410+
m.def("get_mla_metadata_v1", \
1411+
&get_mla_metadata_v1, \
1412+
"get_mla_metadata_v1", \
1413+
py::arg("seqlens_qo_indptr"), \
1414+
py::arg("seqlens_kv_indptr"), \
1415+
py::arg("num_heads_per_head_k"), \
1416+
py::arg("num_heads_k"), \
1417+
py::arg("is_causal"), \
1418+
py::arg("work_metadata_ptrs"), \
1419+
py::arg("work_info_set"), \
1420+
py::arg("work_indptr"), \
1421+
py::arg("reduce_indptr"), \
1422+
py::arg("reduce_final_map"), \
1423+
py::arg("reduce_partial_map"), \
1424+
py::arg("kv_granularity") = 16, \
1425+
py::arg("max_seqlen_qo") = -1, \
1426+
py::arg("uni_seqlen_qo") = -1, \
1427+
py::arg("fast_mode") = true, \
1428+
py::arg("topk") = -1, \
1429+
py::arg("max_split_per_batch") = -1, \
1430+
py::arg("dtype_q") = std::nullopt, \
1431+
py::arg("dtype_kv") = std::nullopt); \
14321432
m.def("get_mla_metadata_v1_no_redundant", &get_mla_metadata_v1_no_redundant);
14331433

1434-
#define PA_METADATA_PYBIND \
1435-
m.def("get_pa_metadata_v1", \
1436-
&get_pa_metadata_v1, \
1437-
"get_pa_metadata_v1", \
1438-
py::arg("seqlens_qo_indptr"), \
1439-
py::arg("pages_kv_indptr"), \
1440-
py::arg("num_heads_per_head_k"), \
1441-
py::arg("num_heads_k"), \
1442-
py::arg("is_causal"), \
1443-
py::arg("work_metadata_ptrs"), \
1444-
py::arg("work_indptr"), \
1445-
py::arg("work_info"), \
1446-
py::arg("reduce_indptr"), \
1447-
py::arg("reduce_final_map"), \
1448-
py::arg("reduce_partial_map"), \
1449-
py::arg("kv_granularity") = 16, \
1450-
py::arg("max_seqlen_qo") = -1, \
1451-
py::arg("uni_seqlen_qo") = -1, \
1452-
py::arg("fast_mode") = true, \
1453-
py::arg("topk") = -1, \
1434+
#define PA_METADATA_PYBIND \
1435+
m.def("get_pa_metadata_v1", \
1436+
&get_pa_metadata_v1, \
1437+
"get_pa_metadata_v1", \
1438+
py::arg("seqlens_qo_indptr"), \
1439+
py::arg("pages_kv_indptr"), \
1440+
py::arg("num_heads_per_head_k"), \
1441+
py::arg("num_heads_k"), \
1442+
py::arg("is_causal"), \
1443+
py::arg("work_metadata_ptrs"), \
1444+
py::arg("work_indptr"), \
1445+
py::arg("work_info"), \
1446+
py::arg("reduce_indptr"), \
1447+
py::arg("reduce_final_map"), \
1448+
py::arg("reduce_partial_map"), \
1449+
py::arg("kv_granularity") = 16, \
1450+
py::arg("max_seqlen_qo") = -1, \
1451+
py::arg("uni_seqlen_qo") = -1, \
1452+
py::arg("fast_mode") = true, \
1453+
py::arg("topk") = -1, \
14541454
py::arg("max_split_per_batch") = -1);
14551455

14561456
#define MLA_REDUCE_PYBIND \
@@ -1462,13 +1462,14 @@ namespace py = pybind11;
14621462
py::arg("reduce_indptr"), \
14631463
py::arg("reduce_final_map"), \
14641464
py::arg("reduce_partial_map"), \
1465+
py::arg("max_seqlen_q"), \
14651466
py::arg("final_output"), \
14661467
py::arg("final_lse") = std::nullopt);
14671468

1468-
#define TOPK_PLAIN_PYBIND \
1469-
m.def("topk_plain", \
1470-
&topk_plain, \
1471-
py::arg("values"), \
1472-
py::arg("topk_ids"), \
1473-
py::arg("topk"), \
1469+
#define TOPK_PLAIN_PYBIND \
1470+
m.def("topk_plain", \
1471+
&topk_plain, \
1472+
py::arg("values"), \
1473+
py::arg("topk_ids"), \
1474+
py::arg("topk"), \
14741475
py::arg("largest"));

0 commit comments

Comments
 (0)