@@ -1406,51 +1406,51 @@ namespace py = pybind11;
14061406 py::arg(" stride0" ), \
14071407 py::arg(" stride1" ));
14081408
1409- #define MLA_METADATA_PYBIND \
1410- m.def(" get_mla_metadata_v1" , \
1411- &get_mla_metadata_v1, \
1412- " get_mla_metadata_v1" , \
1413- py::arg (" seqlens_qo_indptr" ), \
1414- py::arg(" seqlens_kv_indptr" ), \
1415- py::arg(" num_heads_per_head_k" ), \
1416- py::arg(" num_heads_k" ), \
1417- py::arg(" is_causal" ), \
1418- py::arg(" work_metadata_ptrs" ), \
1419- py::arg(" work_info_set" ), \
1420- py::arg(" work_indptr" ), \
1421- py::arg(" reduce_indptr" ), \
1422- py::arg(" reduce_final_map" ), \
1423- py::arg(" reduce_partial_map" ), \
1424- py::arg(" kv_granularity" ) = 16, \
1425- py::arg(" max_seqlen_qo" ) = -1, \
1426- py::arg(" uni_seqlen_qo" ) = -1, \
1427- py::arg(" fast_mode" ) = true, \
1428- py::arg(" topk" ) = -1, \
1429- py::arg(" max_split_per_batch" ) = -1, \
1430- py::arg(" dtype_q" ) = std::nullopt, \
1431- py::arg(" dtype_kv" ) = std::nullopt); \
1409+ #define MLA_METADATA_PYBIND \
1410+ m.def(" get_mla_metadata_v1" , \
1411+ &get_mla_metadata_v1, \
1412+ " get_mla_metadata_v1" , \
1413+ py::arg (" seqlens_qo_indptr" ), \
1414+ py::arg(" seqlens_kv_indptr" ), \
1415+ py::arg(" num_heads_per_head_k" ), \
1416+ py::arg(" num_heads_k" ), \
1417+ py::arg(" is_causal" ), \
1418+ py::arg(" work_metadata_ptrs" ), \
1419+ py::arg(" work_info_set" ), \
1420+ py::arg(" work_indptr" ), \
1421+ py::arg(" reduce_indptr" ), \
1422+ py::arg(" reduce_final_map" ), \
1423+ py::arg(" reduce_partial_map" ), \
1424+ py::arg(" kv_granularity" ) = 16, \
1425+ py::arg(" max_seqlen_qo" ) = -1, \
1426+ py::arg(" uni_seqlen_qo" ) = -1, \
1427+ py::arg(" fast_mode" ) = true, \
1428+ py::arg(" topk" ) = -1, \
1429+ py::arg(" max_split_per_batch" ) = -1, \
1430+ py::arg(" dtype_q" ) = std::nullopt, \
1431+ py::arg(" dtype_kv" ) = std::nullopt); \
14321432 m.def(" get_mla_metadata_v1_no_redundant" , &get_mla_metadata_v1_no_redundant);
14331433
1434- #define PA_METADATA_PYBIND \
1435- m.def(" get_pa_metadata_v1" , \
1436- &get_pa_metadata_v1, \
1437- " get_pa_metadata_v1" , \
1438- py::arg (" seqlens_qo_indptr" ), \
1439- py::arg(" pages_kv_indptr" ), \
1440- py::arg(" num_heads_per_head_k" ), \
1441- py::arg(" num_heads_k" ), \
1442- py::arg(" is_causal" ), \
1443- py::arg(" work_metadata_ptrs" ), \
1444- py::arg(" work_indptr" ), \
1445- py::arg(" work_info" ), \
1446- py::arg(" reduce_indptr" ), \
1447- py::arg(" reduce_final_map" ), \
1448- py::arg(" reduce_partial_map" ), \
1449- py::arg(" kv_granularity" ) = 16, \
1450- py::arg(" max_seqlen_qo" ) = -1, \
1451- py::arg(" uni_seqlen_qo" ) = -1, \
1452- py::arg(" fast_mode" ) = true, \
1453- py::arg(" topk" ) = -1, \
1434+ #define PA_METADATA_PYBIND \
1435+ m.def(" get_pa_metadata_v1" , \
1436+ &get_pa_metadata_v1, \
1437+ " get_pa_metadata_v1" , \
1438+ py::arg (" seqlens_qo_indptr" ), \
1439+ py::arg(" pages_kv_indptr" ), \
1440+ py::arg(" num_heads_per_head_k" ), \
1441+ py::arg(" num_heads_k" ), \
1442+ py::arg(" is_causal" ), \
1443+ py::arg(" work_metadata_ptrs" ), \
1444+ py::arg(" work_indptr" ), \
1445+ py::arg(" work_info" ), \
1446+ py::arg(" reduce_indptr" ), \
1447+ py::arg(" reduce_final_map" ), \
1448+ py::arg(" reduce_partial_map" ), \
1449+ py::arg(" kv_granularity" ) = 16, \
1450+ py::arg(" max_seqlen_qo" ) = -1, \
1451+ py::arg(" uni_seqlen_qo" ) = -1, \
1452+ py::arg(" fast_mode" ) = true, \
1453+ py::arg(" topk" ) = -1, \
14541454 py::arg(" max_split_per_batch" ) = -1);
14551455
14561456#define MLA_REDUCE_PYBIND \
@@ -1462,13 +1462,14 @@ namespace py = pybind11;
14621462 py::arg(" reduce_indptr" ), \
14631463 py::arg(" reduce_final_map" ), \
14641464 py::arg(" reduce_partial_map" ), \
1465+ py::arg(" max_seqlen_q" ), \
14651466 py::arg(" final_output" ), \
14661467 py::arg(" final_lse" ) = std::nullopt);
14671468
1468- #define TOPK_PLAIN_PYBIND \
1469- m.def(" topk_plain" , \
1470- &topk_plain, \
1471- py::arg (" values" ), \
1472- py::arg(" topk_ids" ), \
1473- py::arg(" topk" ), \
1469+ #define TOPK_PLAIN_PYBIND \
1470+ m.def(" topk_plain" , \
1471+ &topk_plain, \
1472+ py::arg (" values" ), \
1473+ py::arg(" topk_ids" ), \
1474+ py::arg(" topk" ), \
14741475 py::arg(" largest" ));
0 commit comments