Skip to content
Open
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,26 @@ def RockAccelTuningParamAttrInterface : AttrInterface<"RockAccelTuningParamAttrI
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/""
>,
InterfaceMethod<
/*desc=*/[{
Return waves_per_eu, this is a hint to the backend compiler to tune the number of wavefronts that are capable of fitting within the resources of an EU.
}],
/*retType=*/"int64_t",
/*methodName=*/"getWavesPerEU",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/""
>,
InterfaceMethod<
/*desc=*/[{
Group size for layout on the distribution of the workgroups.
}],
/*retType=*/"int64_t",
/*methodName=*/"getGridGroupSize",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/""
>

// TODO: more methods here as needed
Expand Down
31 changes: 26 additions & 5 deletions mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def Rock_AttnPerfConfig : Rock_Attr<"AttnPerfConfig", [RockTuningParamAttrInterf
"int64_t":$nPerBlockG0, "int64_t":$kpackPerBlock, "int64_t":$mPerWave,
"int64_t":$nPerWave, "int64_t":$mnPerXdl, "int64_t":$kpack,
"int64_t":$splitKFactor, "int64_t":$scheduleVersion,
"int64_t":$outputSwizzle, "bool":$forceUnroll);
"int64_t":$outputSwizzle, "int64_t":$wavesPerEU, "bool":$forceUnroll);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not gridGroupSize?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gridGroupSize is only used for GEMMs/convs, not attention/g+g kernels.


let extraClassDeclaration = [{
void getPerfConfigStr(::llvm::SmallVectorImpl<char> &perfStr) {
Expand All @@ -296,13 +296,14 @@ def Rock_AttnPerfConfig : Rock_Attr<"AttnPerfConfig", [RockTuningParamAttrInterf
+ Twine(getSplitKFactor()) + ","
+ Twine(getScheduleVersion()) + ","
+ Twine(getOutputSwizzle()) + ","
+ Twine(getWavesPerEU()) + ","
+ Twine(getForceUnroll())).toVector(perfStr);
}
AttnPerfConfigAttr withScheduleVersion(int64_t newScheduleVersion) const {
return AttnPerfConfigAttr::get(
getContext(), getMPerBlockG0(), getMPerBlockG1(), getNPerBlockG0(),
getKpackPerBlock(), getMPerWave(), getNPerWave(), getMnPerXdl(), getKpack(),
getSplitKFactor(), newScheduleVersion, getOutputSwizzle(), getForceUnroll());
getSplitKFactor(), newScheduleVersion, getOutputSwizzle(), getWavesPerEU(), getForceUnroll());
}
}];

Expand All @@ -325,7 +326,7 @@ def Rock_MfmaGemmParamsAttr
"int64_t":$nPerBlock, "int64_t":$kpack, "int64_t":$mPerWave,
"int64_t":$nPerWave, "int64_t":$mnPerXdl, "int64_t":$splitKFactor,
"int64_t":$scheduleVersion, "int64_t":$outputSwizzle,
"bool":$forceUnroll);
"int64_t":$wavesPerEU, "int64_t":$gridGroupSize, "bool":$forceUnroll);

let extraClassDeclaration = [{
void getPerfConfigStr(::llvm::SmallVectorImpl<char> &perfStr) {
Expand All @@ -339,6 +340,8 @@ def Rock_MfmaGemmParamsAttr
+ Twine(getSplitKFactor()) + ","
+ Twine(getScheduleVersion()) + ","
+ Twine(getOutputSwizzle()) + ","
+ Twine(getWavesPerEU()) + ","
+ Twine(getGridGroupSize()) + ","
+ Twine(getForceUnroll()) + ","
+ "1") /* *ThreadCopyMore* */
.toVector(perfStr);
Expand All @@ -359,7 +362,7 @@ def Rock_WmmaGemmParamsAttr : Rock_Attr<"WmmaGemmParams", [RockTuningParamAttrIn
"int64_t":$nPerBlock, "int64_t":$kpack, "int64_t":$mPerWave,
"int64_t":$nPerWave, "int64_t":$mnPerXdl, "int64_t":$splitKFactor,
"int64_t":$scheduleVersion, "int64_t":$outputSwizzle,
"bool":$forceUnroll);
"int64_t":$wavesPerEU, "int64_t":$gridGroupSize, "bool":$forceUnroll);

let extraClassDeclaration = [{
void getPerfConfigStr(SmallVectorImpl<char> &perfStr) {
Expand All @@ -373,6 +376,8 @@ def Rock_WmmaGemmParamsAttr : Rock_Attr<"WmmaGemmParams", [RockTuningParamAttrIn
+ Twine(getSplitKFactor()) + ","
+ Twine(getScheduleVersion()) + ","
+ Twine(getOutputSwizzle()) + ","
+ Twine(getWavesPerEU()) + ","
+ Twine(getGridGroupSize()) + ","
+ Twine(getForceUnroll()) + ","
+ "1") /* *ThreadCopyMore* */
.toVector(perfStr);
Expand Down Expand Up @@ -459,9 +464,25 @@ def Rock_ScheduleVersionAttr : Rock_Attr<"ScheduleVersion"> {
}];
}

// It is a temporary attribute
def Rock_EnableSplitKForTuning : Rock_Attr<"EnableSplitKForTuning"> {
let mnemonic = "enable_splitk_for_tuning";
let description = [{
Whether we tune for split-k. If unset, split-k=1.
}];
}

def Rock_WavesPerEU : Rock_Attr<"WavesPerEU"> {
let mnemonic = "waves_per_eu";
let description = [{
This is a hint to the backend compiler to tune the number of wavefronts that are capable of fitting within the resources of an EU.
}];
}

def Rock_OutputSwizzle : Rock_Attr<"OutputSwizzle"> {
let mnemonic = "output_swizzle";
let description = [{
Whether we run the output swizzle pass. 0 -> disabled, 1 -> enabled, 2 -> heuristic.
}];
}

def Rock_PrefillAttr : Rock_Attr<"Prefill"> {
Expand Down
31 changes: 29 additions & 2 deletions mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ struct InitParamsNonAccel : InitParams, Serializable<InitParamsNonAccel> {
};

struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
// TODO: remove once we generate new quick tuning list
constexpr InitParamsAccel(int64_t mPerBlock, int64_t nPerBlock,
int64_t kPerBlock, int64_t mPerWave,
int64_t nPerWave, int64_t mnPerXdl, int64_t kPack,
Expand All @@ -152,12 +153,28 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
gemmNPerWave(nPerWave), gemmMnPerXdl(mnPerXdl),
gemmNPerWaveOrMnPerXdl(0), gemmKPack(kPack), splitKFactor(splitKFactor),
gemmScheduleVersion(scheduleVersion), outputSwizzle(outputSwizzle),
wavesPerEU(0), gridGroupSize(0),
gemmAThreadCopyMoreGemmK(aThreadCopyMoreGemmK),
gemmBThreadCopyMoreGemmKPack(bThreadCopyMoreGemmKPack) {}

constexpr InitParamsAccel(int64_t mPerBlock, int64_t nPerBlock,
int64_t kPerBlock, int64_t mPerWave,
int64_t nPerWave, int64_t mnPerXdl, int64_t kPack,
int64_t splitKFactor, int64_t scheduleVersion,
int64_t outputSwizzle, int64_t wavesPerEU,
int64_t gridGroupSize, bool aThreadCopyMoreGemmK,
bool bThreadCopyMoreGemmKPack)
: InitParams{mPerBlock, nPerBlock, kPerBlock}, gemmMPerWave(mPerWave),
gemmNPerWave(nPerWave), gemmMnPerXdl(mnPerXdl),
gemmNPerWaveOrMnPerXdl(0), gemmKPack(kPack), splitKFactor(splitKFactor),
gemmScheduleVersion(scheduleVersion), outputSwizzle(outputSwizzle),
wavesPerEU(wavesPerEU), gridGroupSize(gridGroupSize),
gemmAThreadCopyMoreGemmK(aThreadCopyMoreGemmK),
gemmBThreadCopyMoreGemmKPack(bThreadCopyMoreGemmKPack) {}

constexpr InitParamsAccel()
: InitParamsAccel(0LL, 0LL, 0LL, 0LL, 0LL, 0LL, 0LL, 1LL, 1LL, 2LL, false,
false) {}
: InitParamsAccel(0LL, 0LL, 0LL, 0LL, 0LL, 0LL, 0LL, 1LL, 1LL, 2LL, 0LL,
0LL, false, false) {}

InitParamsAccel(MfmaGemmParamsAttr attr)
: InitParams{attr.getMPerBlock(), attr.getNPerBlock(),
Expand All @@ -167,6 +184,8 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
gemmKPack(attr.getKpack()), splitKFactor(attr.getSplitKFactor()),
gemmScheduleVersion(attr.getScheduleVersion()),
outputSwizzle(attr.getOutputSwizzle()),
wavesPerEU(attr.getWavesPerEU()),
gridGroupSize(attr.getGridGroupSize()),
gemmAThreadCopyMoreGemmK(attr.getForceUnroll()),
gemmBThreadCopyMoreGemmKPack(false) {};

Expand All @@ -178,6 +197,8 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
gemmKPack(attr.getKpack()), splitKFactor(attr.getSplitKFactor()),
gemmScheduleVersion(attr.getScheduleVersion()),
outputSwizzle(attr.getOutputSwizzle()),
wavesPerEU(attr.getWavesPerEU()),
gridGroupSize(attr.getGridGroupSize()),
gemmAThreadCopyMoreGemmK(attr.getForceUnroll()),
gemmBThreadCopyMoreGemmKPack(false) {};

Expand All @@ -191,6 +212,8 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
int64_t splitKFactor;
int64_t gemmScheduleVersion;
int64_t outputSwizzle;
int64_t wavesPerEU;
int64_t gridGroupSize;
bool gemmAThreadCopyMoreGemmK;
bool gemmBThreadCopyMoreGemmKPack;

Expand All @@ -214,6 +237,10 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
f(self.gemmScheduleVersion);
f(self.outputSwizzle);
}
if (self.version >= Version::V4) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why adding this here instead of adding it in the previous check if (self.version >= Version::V4)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because we want to visit the params in the right order.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see serialize() in Serializable.h

f(self.wavesPerEU);
f(self.gridGroupSize);
}
f(self.gemmAThreadCopyMoreGemmK);
f(self.gemmBThreadCopyMoreGemmKPack);
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Rock/Tuning/Serializable.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ struct Serializable {
}

bool checkVersionFormat(const std::string &s) {
const int32_t maxNumTokensArray[] = {0, 8, 9, 11, 12};
const int32_t maxNumTokensArray[] = {0, 8, 9, 11, 14};
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we haven't done a release since v4 (I think), is it ok to keep the same version?

const int32_t versionIdx = static_cast<int32_t>(version);
if (versionIdx < 1 || versionIdx >= static_cast<int32_t>(Version::Count)) {
llvm_unreachable("Unknown version of the perfConfig");
Expand Down
21 changes: 20 additions & 1 deletion mlir/lib/Conversion/RockToGPU/RockToGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,11 @@ void LowerRockOpsToGPUPass::runOnOperation() {
gridSize = cast<IntegerAttr>(gridSizeAttr).getInt();
gpuFunc.setKnownGridSizeAttr(b.getDenseI32ArrayAttr({gridSize, 1, 1}));

auto wavesPerEUAttr = theFunc->getAttr(rock::WavesPerEUAttr::getMnemonic());
if (wavesPerEUAttr) {
gpuFunc->setAttr(rock::WavesPerEUAttr::getMnemonic(), wavesPerEUAttr);
}

FailureOr<StringAttr> maybeArch = rock::getArch(theFunc);
if (succeeded(maybeArch)) {
gpuFunc->setAttr("arch", maybeArch.value());
Expand Down Expand Up @@ -391,7 +396,21 @@ void LowerRockOpsToGPUPass::runOnOperation() {
gpuFunc->setAttr("rock.shared_buffer_size",
b.getI32IntegerAttr(ldsUsage));
}
LLVM_DEBUG(llvm::dbgs() << "Attempting to set wavesPerEU...\n");
// if waves_per_eu is set, use it
if (gpuFunc->hasAttrOfType<IntegerAttr>(rock::WavesPerEUAttr::getMnemonic())) {
int64_t wavesPerEU =
gpuFunc->getAttrOfType<IntegerAttr>(rock::WavesPerEUAttr::getMnemonic()).getInt();
// zero means, use heuristic
if (wavesPerEU != 0) {
gpuFunc->setAttr("rocdl.waves_per_eu", b.getI32IntegerAttr(wavesPerEU));
LLVM_DEBUG(llvm::dbgs() << "Setting waves_per_eu using tuning param\n");
// we are done
return;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it returning from here ? Doesn't it need to set some other attributes ?

Copy link
Contributor Author

@dhernandez0 dhernandez0 Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, we are inside op.walk([](gpu::GPUFuncOp gpuFunc) {, only two attributes are set: rock.shared_buffer_size and rocdl.waves_per_eu. The code below is just the heuristic to set rocdl.waves_per_eu.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also was confused by this but yes, it looks correct. However, to make it more readable, we could move all the code below into a function like "runWavesPerEUHeuristic", that way it would be easier to understand why a return here

}
}

// no "waves_per_eu" attribute, use heuristic
LLVM_DEBUG(llvm::dbgs() << "Using heuristic to set wavesPerEU...\n");
if (!gpuFunc->hasAttrOfType<IntegerAttr>("block_size")) {
LLVM_DEBUG(llvm::dbgs() << "blockSize not found in gpuFunc.\n");
return;
Expand Down
9 changes: 5 additions & 4 deletions mlir/lib/Dialect/Rock/IR/RockDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3254,17 +3254,17 @@ AttnPerfConfigAttr AttnPerfConfigAttr::get(StringAttr perfConfigStrAttr,
expectedNumTokens = 11;
break;
case 3:
expectedNumTokens = 12;
expectedNumTokens = 13;
break;
default:
llvm_unreachable("Unknown version of the perfConfig");
}
SmallVector<StringRef, 11> tokens;
SmallVector<StringRef, 13> tokens;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we use expectedNumTokens?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expectedNumTokens is a variable, it wouldn't work here, we could change this to:

SmallVector<StringRef> tokens;
tokens.reserve(expectedNumTokens);

Happy to change it to this if you prefer it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think it would look cleaner if we move SmallVector<int64_t, 13> params here and do:

SmallVector<StringRef> tokens; 
SmallVector<StringRef> params 
tokens.reserve(expectedNumTokens);
params .reserve(expectedNumTokens);

rest.split(tokens, ',');
if (tokens.size() != expectedNumTokens) {
return {};
}
SmallVector<int64_t, 11> params;
SmallVector<int64_t, 13> params;
llvm::transform(tokens, std::back_inserter(params), [](StringRef s) {
int param;
llvm::to_integer(s, param);
Expand Down Expand Up @@ -3298,11 +3298,12 @@ AttnPerfConfigAttr AttnPerfConfigAttr::get(StringAttr perfConfigStrAttr,
int64_t splitKFactor = version > 1 ? params[lastIdx++] : 1;
int64_t scheduleVersion = version > 1 ? params[lastIdx++] : 1;
int64_t outputSwizzle = version > 1 ? params[lastIdx++] : 2;
int64_t wavesPerEU = isV3 ? params[lastIdx++] : 0; // 0 -> use heuristic
int64_t forceUnroll = params[expectedNumTokens - 1] == 1;
return AttnPerfConfigAttr::get(
perfConfigStrAttr.getContext(), mPerBlockG0, mPerBlockG1, nPerBlockG0,
kpackPerBlock, mPerWave, nPerWave, mnPerXdl, kpack, splitKFactor,
scheduleVersion, outputSwizzle, forceUnroll);
scheduleVersion, outputSwizzle, wavesPerEU, forceUnroll);
}

//===-----------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ void AffixTuningParameters::affixTuningParametersImpl(
Attribute params0 = op.getGemm0Params().value_or(nullptr);
// set a default one if params is not provided
StringAttr perfConfigStrAttr =
builder.getStringAttr("attn:v3:32,32,32,32,32,32,16,1,1,1,2,1");
builder.getStringAttr("attn:v3:32,32,32,32,32,32,16,1,1,1,2,0,1");
if (!params0) {
if (StringAttr mayBePerfConfigStrAttr =
dyn_cast_or_null<StringAttr>(op->getAttr("perf_config"))) {
Expand Down
9 changes: 9 additions & 0 deletions mlir/lib/Dialect/Rock/Transforms/GridLayoutEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,15 @@ GridCoordinates rock::layout::makeGroupedGridLayout(PatternRewriter &b,
int64_t bitWidthOut = info.outputType.getIntOrFloatBitWidth();
int64_t groupSize =
std::ceil(std::sqrt(info.numCU)) * (bitWidthOut / bitWidthIn);
// use gridGroupSize if it's not zero
if (info.gridGroupSize != 0) {
groupSize = info.gridGroupSize;
LLVM_DEBUG(llvm::dbgs() << "Setting groupSize by using tuning params to "
<< groupSize << "\n");
} else {
LLVM_DEBUG(llvm::dbgs()
<< "Using heuristic to set groupSize to " << groupSize << "\n");
}

Value mBlocksPerGroup = b.createOrFold<ConstantIndexOp>(loc, groupSize);
Value blocksPerGroup =
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Rock/Transforms/GridLayoutEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ struct GridLayoutInfo {
int64_t numCU;
Type inputType;
Type outputType;
int64_t gridGroupSize;
};

/// This function emits the right triplet of <group,block_m,block_n>
Expand Down
31 changes: 29 additions & 2 deletions mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,9 +532,12 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<GridwiseGemmOp> {
if (failed(maybeArch)) {
return op.emitError("arch needs to be set.");
}
// always use heuristic for non-accel path
int64_t gridGroupSize = 0;
auto gridCoords = layout::makeGroupedGridLayout(
b, loc, bid,
{G, mBlocks, nBlocks, rock::getNumCUValue(op), elementTypeA, destType},
{G, mBlocks, nBlocks, rock::getNumCUValue(op), elementTypeA, destType,
gridGroupSize},
maybeArch->getValue());

Value storeBufferA = GpuAllocOp::create(b, loc, loadBufferA.getType());
Expand Down Expand Up @@ -1978,6 +1981,19 @@ struct GridwiseAttentionAccelRewritePattern
rock::accel::AccelEmitterParams accelParamsGemm1 =
accelEmitterPtrGemm1->getParams();

// wavesPerEU is needed in RockToGPU pass and OutputSwizzle for the OutputSwizzle pass. We add them as func attributes.
assert(gemm0TuningParams.getWavesPerEU() ==
gemm1TuningParams.getWavesPerEU());
assert(gemm0TuningParams.getOutputSwizzle() ==
gemm1TuningParams.getOutputSwizzle());
IntegerAttr wavesPerEUAttr =
rewriter.getI64IntegerAttr(gemm0TuningParams.getWavesPerEU());
IntegerAttr outputSwizzleAttr =
rewriter.getI64IntegerAttr(gemm0TuningParams.getOutputSwizzle());
func::FuncOp funcOp = cast<func::FuncOp>(op->getParentOp());
funcOp->setAttr(rock::WavesPerEUAttr::getMnemonic(), wavesPerEUAttr);
funcOp->setAttr(rock::OutputSwizzleAttr::getMnemonic(), outputSwizzleAttr);

// Get current workgroup ID.
auto bid = WorkgroupIdOp::create(rewriter, loc, rewriter.getIndexType());
// Get current workitem ID.
Expand Down Expand Up @@ -3014,11 +3030,22 @@ struct GridwiseGemmAccelRewritePattern
auto tid = WorkitemIdOp::create(b, loc, b.getIndexType());

// Compute grid coordinates
int64_t gridGroupSize = tuningParams.getGridGroupSize();
auto gridCoords = layout::makeGroupedGridLayout(
b, loc, bid,
{G, mBlocks, nBlocks, rock::getNumCUValue(op), elementTypeA, destType},
{G, mBlocks, nBlocks, rock::getNumCUValue(op), elementTypeA, destType,
gridGroupSize},
arch);

// wavesPerEU is needed in RockToGPU pass and OutputSwizzle for the OutputSwizzle pass. We add them as func attributes.
IntegerAttr wavesPerEUAttr =
b.getI64IntegerAttr(tuningParams.getWavesPerEU());
IntegerAttr outputSwizzleAttr =
b.getI64IntegerAttr(tuningParams.getOutputSwizzle());
func::FuncOp funcOp = cast<func::FuncOp>(op->getParentOp());
funcOp->setAttr(rock::WavesPerEUAttr::getMnemonic(), wavesPerEUAttr);
funcOp->setAttr(rock::OutputSwizzleAttr::getMnemonic(), outputSwizzleAttr);

LDSLayoutConfigDim ldsLayoutConfigA = getLDSLayoutConfigDim(
elementTypeA, kpack, maybeVecDimInfoA.value(), directToLDS);
LDSLayoutConfigDim ldsLayoutConfigB = getLDSLayoutConfigDim(
Expand Down
Loading
Loading