-
Notifications
You must be signed in to change notification settings - Fork 50
Greedy third iteration #2140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Greedy third iteration #2140
Changes from 2 commits
5ad0967
8884d8b
21449a2
87ed630
3ae9033
414ef56
7e10468
0f7b9b8
e9088bc
4f4141c
b625975
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -142,6 +142,7 @@ struct InitParamsNonAccel : InitParams, Serializable<InitParamsNonAccel> { | |
| }; | ||
|
|
||
| struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> { | ||
| // TODO: remove once we generate new quick tuning list | ||
| constexpr InitParamsAccel(int64_t mPerBlock, int64_t nPerBlock, | ||
| int64_t kPerBlock, int64_t mPerWave, | ||
| int64_t nPerWave, int64_t mnPerXdl, int64_t kPack, | ||
|
|
@@ -152,12 +153,28 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> { | |
| gemmNPerWave(nPerWave), gemmMnPerXdl(mnPerXdl), | ||
| gemmNPerWaveOrMnPerXdl(0), gemmKPack(kPack), splitKFactor(splitKFactor), | ||
| gemmScheduleVersion(scheduleVersion), outputSwizzle(outputSwizzle), | ||
| wavesPerEU(0), gridGroupSize(0), | ||
| gemmAThreadCopyMoreGemmK(aThreadCopyMoreGemmK), | ||
| gemmBThreadCopyMoreGemmKPack(bThreadCopyMoreGemmKPack) {} | ||
|
|
||
| constexpr InitParamsAccel(int64_t mPerBlock, int64_t nPerBlock, | ||
| int64_t kPerBlock, int64_t mPerWave, | ||
| int64_t nPerWave, int64_t mnPerXdl, int64_t kPack, | ||
| int64_t splitKFactor, int64_t scheduleVersion, | ||
| int64_t outputSwizzle, int64_t wavesPerEU, | ||
| int64_t gridGroupSize, bool aThreadCopyMoreGemmK, | ||
| bool bThreadCopyMoreGemmKPack) | ||
| : InitParams{mPerBlock, nPerBlock, kPerBlock}, gemmMPerWave(mPerWave), | ||
| gemmNPerWave(nPerWave), gemmMnPerXdl(mnPerXdl), | ||
| gemmNPerWaveOrMnPerXdl(0), gemmKPack(kPack), splitKFactor(splitKFactor), | ||
| gemmScheduleVersion(scheduleVersion), outputSwizzle(outputSwizzle), | ||
| wavesPerEU(wavesPerEU), gridGroupSize(gridGroupSize), | ||
| gemmAThreadCopyMoreGemmK(aThreadCopyMoreGemmK), | ||
| gemmBThreadCopyMoreGemmKPack(bThreadCopyMoreGemmKPack) {} | ||
|
|
||
| constexpr InitParamsAccel() | ||
| : InitParamsAccel(0LL, 0LL, 0LL, 0LL, 0LL, 0LL, 0LL, 1LL, 1LL, 2LL, false, | ||
| false) {} | ||
| : InitParamsAccel(0LL, 0LL, 0LL, 0LL, 0LL, 0LL, 0LL, 1LL, 1LL, 2LL, 0LL, | ||
| 0LL, false, false) {} | ||
|
|
||
| InitParamsAccel(MfmaGemmParamsAttr attr) | ||
| : InitParams{attr.getMPerBlock(), attr.getNPerBlock(), | ||
|
|
@@ -167,6 +184,8 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> { | |
| gemmKPack(attr.getKpack()), splitKFactor(attr.getSplitKFactor()), | ||
| gemmScheduleVersion(attr.getScheduleVersion()), | ||
| outputSwizzle(attr.getOutputSwizzle()), | ||
| wavesPerEU(attr.getWavesPerEU()), | ||
| gridGroupSize(attr.getGridGroupSize()), | ||
| gemmAThreadCopyMoreGemmK(attr.getForceUnroll()), | ||
| gemmBThreadCopyMoreGemmKPack(false) {}; | ||
|
|
||
|
|
@@ -178,6 +197,8 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> { | |
| gemmKPack(attr.getKpack()), splitKFactor(attr.getSplitKFactor()), | ||
| gemmScheduleVersion(attr.getScheduleVersion()), | ||
| outputSwizzle(attr.getOutputSwizzle()), | ||
| wavesPerEU(attr.getWavesPerEU()), | ||
| gridGroupSize(attr.getGridGroupSize()), | ||
| gemmAThreadCopyMoreGemmK(attr.getForceUnroll()), | ||
| gemmBThreadCopyMoreGemmKPack(false) {}; | ||
|
|
||
|
|
@@ -191,6 +212,8 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> { | |
| int64_t splitKFactor; | ||
| int64_t gemmScheduleVersion; | ||
| int64_t outputSwizzle; | ||
| int64_t wavesPerEU; | ||
| int64_t gridGroupSize; | ||
| bool gemmAThreadCopyMoreGemmK; | ||
| bool gemmBThreadCopyMoreGemmKPack; | ||
|
|
||
|
|
@@ -214,6 +237,10 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> { | |
| f(self.gemmScheduleVersion); | ||
| f(self.outputSwizzle); | ||
| } | ||
| if (self.version >= Version::V4) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why adding this here instead of adding it in the previous check
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. because we want to visit the params in the right order.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see serialize() in Serializable.h |
||
| f(self.wavesPerEU); | ||
| f(self.gridGroupSize); | ||
| } | ||
| f(self.gemmAThreadCopyMoreGemmK); | ||
| f(self.gemmBThreadCopyMoreGemmKPack); | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -71,7 +71,7 @@ struct Serializable { | |
| } | ||
|
|
||
| bool checkVersionFormat(const std::string &s) { | ||
| const int32_t maxNumTokensArray[] = {0, 8, 9, 11, 12}; | ||
| const int32_t maxNumTokensArray[] = {0, 8, 9, 11, 14}; | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since we haven't done a release since v4 (I think), is it ok to keep the same version? |
||
| const int32_t versionIdx = static_cast<int32_t>(version); | ||
| if (versionIdx < 1 || versionIdx >= static_cast<int32_t>(Version::Count)) { | ||
| llvm_unreachable("Unknown version of the perfConfig"); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -204,6 +204,11 @@ void LowerRockOpsToGPUPass::runOnOperation() { | |
| gridSize = cast<IntegerAttr>(gridSizeAttr).getInt(); | ||
| gpuFunc.setKnownGridSizeAttr(b.getDenseI32ArrayAttr({gridSize, 1, 1})); | ||
|
|
||
| auto wavesPerEUAttr = theFunc->getAttr(rock::WavesPerEUAttr::getMnemonic()); | ||
| if (wavesPerEUAttr) { | ||
| gpuFunc->setAttr(rock::WavesPerEUAttr::getMnemonic(), wavesPerEUAttr); | ||
| } | ||
|
|
||
| FailureOr<StringAttr> maybeArch = rock::getArch(theFunc); | ||
| if (succeeded(maybeArch)) { | ||
| gpuFunc->setAttr("arch", maybeArch.value()); | ||
|
|
@@ -391,7 +396,21 @@ void LowerRockOpsToGPUPass::runOnOperation() { | |
| gpuFunc->setAttr("rock.shared_buffer_size", | ||
| b.getI32IntegerAttr(ldsUsage)); | ||
| } | ||
| LLVM_DEBUG(llvm::dbgs() << "Attempting to set wavesPerEU...\n"); | ||
| // if waves_per_eu is set, use it | ||
| if (gpuFunc->hasAttrOfType<IntegerAttr>(rock::WavesPerEUAttr::getMnemonic())) { | ||
| int64_t wavesPerEU = | ||
| gpuFunc->getAttrOfType<IntegerAttr>(rock::WavesPerEUAttr::getMnemonic()).getInt(); | ||
| // zero means, use heuristic | ||
| if (wavesPerEU != 0) { | ||
| gpuFunc->setAttr("rocdl.waves_per_eu", b.getI32IntegerAttr(wavesPerEU)); | ||
| LLVM_DEBUG(llvm::dbgs() << "Setting waves_per_eu using tuning param\n"); | ||
| // we are done | ||
| return; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is it returning from here ? Doesn't it need to set some other attributes ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no, we are inside
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also was confused by this but yes, it looks correct. However, to make it more readable, we could move all the code below into a function like "runWavesPerEUHeuristic", that way it would be easier to understand why a return here |
||
| } | ||
| } | ||
|
|
||
| // no "waves_per_eu" attribute, use heuristic | ||
| LLVM_DEBUG(llvm::dbgs() << "Using heuristic to set wavesPerEU...\n"); | ||
| if (!gpuFunc->hasAttrOfType<IntegerAttr>("block_size")) { | ||
| LLVM_DEBUG(llvm::dbgs() << "blockSize not found in gpuFunc.\n"); | ||
| return; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3254,17 +3254,17 @@ AttnPerfConfigAttr AttnPerfConfigAttr::get(StringAttr perfConfigStrAttr, | |
| expectedNumTokens = 11; | ||
| break; | ||
| case 3: | ||
| expectedNumTokens = 12; | ||
| expectedNumTokens = 13; | ||
| break; | ||
| default: | ||
| llvm_unreachable("Unknown version of the perfConfig"); | ||
| } | ||
| SmallVector<StringRef, 11> tokens; | ||
| SmallVector<StringRef, 13> tokens; | ||
|
||
| rest.split(tokens, ','); | ||
| if (tokens.size() != expectedNumTokens) { | ||
| return {}; | ||
| } | ||
| SmallVector<int64_t, 11> params; | ||
| SmallVector<int64_t, 13> params; | ||
| llvm::transform(tokens, std::back_inserter(params), [](StringRef s) { | ||
| int param; | ||
| llvm::to_integer(s, param); | ||
|
|
@@ -3298,11 +3298,12 @@ AttnPerfConfigAttr AttnPerfConfigAttr::get(StringAttr perfConfigStrAttr, | |
| int64_t splitKFactor = version > 1 ? params[lastIdx++] : 1; | ||
| int64_t scheduleVersion = version > 1 ? params[lastIdx++] : 1; | ||
| int64_t outputSwizzle = version > 1 ? params[lastIdx++] : 2; | ||
| int64_t wavesPerEU = isV3 ? params[lastIdx++] : 0; // 0 -> use heuristic | ||
| int64_t forceUnroll = params[expectedNumTokens - 1] == 1; | ||
| return AttnPerfConfigAttr::get( | ||
| perfConfigStrAttr.getContext(), mPerBlockG0, mPerBlockG1, nPerBlockG0, | ||
| kpackPerBlock, mPerWave, nPerWave, mnPerXdl, kpack, splitKFactor, | ||
| scheduleVersion, outputSwizzle, forceUnroll); | ||
| scheduleVersion, outputSwizzle, wavesPerEU, forceUnroll); | ||
| } | ||
|
|
||
| //===-----------------------------------------------------===// | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not gridGroupSize?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gridGroupSize is only used for GEMMs/convs, not attention/g+g kernels.