Skip to content

Commit 8884d8b

Browse files
committed
Addressing PR comments
1 parent 5ad0967 commit 8884d8b

File tree

5 files changed

+30
-14
lines changed

5 files changed

+30
-14
lines changed

mlir/include/mlir/Dialect/Rock/IR/RockAccelTuningParamAttrInterface.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def RockAccelTuningParamAttrInterface : AttrInterface<"RockAccelTuningParamAttrI
9393
>,
9494
InterfaceMethod<
9595
/*desc=*/[{
96-
Return param passed to the backend compiler: waves_per_eu.
96+
Return waves_per_eu, this is a hint to the backend compiler to tune the number of wavefronts that are capable of fitting within the resources of an EU.
9797
}],
9898
/*retType=*/"int64_t",
9999
/*methodName=*/"getWavesPerEU",

mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,9 +464,25 @@ def Rock_ScheduleVersionAttr : Rock_Attr<"ScheduleVersion"> {
464464
}];
465465
}
466466

467-
// It is a temporary attribute
468467
def Rock_EnableSplitKForTuning : Rock_Attr<"EnableSplitKForTuning"> {
469468
let mnemonic = "enable_splitk_for_tuning";
469+
let description = [{
470+
Whether we tune for split-k. If unset, split-k=1.
471+
}];
472+
}
473+
474+
def Rock_WavesPerEU : Rock_Attr<"WavesPerEU"> {
475+
let mnemonic = "waves_per_eu";
476+
let description = [{
477+
This is a hint to the backend compiler to tune the number of wavefronts that are capable of fitting within the resources of an EU.
478+
}];
479+
}
480+
481+
def Rock_OutputSwizzle : Rock_Attr<"OutputSwizzle"> {
482+
let mnemonic = "output_swizzle";
483+
let description = [{
484+
Whether we run the output swizzle pass. 0 -> disabled, 1 -> enabled, 2 -> heuristic.
485+
}];
470486
}
471487

472488
def Rock_PrefillAttr : Rock_Attr<"Prefill"> {

mlir/lib/Conversion/RockToGPU/RockToGPU.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,9 @@ void LowerRockOpsToGPUPass::runOnOperation() {
204204
gridSize = cast<IntegerAttr>(gridSizeAttr).getInt();
205205
gpuFunc.setKnownGridSizeAttr(b.getDenseI32ArrayAttr({gridSize, 1, 1}));
206206

207-
auto wavesPerEUAttr = theFunc->getAttr("waves_per_eu");
207+
auto wavesPerEUAttr = theFunc->getAttr(rock::WavesPerEUAttr::getMnemonic());
208208
if (wavesPerEUAttr) {
209-
gpuFunc->setAttr("waves_per_eu", wavesPerEUAttr);
209+
gpuFunc->setAttr(rock::WavesPerEUAttr::getMnemonic(), wavesPerEUAttr);
210210
}
211211

212212
FailureOr<StringAttr> maybeArch = rock::getArch(theFunc);
@@ -397,9 +397,9 @@ void LowerRockOpsToGPUPass::runOnOperation() {
397397
b.getI32IntegerAttr(ldsUsage));
398398
}
399399
// if waves_per_eu is set, use it
400-
if (gpuFunc->hasAttrOfType<IntegerAttr>("waves_per_eu")) {
400+
if (gpuFunc->hasAttrOfType<IntegerAttr>(rock::WavesPerEUAttr::getMnemonic())) {
401401
int64_t wavesPerEU =
402-
gpuFunc->getAttrOfType<IntegerAttr>("waves_per_eu").getInt();
402+
gpuFunc->getAttrOfType<IntegerAttr>(rock::WavesPerEUAttr::getMnemonic()).getInt();
403403
// zero means, use heuristic
404404
if (wavesPerEU != 0) {
405405
gpuFunc->setAttr("rocdl.waves_per_eu", b.getI32IntegerAttr(wavesPerEU));

mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1981,7 +1981,7 @@ struct GridwiseAttentionAccelRewritePattern
19811981
rock::accel::AccelEmitterParams accelParamsGemm1 =
19821982
accelEmitterPtrGemm1->getParams();
19831983

1984-
// some params are needed in future passes, add them to func as attributes
1984+
// wavesPerEU is needed in RockToGPU pass and OutputSwizzle for the OutputSwizzle pass. We add them as func attributes.
19851985
assert(gemm0TuningParams.getWavesPerEU() ==
19861986
gemm1TuningParams.getWavesPerEU());
19871987
assert(gemm0TuningParams.getOutputSwizzle() ==
@@ -1991,8 +1991,8 @@ struct GridwiseAttentionAccelRewritePattern
19911991
IntegerAttr outputSwizzleAttr =
19921992
rewriter.getI64IntegerAttr(gemm0TuningParams.getOutputSwizzle());
19931993
func::FuncOp funcOp = cast<func::FuncOp>(op->getParentOp());
1994-
funcOp->setAttr("waves_per_eu", wavesPerEUAttr);
1995-
funcOp->setAttr("output_swizzle", outputSwizzleAttr);
1994+
funcOp->setAttr(rock::WavesPerEUAttr::getMnemonic(), wavesPerEUAttr);
1995+
funcOp->setAttr(rock::OutputSwizzleAttr::getMnemonic(), outputSwizzleAttr);
19961996

19971997
// Get current workgroup ID.
19981998
auto bid = WorkgroupIdOp::create(rewriter, loc, rewriter.getIndexType());
@@ -3037,14 +3037,14 @@ struct GridwiseGemmAccelRewritePattern
30373037
gridGroupSize},
30383038
arch);
30393039

3040-
// some params are needed in future passes, add them to func as attributes
3040+
// wavesPerEU is needed in RockToGPU pass and OutputSwizzle for the OutputSwizzle pass. We add them as func attributes.
30413041
IntegerAttr wavesPerEUAttr =
30423042
b.getI64IntegerAttr(tuningParams.getWavesPerEU());
30433043
IntegerAttr outputSwizzleAttr =
30443044
b.getI64IntegerAttr(tuningParams.getOutputSwizzle());
30453045
func::FuncOp funcOp = cast<func::FuncOp>(op->getParentOp());
3046-
funcOp->setAttr("waves_per_eu", wavesPerEUAttr);
3047-
funcOp->setAttr("output_swizzle", outputSwizzleAttr);
3046+
funcOp->setAttr(rock::WavesPerEUAttr::getMnemonic(), wavesPerEUAttr);
3047+
funcOp->setAttr(rock::OutputSwizzleAttr::getMnemonic(), outputSwizzleAttr);
30483048

30493049
LDSLayoutConfigDim ldsLayoutConfigA = getLDSLayoutConfigDim(
30503050
elementTypeA, kpack, maybeVecDimInfoA.value(), directToLDS);

mlir/lib/Dialect/Rock/Transforms/OutputSwizzle.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -416,10 +416,10 @@ void RockOutputSwizzlePass::runOnOperation() {
416416

417417
bool useHeuristic = true;
418418
bool enableOutputSwizzle = false;
419-
if (func->hasAttrOfType<IntegerAttr>("output_swizzle")) {
419+
if (func->hasAttrOfType<IntegerAttr>(rock::OutputSwizzleAttr::getMnemonic())) {
420420
// 0 -> disabled, 1 -> enabled, 2 -> heuristic
421421
int64_t outputSwizzle =
422-
func->getAttrOfType<IntegerAttr>("output_swizzle").getInt();
422+
func->getAttrOfType<IntegerAttr>(rock::OutputSwizzleAttr::getMnemonic()).getInt();
423423
useHeuristic = outputSwizzle == 2;
424424
enableOutputSwizzle = outputSwizzle == 1;
425425
}

0 commit comments

Comments
 (0)