Skip to content

Commit 31dc230

Browse files
committed
Add third phase to greedy tuning (outputSwizzle, wavesPerEU, gridGroupSize)
1 parent 31ea3ce commit 31dc230

File tree

15 files changed

+321
-60
lines changed

15 files changed

+321
-60
lines changed

mlir/include/mlir/Dialect/Rock/IR/RockAccelTuningParamAttrInterface.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,26 @@ def RockAccelTuningParamAttrInterface : AttrInterface<"RockAccelTuningParamAttrI
9090
/*args=*/(ins),
9191
/*methodBody=*/"",
9292
/*defaultImplementation=*/""
93+
>,
94+
InterfaceMethod<
95+
/*desc=*/[{
96+
Return param passed to the backend compiler: waves_per_eu.
97+
}],
98+
/*retType=*/"int64_t",
99+
/*methodName=*/"getWavesPerEU",
100+
/*args=*/(ins),
101+
/*methodBody=*/"",
102+
/*defaultImplementation=*/""
103+
>,
104+
InterfaceMethod<
105+
/*desc=*/[{
106+
Group size for layout on the distribution of the workgroups.
107+
}],
108+
/*retType=*/"int64_t",
109+
/*methodName=*/"getGridGroupSize",
110+
/*args=*/(ins),
111+
/*methodBody=*/"",
112+
/*defaultImplementation=*/""
93113
>
94114

95115
// TODO: more methods here as needed

mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def Rock_AttnPerfConfig : Rock_Attr<"AttnPerfConfig", [RockTuningParamAttrInterf
280280
"int64_t":$nPerBlockG0, "int64_t":$kpackPerBlock, "int64_t":$mPerWave,
281281
"int64_t":$nPerWave, "int64_t":$mnPerXdl, "int64_t":$kpack,
282282
"int64_t":$splitKFactor, "int64_t":$scheduleVersion,
283-
"int64_t":$outputSwizzle, "bool":$forceUnroll);
283+
"int64_t":$outputSwizzle, "int64_t":$wavesPerEU, "bool":$forceUnroll);
284284

285285
let extraClassDeclaration = [{
286286
void getPerfConfigStr(::llvm::SmallVectorImpl<char> &perfStr) {
@@ -296,13 +296,14 @@ def Rock_AttnPerfConfig : Rock_Attr<"AttnPerfConfig", [RockTuningParamAttrInterf
296296
+ Twine(getSplitKFactor()) + ","
297297
+ Twine(getScheduleVersion()) + ","
298298
+ Twine(getOutputSwizzle()) + ","
299+
+ Twine(getWavesPerEU()) + ","
299300
+ Twine(getForceUnroll())).toVector(perfStr);
300301
}
301302
AttnPerfConfigAttr withScheduleVersion(int64_t newScheduleVersion) const {
302303
return AttnPerfConfigAttr::get(
303304
getContext(), getMPerBlockG0(), getMPerBlockG1(), getNPerBlockG0(),
304305
getKpackPerBlock(), getMPerWave(), getNPerWave(), getMnPerXdl(), getKpack(),
305-
getSplitKFactor(), newScheduleVersion, getOutputSwizzle(), getForceUnroll());
306+
getSplitKFactor(), newScheduleVersion, getOutputSwizzle(), getWavesPerEU(), getForceUnroll());
306307
}
307308
}];
308309

@@ -321,19 +322,11 @@ def Rock_MfmaGemmParamsAttr
321322
let description = [{
322323
The tuning parameters for an mfma-based matrix multiplication.
323324
}];
324-
let parameters = (ins
325-
"int64_t":$kpackPerBlock,
326-
"int64_t":$mPerBlock,
327-
"int64_t":$nPerBlock,
328-
"int64_t":$kpack,
329-
"int64_t":$mPerWave,
330-
"int64_t":$nPerWave,
331-
"int64_t":$mnPerXdl,
332-
"int64_t":$splitKFactor,
333-
"int64_t":$scheduleVersion,
334-
"int64_t":$outputSwizzle,
335-
"bool":$forceUnroll
336-
);
325+
let parameters = (ins "int64_t":$kpackPerBlock, "int64_t":$mPerBlock,
326+
"int64_t":$nPerBlock, "int64_t":$kpack, "int64_t":$mPerWave,
327+
"int64_t":$nPerWave, "int64_t":$mnPerXdl, "int64_t":$splitKFactor,
328+
"int64_t":$scheduleVersion, "int64_t":$outputSwizzle,
329+
"int64_t":$wavesPerEU, "int64_t":$gridGroupSize, "bool":$forceUnroll);
337330

338331
let extraClassDeclaration = [{
339332
void getPerfConfigStr(::llvm::SmallVectorImpl<char> &perfStr) {
@@ -347,6 +340,8 @@ def Rock_MfmaGemmParamsAttr
347340
+ Twine(getSplitKFactor()) + ","
348341
+ Twine(getScheduleVersion()) + ","
349342
+ Twine(getOutputSwizzle()) + ","
343+
+ Twine(getWavesPerEU()) + ","
344+
+ Twine(getGridGroupSize()) + ","
350345
+ Twine(getForceUnroll()) + ","
351346
+ "1") /* *ThreadCopyMore* */
352347
.toVector(perfStr);
@@ -367,7 +362,7 @@ def Rock_WmmaGemmParamsAttr : Rock_Attr<"WmmaGemmParams", [RockTuningParamAttrIn
367362
"int64_t":$nPerBlock, "int64_t":$kpack, "int64_t":$mPerWave,
368363
"int64_t":$nPerWave, "int64_t":$mnPerXdl, "int64_t":$splitKFactor,
369364
"int64_t":$scheduleVersion, "int64_t":$outputSwizzle,
370-
"bool":$forceUnroll);
365+
"int64_t":$wavesPerEU, "int64_t":$gridGroupSize, "bool":$forceUnroll);
371366

372367
let extraClassDeclaration = [{
373368
void getPerfConfigStr(SmallVectorImpl<char> &perfStr) {
@@ -381,6 +376,8 @@ def Rock_WmmaGemmParamsAttr : Rock_Attr<"WmmaGemmParams", [RockTuningParamAttrIn
381376
+ Twine(getSplitKFactor()) + ","
382377
+ Twine(getScheduleVersion()) + ","
383378
+ Twine(getOutputSwizzle()) + ","
379+
+ Twine(getWavesPerEU()) + ","
380+
+ Twine(getGridGroupSize()) + ","
384381
+ Twine(getForceUnroll()) + ","
385382
+ "1") /* *ThreadCopyMore* */
386383
.toVector(perfStr);

mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ struct InitParamsNonAccel : InitParams, Serializable<InitParamsNonAccel> {
142142
};
143143

144144
struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
145+
// TODO: remove once we generate new quick tuning list
145146
constexpr InitParamsAccel(int64_t mPerBlock, int64_t nPerBlock,
146147
int64_t kPerBlock, int64_t mPerWave,
147148
int64_t nPerWave, int64_t mnPerXdl, int64_t kPack,
@@ -152,12 +153,28 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
152153
gemmNPerWave(nPerWave), gemmMnPerXdl(mnPerXdl),
153154
gemmNPerWaveOrMnPerXdl(0), gemmKPack(kPack), splitKFactor(splitKFactor),
154155
gemmScheduleVersion(scheduleVersion), outputSwizzle(outputSwizzle),
156+
wavesPerEU(0), gridGroupSize(0),
157+
gemmAThreadCopyMoreGemmK(aThreadCopyMoreGemmK),
158+
gemmBThreadCopyMoreGemmKPack(bThreadCopyMoreGemmKPack) {}
159+
160+
constexpr InitParamsAccel(int64_t mPerBlock, int64_t nPerBlock,
161+
int64_t kPerBlock, int64_t mPerWave,
162+
int64_t nPerWave, int64_t mnPerXdl, int64_t kPack,
163+
int64_t splitKFactor, int64_t scheduleVersion,
164+
int64_t outputSwizzle, int64_t wavesPerEU,
165+
int64_t gridGroupSize, bool aThreadCopyMoreGemmK,
166+
bool bThreadCopyMoreGemmKPack)
167+
: InitParams{mPerBlock, nPerBlock, kPerBlock}, gemmMPerWave(mPerWave),
168+
gemmNPerWave(nPerWave), gemmMnPerXdl(mnPerXdl),
169+
gemmNPerWaveOrMnPerXdl(0), gemmKPack(kPack), splitKFactor(splitKFactor),
170+
gemmScheduleVersion(scheduleVersion), outputSwizzle(outputSwizzle),
171+
wavesPerEU(wavesPerEU), gridGroupSize(gridGroupSize),
155172
gemmAThreadCopyMoreGemmK(aThreadCopyMoreGemmK),
156173
gemmBThreadCopyMoreGemmKPack(bThreadCopyMoreGemmKPack) {}
157174

158175
constexpr InitParamsAccel()
159-
: InitParamsAccel(0LL, 0LL, 0LL, 0LL, 0LL, 0LL, 0LL, 1LL, 1LL, 2LL, false,
160-
false) {}
176+
: InitParamsAccel(0LL, 0LL, 0LL, 0LL, 0LL, 0LL, 0LL, 1LL, 1LL, 2LL, 0LL,
177+
0LL, false, false) {}
161178

162179
InitParamsAccel(MfmaGemmParamsAttr attr)
163180
: InitParams{attr.getMPerBlock(), attr.getNPerBlock(),
@@ -167,6 +184,8 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
167184
gemmKPack(attr.getKpack()), splitKFactor(attr.getSplitKFactor()),
168185
gemmScheduleVersion(attr.getScheduleVersion()),
169186
outputSwizzle(attr.getOutputSwizzle()),
187+
wavesPerEU(attr.getWavesPerEU()),
188+
gridGroupSize(attr.getGridGroupSize()),
170189
gemmAThreadCopyMoreGemmK(attr.getForceUnroll()),
171190
gemmBThreadCopyMoreGemmKPack(false) {};
172191

@@ -178,6 +197,8 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
178197
gemmKPack(attr.getKpack()), splitKFactor(attr.getSplitKFactor()),
179198
gemmScheduleVersion(attr.getScheduleVersion()),
180199
outputSwizzle(attr.getOutputSwizzle()),
200+
wavesPerEU(attr.getWavesPerEU()),
201+
gridGroupSize(attr.getGridGroupSize()),
181202
gemmAThreadCopyMoreGemmK(attr.getForceUnroll()),
182203
gemmBThreadCopyMoreGemmKPack(false) {};
183204

@@ -191,6 +212,8 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
191212
int64_t splitKFactor;
192213
int64_t gemmScheduleVersion;
193214
int64_t outputSwizzle;
215+
int64_t wavesPerEU;
216+
int64_t gridGroupSize;
194217
bool gemmAThreadCopyMoreGemmK;
195218
bool gemmBThreadCopyMoreGemmKPack;
196219

@@ -214,6 +237,10 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
214237
f(self.gemmScheduleVersion);
215238
f(self.outputSwizzle);
216239
}
240+
if (self.version >= Version::V4) {
241+
f(self.wavesPerEU);
242+
f(self.gridGroupSize);
243+
}
217244
f(self.gemmAThreadCopyMoreGemmK);
218245
f(self.gemmBThreadCopyMoreGemmKPack);
219246
}

mlir/include/mlir/Dialect/Rock/Tuning/Serializable.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ struct Serializable {
7171
}
7272

7373
bool checkVersionFormat(const std::string &s) {
74-
const int32_t maxNumTokensArray[] = {0, 8, 9, 11, 12};
74+
const int32_t maxNumTokensArray[] = {0, 8, 9, 11, 14};
7575
const int32_t versionIdx = static_cast<int32_t>(version);
7676
if (versionIdx < 1 || versionIdx >= static_cast<int32_t>(Version::Count)) {
7777
llvm_unreachable("Unknown version of the perfConfig");

mlir/lib/Conversion/RockToGPU/RockToGPU.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,11 @@ void LowerRockOpsToGPUPass::runOnOperation() {
204204
gridSize = cast<IntegerAttr>(gridSizeAttr).getInt();
205205
gpuFunc.setKnownGridSizeAttr(b.getDenseI32ArrayAttr({gridSize, 1, 1}));
206206

207+
auto wavesPerEUAttr = theFunc->getAttr("waves_per_eu");
208+
if (wavesPerEUAttr) {
209+
gpuFunc->setAttr("waves_per_eu", wavesPerEUAttr);
210+
}
211+
207212
FailureOr<StringAttr> maybeArch = rock::getArch(theFunc);
208213
if (succeeded(maybeArch)) {
209214
gpuFunc->setAttr("arch", maybeArch.value());
@@ -391,7 +396,21 @@ void LowerRockOpsToGPUPass::runOnOperation() {
391396
gpuFunc->setAttr("rock.shared_buffer_size",
392397
b.getI32IntegerAttr(ldsUsage));
393398
}
394-
LLVM_DEBUG(llvm::dbgs() << "Attempting to set wavesPerEU...\n");
399+
// if waves_per_eu is set, use it
400+
if (gpuFunc->hasAttrOfType<IntegerAttr>("waves_per_eu")) {
401+
int64_t wavesPerEU =
402+
gpuFunc->getAttrOfType<IntegerAttr>("waves_per_eu").getInt();
403+
// zero means, use heuristic
404+
if (wavesPerEU != 0) {
405+
gpuFunc->setAttr("rocdl.waves_per_eu", b.getI32IntegerAttr(wavesPerEU));
406+
LLVM_DEBUG(llvm::dbgs() << "Setting waves_per_eu using tuning param\n");
407+
// we are done
408+
return;
409+
}
410+
}
411+
412+
// no "waves_per_eu" attribute, use heuristic
413+
LLVM_DEBUG(llvm::dbgs() << "Using heuristic to set wavesPerEU...\n");
395414
if (!gpuFunc->hasAttrOfType<IntegerAttr>("block_size")) {
396415
LLVM_DEBUG(llvm::dbgs() << "blockSize not found in gpuFunc.\n");
397416
return;

mlir/lib/Dialect/Rock/IR/RockDialect.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3187,11 +3187,12 @@ AttnPerfConfigAttr AttnPerfConfigAttr::get(StringAttr perfConfigStrAttr,
31873187
int64_t splitKFactor = version > 1 ? params[lastIdx++] : 1;
31883188
int64_t scheduleVersion = version > 1 ? params[lastIdx++] : 1;
31893189
int64_t outputSwizzle = version > 1 ? params[lastIdx++] : 2;
3190+
int64_t wavesPerEU = isV3 ? params[lastIdx++] : 0; // 0 -> use heuristic
31903191
int64_t forceUnroll = params[expectedNumTokens - 1] == 1;
31913192
return AttnPerfConfigAttr::get(
31923193
perfConfigStrAttr.getContext(), mPerBlockG0, mPerBlockG1, nPerBlockG0,
31933194
kpackPerBlock, mPerWave, nPerWave, mnPerXdl, kpack, splitKFactor,
3194-
scheduleVersion, outputSwizzle, forceUnroll);
3195+
scheduleVersion, outputSwizzle, wavesPerEU, forceUnroll);
31953196
}
31963197

31973198
//===-----------------------------------------------------===//

mlir/lib/Dialect/Rock/Transforms/AffixTuningParameters.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ void AffixTuningParameters::affixTuningParametersImpl(
309309
Attribute params0 = op.getGemm0Params().value_or(nullptr);
310310
// set a default one if params is not provided
311311
StringAttr perfConfigStrAttr =
312-
builder.getStringAttr("attn:v3:32,32,32,32,32,32,16,1,1,1,2,1");
312+
builder.getStringAttr("attn:v3:32,32,32,32,32,32,16,1,1,1,2,0,1");
313313
if (!params0) {
314314
if (StringAttr mayBePerfConfigStrAttr =
315315
dyn_cast_or_null<StringAttr>(op->getAttr("perf_config"))) {

mlir/lib/Dialect/Rock/Transforms/GridLayoutEmitter.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,15 @@ GridCoordinates rock::layout::makeGroupedGridLayout(PatternRewriter &b,
8484
int64_t bitWidthOut = info.outputType.getIntOrFloatBitWidth();
8585
int64_t groupSize =
8686
std::ceil(std::sqrt(info.numCU)) * (bitWidthOut / bitWidthIn);
87+
// use gridGroupSize if it's not zero
88+
if (info.gridGroupSize != 0) {
89+
groupSize = info.gridGroupSize;
90+
LLVM_DEBUG(llvm::dbgs() << "Setting groupSize by using tuning params to "
91+
<< groupSize << "\n");
92+
} else {
93+
LLVM_DEBUG(llvm::dbgs()
94+
<< "Using heuristic to set groupSize to " << groupSize << "\n");
95+
}
8796

8897
Value mBlocksPerGroup = b.createOrFold<ConstantIndexOp>(loc, groupSize);
8998
Value blocksPerGroup =

mlir/lib/Dialect/Rock/Transforms/GridLayoutEmitter.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ struct GridLayoutInfo {
5454
int64_t numCU;
5555
Type inputType;
5656
Type outputType;
57+
int64_t gridGroupSize;
5758
};
5859

5960
/// This function emits the right triplet of <group,block_m,block_n>

mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -532,9 +532,12 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<GridwiseGemmOp> {
532532
if (failed(maybeArch)) {
533533
return op.emitError("arch needs to be set.");
534534
}
535+
// always use heuristic for non-accel path
536+
int64_t gridGroupSize = 0;
535537
auto gridCoords = layout::makeGroupedGridLayout(
536538
b, loc, bid,
537-
{G, mBlocks, nBlocks, rock::getNumCUValue(op), elementTypeA, destType},
539+
{G, mBlocks, nBlocks, rock::getNumCUValue(op), elementTypeA, destType,
540+
gridGroupSize},
538541
maybeArch->getValue());
539542

540543
Value storeBufferA = GpuAllocOp::create(b, loc, loadBufferA.getType());
@@ -1978,6 +1981,19 @@ struct GridwiseAttentionAccelRewritePattern
19781981
rock::accel::AccelEmitterParams accelParamsGemm1 =
19791982
accelEmitterPtrGemm1->getParams();
19801983

1984+
// some params are needed in future passes, add them to func as attributes
1985+
assert(gemm0TuningParams.getWavesPerEU() ==
1986+
gemm1TuningParams.getWavesPerEU());
1987+
assert(gemm0TuningParams.getOutputSwizzle() ==
1988+
gemm1TuningParams.getOutputSwizzle());
1989+
IntegerAttr wavesPerEUAttr =
1990+
rewriter.getI64IntegerAttr(gemm0TuningParams.getWavesPerEU());
1991+
IntegerAttr outputSwizzleAttr =
1992+
rewriter.getI64IntegerAttr(gemm0TuningParams.getOutputSwizzle());
1993+
func::FuncOp funcOp = cast<func::FuncOp>(op->getParentOp());
1994+
funcOp->setAttr("waves_per_eu", wavesPerEUAttr);
1995+
funcOp->setAttr("output_swizzle", outputSwizzleAttr);
1996+
19811997
// Get current workgroup ID.
19821998
auto bid = WorkgroupIdOp::create(rewriter, loc, rewriter.getIndexType());
19831999
// Get current workitem ID.
@@ -3010,11 +3026,22 @@ struct GridwiseGemmAccelRewritePattern
30103026
auto tid = WorkitemIdOp::create(b, loc, b.getIndexType());
30113027

30123028
// Compute grid coordinates
3029+
int64_t gridGroupSize = tuningParams.getGridGroupSize();
30133030
auto gridCoords = layout::makeGroupedGridLayout(
30143031
b, loc, bid,
3015-
{G, mBlocks, nBlocks, rock::getNumCUValue(op), elementTypeA, destType},
3032+
{G, mBlocks, nBlocks, rock::getNumCUValue(op), elementTypeA, destType,
3033+
gridGroupSize},
30163034
arch);
30173035

3036+
// some params are needed in future passes, add them to func as attributes
3037+
IntegerAttr wavesPerEUAttr =
3038+
b.getI64IntegerAttr(tuningParams.getWavesPerEU());
3039+
IntegerAttr outputSwizzleAttr =
3040+
b.getI64IntegerAttr(tuningParams.getOutputSwizzle());
3041+
func::FuncOp funcOp = cast<func::FuncOp>(op->getParentOp());
3042+
funcOp->setAttr("waves_per_eu", wavesPerEUAttr);
3043+
funcOp->setAttr("output_swizzle", outputSwizzleAttr);
3044+
30183045
LDSLayoutConfigDim ldsLayoutConfigA = getLDSLayoutConfigDim(
30193046
elementTypeA, kpack, maybeVecDimInfoA.value(), directToLDS);
30203047
LDSLayoutConfigDim ldsLayoutConfigB = getLDSLayoutConfigDim(

0 commit comments

Comments
 (0)