@@ -142,6 +142,7 @@ struct InitParamsNonAccel : InitParams, Serializable<InitParamsNonAccel> {
142142};
143143
144144struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
145+ // TODO: remove once we generate new quick tuning list
145146 constexpr InitParamsAccel (int64_t mPerBlock , int64_t nPerBlock,
146147 int64_t kPerBlock , int64_t mPerWave ,
147148 int64_t nPerWave, int64_t mnPerXdl, int64_t kPack ,
@@ -152,12 +153,28 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
152153 gemmNPerWave (nPerWave), gemmMnPerXdl(mnPerXdl),
153154 gemmNPerWaveOrMnPerXdl (0 ), gemmKPack(kPack ), splitKFactor(splitKFactor),
154155 gemmScheduleVersion (scheduleVersion), outputSwizzle(outputSwizzle),
156+ wavesPerEU (0 ), gridGroupSize(0 ),
157+ gemmAThreadCopyMoreGemmK (aThreadCopyMoreGemmK),
158+ gemmBThreadCopyMoreGemmKPack (bThreadCopyMoreGemmKPack) {}
159+
160+ constexpr InitParamsAccel (int64_t mPerBlock , int64_t nPerBlock,
161+ int64_t kPerBlock , int64_t mPerWave ,
162+ int64_t nPerWave, int64_t mnPerXdl, int64_t kPack ,
163+ int64_t splitKFactor, int64_t scheduleVersion,
164+ int64_t outputSwizzle, int64_t wavesPerEU,
165+ int64_t gridGroupSize, bool aThreadCopyMoreGemmK,
166+ bool bThreadCopyMoreGemmKPack)
167+ : InitParams{mPerBlock , nPerBlock, kPerBlock }, gemmMPerWave(mPerWave ),
168+ gemmNPerWave (nPerWave), gemmMnPerXdl(mnPerXdl),
169+ gemmNPerWaveOrMnPerXdl (0 ), gemmKPack(kPack ), splitKFactor(splitKFactor),
170+ gemmScheduleVersion (scheduleVersion), outputSwizzle(outputSwizzle),
171+ wavesPerEU (wavesPerEU), gridGroupSize(gridGroupSize),
155172 gemmAThreadCopyMoreGemmK (aThreadCopyMoreGemmK),
156173 gemmBThreadCopyMoreGemmKPack (bThreadCopyMoreGemmKPack) {}
157174
158175 constexpr InitParamsAccel ()
159- : InitParamsAccel(0LL , 0LL , 0LL , 0LL , 0LL , 0LL , 0LL , 1LL , 1LL , 2LL , false ,
160- false ) {}
176+ : InitParamsAccel(0LL , 0LL , 0LL , 0LL , 0LL , 0LL , 0LL , 1LL , 1LL , 2LL , 0LL ,
177+ 0LL , false , false ) {}
161178
162179 InitParamsAccel (MfmaGemmParamsAttr attr)
163180 : InitParams{attr.getMPerBlock (), attr.getNPerBlock (),
@@ -167,6 +184,8 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
167184 gemmKPack (attr.getKpack()), splitKFactor(attr.getSplitKFactor()),
168185 gemmScheduleVersion (attr.getScheduleVersion()),
169186 outputSwizzle (attr.getOutputSwizzle()),
187+ wavesPerEU (attr.getWavesPerEU()),
188+ gridGroupSize (attr.getGridGroupSize()),
170189 gemmAThreadCopyMoreGemmK (attr.getForceUnroll()),
171190 gemmBThreadCopyMoreGemmKPack (false ) {};
172191
@@ -178,6 +197,8 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
178197 gemmKPack (attr.getKpack()), splitKFactor(attr.getSplitKFactor()),
179198 gemmScheduleVersion (attr.getScheduleVersion()),
180199 outputSwizzle (attr.getOutputSwizzle()),
200+ wavesPerEU (attr.getWavesPerEU()),
201+ gridGroupSize (attr.getGridGroupSize()),
181202 gemmAThreadCopyMoreGemmK (attr.getForceUnroll()),
182203 gemmBThreadCopyMoreGemmKPack (false ) {};
183204
@@ -191,6 +212,8 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
191212 int64_t splitKFactor;
192213 int64_t gemmScheduleVersion;
193214 int64_t outputSwizzle;
215+ int64_t wavesPerEU;
216+ int64_t gridGroupSize;
194217 bool gemmAThreadCopyMoreGemmK;
195218 bool gemmBThreadCopyMoreGemmKPack;
196219
@@ -214,6 +237,10 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
214237 f (self.gemmScheduleVersion );
215238 f (self.outputSwizzle );
216239 }
240+ if (self.version >= Version::V4) {
241+ f (self.wavesPerEU );
242+ f (self.gridGroupSize );
243+ }
217244 f (self.gemmAThreadCopyMoreGemmK );
218245 f (self.gemmBThreadCopyMoreGemmKPack );
219246 }
0 commit comments