-
Notifications
You must be signed in to change notification settings - Fork 50
Greedy third iteration #2140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Greedy third iteration #2140
Changes from all commits
5ad0967
8884d8b
21449a2
87ed630
3ae9033
414ef56
7e10468
0f7b9b8
e9088bc
4f4141c
b625975
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -142,6 +142,7 @@ struct InitParamsNonAccel : InitParams, Serializable<InitParamsNonAccel> { | |
| }; | ||
|
|
||
| struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> { | ||
| // TODO: remove once we generate new quick tuning list | ||
| constexpr InitParamsAccel(int64_t mPerBlock, int64_t nPerBlock, | ||
| int64_t kPerBlock, int64_t mPerWave, | ||
| int64_t nPerWave, int64_t mnPerXdl, int64_t kPack, | ||
|
|
@@ -152,12 +153,28 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> { | |
| gemmNPerWave(nPerWave), gemmMnPerXdl(mnPerXdl), | ||
| gemmNPerWaveOrMnPerXdl(0), gemmKPack(kPack), splitKFactor(splitKFactor), | ||
| gemmScheduleVersion(scheduleVersion), outputSwizzle(outputSwizzle), | ||
| wavesPerEU(0), gridGroupSize(0), | ||
| gemmAThreadCopyMoreGemmK(aThreadCopyMoreGemmK), | ||
| gemmBThreadCopyMoreGemmKPack(bThreadCopyMoreGemmKPack) {} | ||
|
|
||
| constexpr InitParamsAccel(int64_t mPerBlock, int64_t nPerBlock, | ||
| int64_t kPerBlock, int64_t mPerWave, | ||
| int64_t nPerWave, int64_t mnPerXdl, int64_t kPack, | ||
| int64_t splitKFactor, int64_t scheduleVersion, | ||
| int64_t outputSwizzle, int64_t wavesPerEU, | ||
| int64_t gridGroupSize, bool aThreadCopyMoreGemmK, | ||
| bool bThreadCopyMoreGemmKPack) | ||
| : InitParams{mPerBlock, nPerBlock, kPerBlock}, gemmMPerWave(mPerWave), | ||
| gemmNPerWave(nPerWave), gemmMnPerXdl(mnPerXdl), | ||
| gemmNPerWaveOrMnPerXdl(0), gemmKPack(kPack), splitKFactor(splitKFactor), | ||
| gemmScheduleVersion(scheduleVersion), outputSwizzle(outputSwizzle), | ||
| wavesPerEU(wavesPerEU), gridGroupSize(gridGroupSize), | ||
| gemmAThreadCopyMoreGemmK(aThreadCopyMoreGemmK), | ||
| gemmBThreadCopyMoreGemmKPack(bThreadCopyMoreGemmKPack) {} | ||
|
|
||
| constexpr InitParamsAccel() | ||
| : InitParamsAccel(0LL, 0LL, 0LL, 0LL, 0LL, 0LL, 0LL, 1LL, 1LL, 2LL, false, | ||
| false) {} | ||
| : InitParamsAccel(0LL, 0LL, 0LL, 0LL, 0LL, 0LL, 0LL, 1LL, 1LL, 2LL, 0LL, | ||
| 0LL, false, false) {} | ||
|
|
||
| InitParamsAccel(MfmaGemmParamsAttr attr) | ||
| : InitParams{attr.getMPerBlock(), attr.getNPerBlock(), | ||
|
|
@@ -167,6 +184,8 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> { | |
| gemmKPack(attr.getKpack()), splitKFactor(attr.getSplitKFactor()), | ||
| gemmScheduleVersion(attr.getScheduleVersion()), | ||
| outputSwizzle(attr.getOutputSwizzle()), | ||
| wavesPerEU(attr.getWavesPerEU()), | ||
| gridGroupSize(attr.getGridGroupSize()), | ||
| gemmAThreadCopyMoreGemmK(attr.getForceUnroll()), | ||
| gemmBThreadCopyMoreGemmKPack(false) {}; | ||
|
|
||
|
|
@@ -178,6 +197,8 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> { | |
| gemmKPack(attr.getKpack()), splitKFactor(attr.getSplitKFactor()), | ||
| gemmScheduleVersion(attr.getScheduleVersion()), | ||
| outputSwizzle(attr.getOutputSwizzle()), | ||
| wavesPerEU(attr.getWavesPerEU()), | ||
| gridGroupSize(attr.getGridGroupSize()), | ||
| gemmAThreadCopyMoreGemmK(attr.getForceUnroll()), | ||
| gemmBThreadCopyMoreGemmKPack(false) {}; | ||
|
|
||
|
|
@@ -191,6 +212,8 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> { | |
| int64_t splitKFactor; | ||
| int64_t gemmScheduleVersion; | ||
| int64_t outputSwizzle; | ||
| int64_t wavesPerEU; | ||
| int64_t gridGroupSize; | ||
| bool gemmAThreadCopyMoreGemmK; | ||
| bool gemmBThreadCopyMoreGemmKPack; | ||
|
|
||
|
|
@@ -214,6 +237,10 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> { | |
| f(self.gemmScheduleVersion); | ||
| f(self.outputSwizzle); | ||
| } | ||
| if (self.version >= Version::V4) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why adding this here instead of adding it in the previous check
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. because we want to visit the params in the right order.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see serialize() in Serializable.h |
||
| f(self.wavesPerEU); | ||
| f(self.gridGroupSize); | ||
| } | ||
| f(self.gemmAThreadCopyMoreGemmK); | ||
| f(self.gemmBThreadCopyMoreGemmKPack); | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -71,7 +71,7 @@ struct Serializable { | |
| } | ||
|
|
||
| bool checkVersionFormat(const std::string &s) { | ||
| const int32_t maxNumTokensArray[] = {0, 8, 9, 11, 12}; | ||
| const int32_t maxNumTokensArray[] = {0, 8, 9, 11, 14}; | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since we haven't done a release since v4 (I think), is it ok to keep the same version? |
||
| const int32_t versionIdx = static_cast<int32_t>(version); | ||
| if (versionIdx < 1 || versionIdx >= static_cast<int32_t>(Version::Count)) { | ||
| llvm_unreachable("Unknown version of the perfConfig"); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -136,6 +136,63 @@ struct WorkgroupIdRewritePattern | |
| }; | ||
| } // namespace | ||
|
|
||
| static void runWavesPerEUHeuristic(OpBuilder b, gpu::GPUFuncOp gpuFunc, | ||
| int64_t ldsUsage) { | ||
| LLVM_DEBUG(llvm::dbgs() << "Using heuristic to set wavesPerEU...\n"); | ||
| if (!gpuFunc->hasAttrOfType<IntegerAttr>("block_size")) { | ||
| LLVM_DEBUG(llvm::dbgs() << "blockSize not found in gpuFunc.\n"); | ||
| return; | ||
| } | ||
| int64_t blockSize = | ||
| gpuFunc->getAttrOfType<IntegerAttr>("block_size").getInt(); | ||
| if (!gpuFunc->hasAttrOfType<IntegerAttr>("grid_size")) { | ||
| LLVM_DEBUG(llvm::dbgs() << "gridSize not found in gpuFunc.\n"); | ||
| return; | ||
| } | ||
| int64_t gridSize = gpuFunc->getAttrOfType<IntegerAttr>("grid_size").getInt(); | ||
| FailureOr<StringAttr> maybeArch = rock::getArch(gpuFunc); | ||
| if (succeeded(maybeArch)) { | ||
| StringAttr arch = maybeArch.value(); | ||
| rock::AmdArchInfo archInfo = rock::lookupArchInfo(arch); | ||
| FailureOr<int64_t> maybeNumCU = rock::getNumCU(gpuFunc); | ||
| int64_t numCU = maybeNumCU.value_or(archInfo.minNumCU); | ||
| int64_t totalEUs = archInfo.numEUPerCU * numCU; | ||
| int64_t wavesPerBlock = (blockSize / archInfo.waveSize); | ||
| int64_t totalWaves = wavesPerBlock * gridSize; | ||
| int64_t wavesPerEUPerBlock = wavesPerBlock / archInfo.numEUPerCU; | ||
| int64_t wavesPerEUPerGrid = (totalWaves + totalEUs - 1) / totalEUs; | ||
| int64_t wavesPerEU = std::max(wavesPerEUPerBlock, wavesPerEUPerGrid); | ||
| LLVM_DEBUG(llvm::dbgs() << "wavesPerEU:" << wavesPerEU << "\n"); | ||
| LLVM_DEBUG(llvm::dbgs() << " blockSize:" << blockSize << "\n"); | ||
| LLVM_DEBUG(llvm::dbgs() << " waveSize:" << archInfo.waveSize << "\n"); | ||
| LLVM_DEBUG(llvm::dbgs() << " gridSize:" << gridSize << "\n"); | ||
| LLVM_DEBUG(llvm::dbgs() << " numCU:" << numCU << "\n"); | ||
| LLVM_DEBUG(llvm::dbgs() << " numEUPerCU:" << archInfo.numEUPerCU << "\n"); | ||
| LLVM_DEBUG(llvm::dbgs() | ||
| << "maxSharedMemPerWG:" << archInfo.maxSharedMemPerWG << "\n"); | ||
| LLVM_DEBUG(llvm::dbgs() << "ldsUsage:" << ldsUsage << "\n"); | ||
| // limit wavesPerEU based on lds usage | ||
| if (ldsUsage > 0) { | ||
| wavesPerEU = | ||
| std::min(wavesPerEU, archInfo.totalSharedMemPerCU / ldsUsage); | ||
| } | ||
| // Currently limiting wavesPerEU to be two | ||
| // it is a future to ticket to remove this constraint with further | ||
| // analysis | ||
| constexpr int64_t wavesPerEUUpperBound = 2; | ||
| wavesPerEU = std::min(wavesPerEU, wavesPerEUUpperBound); | ||
| if (wavesPerEU > 1) { | ||
| LLVM_DEBUG(llvm::dbgs() << "waves_per_eu:" << wavesPerEU << "\n"); | ||
| gpuFunc->setAttr("rocdl.waves_per_eu", b.getI32IntegerAttr(wavesPerEU)); | ||
| } else { | ||
| LLVM_DEBUG(llvm::dbgs() << "waves_per_eu not set" | ||
| << "\n"); | ||
| } | ||
| } else { | ||
| LLVM_DEBUG(llvm::dbgs() << "arch not found.\n"); | ||
| } | ||
| } | ||
|
|
||
| void LowerRockOpsToGPUPass::runOnOperation() { | ||
| ModuleOp op = getOperation(); | ||
| MLIRContext *ctx = op.getContext(); | ||
|
|
@@ -204,6 +261,11 @@ void LowerRockOpsToGPUPass::runOnOperation() { | |
| gridSize = cast<IntegerAttr>(gridSizeAttr).getInt(); | ||
| gpuFunc.setKnownGridSizeAttr(b.getDenseI32ArrayAttr({gridSize, 1, 1})); | ||
|
|
||
| auto wavesPerEUAttr = theFunc->getAttr(rock::WavesPerEUAttr::getMnemonic()); | ||
| if (wavesPerEUAttr) { | ||
| gpuFunc->setAttr(rock::WavesPerEUAttr::getMnemonic(), wavesPerEUAttr); | ||
| } | ||
|
|
||
| FailureOr<StringAttr> maybeArch = rock::getArch(theFunc); | ||
| if (succeeded(maybeArch)) { | ||
| gpuFunc->setAttr("arch", maybeArch.value()); | ||
|
|
@@ -391,61 +453,23 @@ void LowerRockOpsToGPUPass::runOnOperation() { | |
| gpuFunc->setAttr("rock.shared_buffer_size", | ||
| b.getI32IntegerAttr(ldsUsage)); | ||
| } | ||
| LLVM_DEBUG(llvm::dbgs() << "Attempting to set wavesPerEU...\n"); | ||
| if (!gpuFunc->hasAttrOfType<IntegerAttr>("block_size")) { | ||
| LLVM_DEBUG(llvm::dbgs() << "blockSize not found in gpuFunc.\n"); | ||
| return; | ||
| } | ||
| int64_t blockSize = | ||
| gpuFunc->getAttrOfType<IntegerAttr>("block_size").getInt(); | ||
| if (!gpuFunc->hasAttrOfType<IntegerAttr>("grid_size")) { | ||
| LLVM_DEBUG(llvm::dbgs() << "gridSize not found in gpuFunc.\n"); | ||
| return; | ||
| } | ||
| int64_t gridSize = | ||
| gpuFunc->getAttrOfType<IntegerAttr>("grid_size").getInt(); | ||
| FailureOr<StringAttr> maybeArch = rock::getArch(gpuFunc); | ||
| if (succeeded(maybeArch)) { | ||
| StringAttr arch = maybeArch.value(); | ||
| rock::AmdArchInfo archInfo = rock::lookupArchInfo(arch); | ||
| FailureOr<int64_t> maybeNumCU = rock::getNumCU(gpuFunc); | ||
| int64_t numCU = maybeNumCU.value_or(archInfo.minNumCU); | ||
| int64_t totalEUs = archInfo.numEUPerCU * numCU; | ||
| int64_t wavesPerBlock = (blockSize / archInfo.waveSize); | ||
| int64_t totalWaves = wavesPerBlock * gridSize; | ||
| int64_t wavesPerEUPerBlock = wavesPerBlock / archInfo.numEUPerCU; | ||
| int64_t wavesPerEUPerGrid = (totalWaves + totalEUs - 1) / totalEUs; | ||
| int64_t wavesPerEU = std::max(wavesPerEUPerBlock, wavesPerEUPerGrid); | ||
| LLVM_DEBUG(llvm::dbgs() << "wavesPerEU:" << wavesPerEU << "\n"); | ||
| LLVM_DEBUG(llvm::dbgs() << " blockSize:" << blockSize << "\n"); | ||
| LLVM_DEBUG(llvm::dbgs() << " waveSize:" << archInfo.waveSize << "\n"); | ||
| LLVM_DEBUG(llvm::dbgs() << " gridSize:" << gridSize << "\n"); | ||
| LLVM_DEBUG(llvm::dbgs() << " numCU:" << numCU << "\n"); | ||
| LLVM_DEBUG(llvm::dbgs() | ||
| << " numEUPerCU:" << archInfo.numEUPerCU << "\n"); | ||
| LLVM_DEBUG(llvm::dbgs() | ||
| << "maxSharedMemPerWG:" << archInfo.maxSharedMemPerWG << "\n"); | ||
| LLVM_DEBUG(llvm::dbgs() << "ldsUsage:" << ldsUsage << "\n"); | ||
| // limit wavesPerEU based on lds usage | ||
| if (ldsUsage > 0) { | ||
| wavesPerEU = | ||
| std::min(wavesPerEU, archInfo.totalSharedMemPerCU / ldsUsage); | ||
| } | ||
| // Currently limiting wavesPerEU to be two | ||
| // it is a future to ticket to remove this constraint with further | ||
| // analysis | ||
| constexpr int64_t wavesPerEUUpperBound = 2; | ||
| wavesPerEU = std::min(wavesPerEU, wavesPerEUUpperBound); | ||
| if (wavesPerEU > 1) { | ||
| LLVM_DEBUG(llvm::dbgs() << "waves_per_eu:" << wavesPerEU << "\n"); | ||
| // if waves_per_eu is set, use it | ||
| if (gpuFunc->hasAttrOfType<IntegerAttr>( | ||
| rock::WavesPerEUAttr::getMnemonic())) { | ||
| int64_t wavesPerEU = | ||
| gpuFunc | ||
| ->getAttrOfType<IntegerAttr>(rock::WavesPerEUAttr::getMnemonic()) | ||
| .getInt(); | ||
| // zero means, use heuristic | ||
| if (wavesPerEU != 0) { | ||
| gpuFunc->setAttr("rocdl.waves_per_eu", b.getI32IntegerAttr(wavesPerEU)); | ||
| } else { | ||
| LLVM_DEBUG(llvm::dbgs() << "waves_per_eu not set" | ||
| << "\n"); | ||
| LLVM_DEBUG(llvm::dbgs() << "Setting waves_per_eu using tuning param\n"); | ||
| // we are done | ||
| return; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is it returning from here ? Doesn't it need to set some other attributes ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no, we are inside
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also was confused by this but yes, it looks correct. However, to make it more readable, we could move all the code below into a function like "runWavesPerEUHeuristic", that way it would be easier to understand why a return here |
||
| } | ||
| } else { | ||
| LLVM_DEBUG(llvm::dbgs() << "arch not found.\n"); | ||
| } | ||
| // no "waves_per_eu" attribute, use heuristic | ||
| runWavesPerEUHeuristic(b, gpuFunc, ldsUsage); | ||
| }); | ||
|
|
||
| if (gpuModCount == 0) { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not gridGroupSize?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gridGroupSize is only used for GEMMs/convs, not attention/g+g kernels.