@@ -77,7 +77,8 @@ void ck_moe_stage1_gemm(const hipStream_t& stream,
7777 // : 128;
7878 static constexpr ck::index_t CShuffleMXDLPerWave = MXDLPerWave;
7979 static constexpr ck::index_t CShuffleNXDLPerWave = NXDLPerWave;
80- static constexpr ck::index_t CShuffleNLane = NPerBlock / 2 / NXDLPerWave; // 64
80+ static constexpr ck::index_t CShuffleNLane =
81+ BLOCKSIZE == 64 ? NPerBlock / NXDLPerWave : NPerBlock / 2 / NXDLPerWave; // 64
8182 static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane;
8283 static constexpr ck::index_t AK1 = 16 / sizeof (A0DataType);
8384 static constexpr ck::index_t BK1 = 16 / sizeof (B0DataType);
@@ -97,17 +98,17 @@ void ck_moe_stage1_gemm(const hipStream_t& stream,
9798// /######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
9899// /######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S<C, D0, D1>|
99100// /###### RCR
100- < Row, Col, DsLayout, ELayout,
101+ < Row, Col, DsLayout, ELayout,
101102 A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
102- AElementOp, BElementOp, CDEElementOp, GemmSpec,
103- 32 , BLOCKSIZE,
103+ AElementOp, BElementOp, CDEElementOp, GemmSpec,
104+ 32 , BLOCKSIZE,
104105 MPerBlock, NPerBlock, 128 ,
105106 AK1, BK1,
106107 MNPerXDL, MNPerXDL,
107108 MXDLPerWave, NXDLPerWave,
108109 S<K0_A, K0_M_A, 1 >, S<1 , 0 , 2 >, S<1 , 0 , 2 >, 2 , AK1, AK1, 1 ,
109110 S<K0_B, K0_N_B, 1 >, S<1 , 0 , 2 >, S<1 , 0 , 2 >, 2 , BK1, BK1, 1 ,
110- 2 , CShuffleNXDLPerWave, S<1 , 32 , 1 , 8 >, S<EVec, D0Vec, D1Vec>,
111+ 2 , CShuffleNXDLPerWave, S<1 , CShuffleNLane , 1 , CShuffleMLane >, S<EVec, D0Vec, D1Vec>,
111112 ck::BlockGemmPipelineScheduler::Intrawave, PipelineVer, ActOP, Nswizzle, true , MulRoutedWeight, ck::index_t , A0DataType>; // clang-format on
112113 // clang-format on
113114
@@ -286,10 +287,10 @@ void ck_moe_stage2_gemm(const hipStream_t& stream,
286287// /#####| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
287288// /#####| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S<C, D0, D1>|
288289// /##### RCR
289- < Row, Col, DsLayout, ELayout,
290+ < Row, Col, DsLayout, ELayout,
290291 A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
291292 AElementOp, BElementOp, CDEElementOp, GemmSpec,
292- 32 , BLOCKSIZE,
293+ 32 , BLOCKSIZE,
293294 MPerBlock, NPerBlock, 128 ,
294295 AK1, BK1,
295296 MNPerXDL, MNPerXDL,
@@ -365,4 +366,4 @@ void ck_moe_stage2_gemm(const hipStream_t& stream,
365366 void *&num_valid_ids, \
366367 void *&out, \
367368 std::optional<void *> w2_scale, \
368- std::optional<void *> a2_scale);
369+ std::optional<void *> a2_scale);
0 commit comments