Skip to content

Commit aae5b49

Browse files
authored
Wjx/opt a4w4 (#1564)
* opt a4w4 moe decode * update
1 parent d9ffd9a commit aae5b49

File tree

8 files changed

+18
-16
lines changed

8 files changed

+18
-16
lines changed

aiter/ops/triton/mha.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -784,10 +784,10 @@ def flash_attn_with_kvcache(
784784
window_size: (left, right) local attention window; (-1,-1) = full.
785785
softcap: (float) currently must be 0.0 (backend limitation).
786786
num_splits: 0 or 1 only (backend limitation >1).
787-
rotary_cos/rotary_sin: Optional rotary embeddings (applied if provided) interleaving flag unused here.
787+
rotary_cos/rotary_sin: Optional rotary embeddings (applied if provided) - interleaving flag unused here.
788788
cache_batch_idx/cache_leftpad: Optional indexing / left padding metadata.
789789
block_table: Optional paging table mapping logical blocks for paged KV cache.
790-
alibi_slopes: (nheads,) or (batch,nheads) bias slopes (currently ignored if provided placeholder).
790+
alibi_slopes: (nheads,) or (batch,nheads) bias slopes (currently ignored if provided - placeholder).
791791
rotary_interleaved: Flag kept for parity (currently forwarded as True constant to backend which ignores it).
792792
return_softmax_lse: If True returns (out, lse) else out.
793793

aiter/ops/triton/pa_decode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def paged_attention_decode(
4848
) -> None:
4949
"""
5050
Paged attention decode with automatic V1/V2 dispatch and quantization support.
51-
V1 for short sequences (8192), V2 with sequence partitioning for longer sequences.
51+
V1 for short sequences (<=8192), V2 with sequence partitioning for longer sequences.
5252
5353
Args:
5454
output (torch.Tensor): Pre-allocated output with shape (num_seqs, num_q_heads, head_dim).

csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ void ck_moe_stage1(torch::Tensor &hidden_states, // [m, k], input token
6565
"Out dtype only support BFloat16/Float16!")
6666

6767
int tokens = hidden_states.size(0);
68-
int sorted_size = sorted_token_ids.size(0);
68+
int sorted_size = std::min(int64_t(tokens * topk * block_m.value()), sorted_token_ids.size(0));
6969
int E = w1.size(0);
7070
int N = w1.size(1) / 2;
7171
int K = hidden_states.size(-1);
@@ -122,7 +122,7 @@ void ck_moe_stage2(torch::Tensor &inter_states, // [m, k], input token
122122
"Out dtype only support BFloat16/Float16!")
123123

124124
int tokens = inter_states.size(0);
125-
int sorted_size = sorted_token_ids.size(0);
125+
int sorted_size = std::min(int64_t(tokens * topk * block_m.value()), sorted_token_ids.size(0));
126126
int E = w1.size(0);
127127
int N = w2.size(1);
128128
int K = inter_states.size(-1);

csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def name(self) -> str:
194194
0: kernelInstanceGEMM1( 256, 32, 128, 128, 1, 4, 3,),
195195
1: kernelInstanceGEMM1( 256, 64, 128, 128, 1, 4, 3,),
196196
2: kernelInstanceGEMM1( 256, 128, 128, 128, 1, 4, 3,),
197+
4: kernelInstanceGEMM1( 64, 32, 32, 128, 1, 1, 3,),
197198
# 3: kernelInstanceGEMM1( 256, 256, 128, 128, 2, 2, 3,),
198199
}
199200

csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_mxfp4.cuh

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ void ck_moe_stage1_gemm(const hipStream_t& stream,
7777
// : 128;
7878
static constexpr ck::index_t CShuffleMXDLPerWave = MXDLPerWave;
7979
static constexpr ck::index_t CShuffleNXDLPerWave = NXDLPerWave;
80-
static constexpr ck::index_t CShuffleNLane = NPerBlock / 2 / NXDLPerWave; // 64
80+
static constexpr ck::index_t CShuffleNLane =
81+
BLOCKSIZE == 64 ? NPerBlock / NXDLPerWave : NPerBlock / 2 / NXDLPerWave; // 64
8182
static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane;
8283
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
8384
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
@@ -97,17 +98,17 @@ void ck_moe_stage1_gemm(const hipStream_t& stream,
9798
///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
9899
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S<C, D0, D1>|
99100
///###### RCR
100-
< Row, Col, DsLayout, ELayout,
101+
< Row, Col, DsLayout, ELayout,
101102
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
102-
AElementOp, BElementOp, CDEElementOp, GemmSpec,
103-
32, BLOCKSIZE,
103+
AElementOp, BElementOp, CDEElementOp, GemmSpec,
104+
32, BLOCKSIZE,
104105
MPerBlock, NPerBlock, 128,
105106
AK1, BK1,
106107
MNPerXDL, MNPerXDL,
107108
MXDLPerWave, NXDLPerWave,
108109
S<K0_A, K0_M_A, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 1,
109110
S<K0_B, K0_N_B, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 1,
110-
2, CShuffleNXDLPerWave, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
111+
2, CShuffleNXDLPerWave, S<1, CShuffleNLane, 1, CShuffleMLane>, S<EVec, D0Vec, D1Vec>,
111112
ck::BlockGemmPipelineScheduler::Intrawave, PipelineVer, ActOP, Nswizzle, true, MulRoutedWeight, ck::index_t, A0DataType>; // clang-format on
112113
// clang-format on
113114

@@ -286,10 +287,10 @@ void ck_moe_stage2_gemm(const hipStream_t& stream,
286287
///#####| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
287288
///#####| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S<C, D0, D1>|
288289
///##### RCR
289-
< Row, Col, DsLayout, ELayout,
290+
< Row, Col, DsLayout, ELayout,
290291
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
291292
AElementOp, BElementOp, CDEElementOp, GemmSpec,
292-
32, BLOCKSIZE,
293+
32, BLOCKSIZE,
293294
MPerBlock, NPerBlock, 128,
294295
AK1, BK1,
295296
MNPerXDL, MNPerXDL,
@@ -365,4 +366,4 @@ void ck_moe_stage2_gemm(const hipStream_t& stream,
365366
void *&num_valid_ids, \
366367
void *&out, \
367368
std::optional<void *> w2_scale, \
368-
std::optional<void *> a2_scale);
369+
std::optional<void *> a2_scale);

csrc/ck_gemm_moe_2stages_codegen/gen_instances.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@
181181
{{
182182
if (block_m == 32)
183183
{{
184-
return ck_moe_stage1_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V3, 256, 32, 128, 128/sizeof({A0DataType}), 1, 4, {Nswizzle}, {Quant} == static_cast<int>(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>;
184+
return ck_moe_stage1_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V3, 64, 32, 32, 128/sizeof({A0DataType}), 1, 1, {Nswizzle}, {Quant} == static_cast<int>(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>;
185185
}}
186186
else if (block_m == 64)
187187
{{

csrc/cpp_itfs/pa/pa_ragged.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_
9090
}
9191
const int64_t query_loc = static_cast<int64_t>(seq_idx * MTP);
9292
const int* block_table_seq = kv_page_indices + kv_indptr[seq_idx];
93-
93+
9494
if constexpr (VERSION_ID == 0) // 0: GOLDEN VERSION
9595
{
9696
_paged_attention_kernel<scalar_t, cache_t, KV_DTYPE, BLOCK_SIZE, HEAD_SIZE, NUM_THREADS, ALIBI_ENABLED, GQA_RATIO, MTP, AttentionVariant, false>

csrc/pybind/fused_mrope_rms_pybind.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// SPDX-License-Identifier: MIT
2-
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
2+
// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
33
#include "rocm_ops.hpp"
44
#include "fused_mrope_rms.h"
55

0 commit comments

Comments
 (0)