From 45292377a8f2a2cf19fc9a14e99a5306a391f833 Mon Sep 17 00:00:00 2001 From: Daniel Hernandez Date: Tue, 9 Dec 2025 11:17:27 +0100 Subject: [PATCH 1/2] Disable tranpose loads for numWaves>4 --- .../Dialect/Rock/utility/LdsTransposeLoad.cpp | 35 ++++++++++++------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp b/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp index b8ebee983839..17fef5370d57 100644 --- a/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp +++ b/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp @@ -278,6 +278,13 @@ LDSTransposeDecision decideLDSTransposeForOperands( } // else - implicitly: neither operand usable, enableA/enableB remain false. + // TODO: adapt code to support numWaves = 8, 16 and 32 (only wmma). + int64_t numWaves = (mPerBlock * nPerBlock) / (mPerWave * nPerWave); + if (numWaves > 4) { + result.enableA = false; + result.enableB = false; + } + return result; } @@ -720,31 +727,31 @@ static std::pair computeLDSBaseOffsets(PatternRewriter &b, // // Parameters: // waveId - Runtime wave ID inside the workgroup. -// physicalWaves - Total number of waves (compile-time). // // Returns: // WaveGridLayout containing: // - wavesInM, wavesInN: Grid dimensions. // - waveM, waveN: This wave's assigned 2D grid coordinates. //===----------------------------------------------------------------------===// -static WaveGridLayout computeWaveGridLayout(PatternRewriter &b, Location loc, - Value waveId, int64_t physicalWaves, - int64_t mPerWave, int64_t nPerWave, - int64_t mPerBlock, - int64_t nPerBlock) { +static FailureOr +computeWaveGridLayout(PatternRewriter &b, Location loc, Value waveId, + int64_t mPerWave, int64_t nPerWave, int64_t mPerBlock, + int64_t nPerBlock) { // Calculate how many wave-sized tiles fit in the block dimensions // These determine the wave grid, not accounting for outer loop repeats int64_t waveTilesInM = mPerBlock / mPerWave; int64_t waveTilesInN = nPerBlock / nPerWave; + int64_t numWaves = waveTilesInM * waveTilesInN; // Determine wave grid layout based on physical waves and wave tiles // This distributes waves spatially across M and N dimensions - // Note: physicalWaves can only be 1, 2, 3, or 4 (for 64, 128, 192, 256 + // Note: numWaves can only be 1, 2, 3, or 4 (for 64, 128, 192, 256 // threads) + // TODO: numWaves can be 8 and 16 (and 32 for wmma) as well, update this code int64_t wavesInM = 1; int64_t wavesInN = 1; - switch (physicalWaves) { + switch (numWaves) { case 1: // Single wave: always 1×1 wavesInM = 1; @@ -805,7 +812,7 @@ static WaveGridLayout computeWaveGridLayout(PatternRewriter &b, Location loc, } break; default: - llvm_unreachable("Invalid physicalWaves: blockSize / waveSize"); + return failure(); } LLVM_DEBUG(llvm::dbgs() << "[lds_transpose] Wave grid layout: " << wavesInM @@ -818,7 +825,7 @@ static WaveGridLayout computeWaveGridLayout(PatternRewriter &b, Location loc, Value waveM = arith::DivUIOp::create(b, loc, waveId, wavesInNVal); Value waveN = arith::RemUIOp::create(b, loc, waveId, wavesInNVal); - return {wavesInM, wavesInN, waveM, waveN}; + return WaveGridLayout{wavesInM, wavesInN, waveM, waveN}; } //===----------------------------------------------------------------------===// @@ -1084,7 +1091,6 @@ LogicalResult emitThreadwiseHWTranspose(PatternRewriter &b, Value waveSizeVal = arith::ConstantIndexOp::create(b, loc, waveSize); Value lane = arith::RemUIOp::create(b, loc, tid, waveSizeVal); Value waveId = arith::DivUIOp::create(b, loc, tid, waveSizeVal); - int64_t physicalWaves = blockSize / waveSize; // Read parameters directly from config int64_t dDim = config.getMfmaDDim(); @@ -1099,8 +1105,11 @@ LogicalResult emitThreadwiseHWTranspose(PatternRewriter &b, config.getIsOperandA() ? OperandKind::A : OperandKind::B; // Compute wave grid layout and decompose wave ID into 2D position - WaveGridLayout waveGrid = computeWaveGridLayout( - b, loc, waveId, physicalWaves, mPerWave, nPerWave, mPerBlock, nPerBlock); + FailureOr maybeWaveGrid = computeWaveGridLayout( + b, loc, waveId, mPerWave, nPerWave, mPerBlock, nPerBlock); + assert(succeeded(maybeWaveGrid) && + "If we decided to use transpose load, this must work"); + WaveGridLayout waveGrid = maybeWaveGrid.value(); Value waveM = waveGrid.waveM; Value waveN = waveGrid.waveN; From f1db36da4d0df3e6498896ecb5e8c32c613100fe Mon Sep 17 00:00:00 2001 From: Daniel Hernandez Date: Tue, 2 Dec 2025 18:11:57 +0100 Subject: [PATCH 2/2] Add third phase to greedy tuning (outputSwizzle, wavesPerEU, gridGroupSize) --- .../IR/RockAccelTuningParamAttrInterface.td | 20 ++ .../mlir/Dialect/Rock/IR/RockAttrDefs.td | 31 ++- .../Dialect/Rock/Tuning/GridwiseGemmParams.h | 31 ++- .../mlir/Dialect/Rock/Tuning/Serializable.h | 2 +- mlir/lib/Conversion/RockToGPU/RockToGPU.cpp | 128 +++++++----- mlir/lib/Dialect/Rock/IR/RockDialect.cpp | 11 +- .../Rock/Transforms/AffixTuningParameters.cpp | 2 +- .../Rock/Transforms/GridLayoutEmitter.cpp | 9 + .../Rock/Transforms/GridLayoutEmitter.h | 1 + .../Transforms/GridwiseGemmToBlockwise.cpp | 33 +++- .../Dialect/Rock/Transforms/OutputSwizzle.cpp | 34 +++- .../Rock/Tuning/GridwiseGemmParams.cpp | 15 +- .../Dialect/Rock/Tuning/RockTuningImpl.cpp | 183 +++++++++++++++--- .../Conversion/RockToGPU/waves_per_eu.mlir | 35 ++++ .../Dialect/Rock/affix_tuning_params.mlir | 104 +++++----- mlir/test/Dialect/Rock/async_wait_add.mlir | 8 +- mlir/test/Dialect/Rock/conv_to_gemm.mlir | 6 +- mlir/test/Dialect/Rock/effects.mlir | 46 +++-- mlir/test/Dialect/Rock/gemm_to_gridwise.mlir | 20 +- .../gridwise_attention_accel_lowering.mlir | 67 +++++-- ...ise_attention_accel_lowering_barriers.mlir | 6 +- ...gridwise_attention_accel_lowering_gqa.mlir | 2 +- .../Rock/gridwise_gemm_accel_lowering.mlir | 63 +++++- .../gridwise_gemm_accel_lowering_invalid.mlir | 8 +- ...idwise_gemm_conservative_lds_barriers.mlir | 2 +- .../multibuffer/test_multi_buffer_full.mlir | 2 +- .../test_multi_buffer_full_gfx950.mlir | 2 +- .../Rock/lds_transpose_attributes.mlir | 4 +- .../Rock/loadtile_to_threadwise_lowering.mlir | 6 +- .../Rock/lowering_blockwise_gemm_accel.mlir | 14 +- .../Rock/lowering_blockwise_gemm_wmma.mlir | 8 +- ...ring_gemm_linalg_splitk_normalization.mlir | 18 +- .../Dialect/Rock/lowering_output_swizzle.mlir | 4 +- .../Rock/lowering_output_swizzle_tuning.mlir | 155 +++++++++++++++ .../Rock/lowering_to_threadwise_accel.mlir | 20 +- .../test/Dialect/Rock/lowering_top_level.mlir | 4 +- .../test/Dialect/Rock/lowering_wmma_gemm.mlir | 10 +- .../Dialect/Rock/lowering_xdlops_gemm.mlir | 44 ++--- mlir/test/Dialect/Rock/ops.mlir | 12 +- mlir/test/Dialect/Rock/ops_2.mlir | 10 +- mlir/test/Dialect/Rock/ops_blockwise_f16.mlir | 16 +- mlir/test/Dialect/Rock/ops_error.mlir | 10 +- mlir/test/Dialect/Rock/regularize.mlir | 2 +- .../rock-shuffle-gemm-for-reductions.mlir | 2 +- .../Rock/test-fusion-and-pipeline.mlir | 2 +- mlir/test/Dialect/Rock/test_multi_buffer.mlir | 2 +- .../toblockwise_attention_accel_lowering.mlir | 8 +- .../Rock/toblockwise_gemm_accel_lowering.mlir | 4 +- .../e2e/AttentionNonPowerOfTwoTileSize.toml | 2 +- .../GemmVariantsNonPowerOfTwoTileSize.toml | 2 +- .../test/fusion/bug-1546-compile-failure.mlir | 2 +- ...1550-reduction-fusion-compile-failure.mlir | 2 +- .../linalg-generic-with-atomic-store.mlir | 2 +- .../mixr-attention-padded-scale-cross.mlir | 2 +- .../populate_perf_config_gemm.mlir | 4 +- mlir/test/rocmlir-gen/emit-tuning-space.mlir | 16 +- mlir/tools/rocmlir-gen/rocmlir-gen.cpp | 2 +- .../Dialect/Rock/InitParamsAccelTests.cpp | 10 +- mlir/utils/performance/attentionSweeps.py | 6 +- mlir/utils/performance/parameterSweeps.py | 30 ++- 60 files changed, 958 insertions(+), 348 deletions(-) create mode 100644 mlir/test/Conversion/RockToGPU/waves_per_eu.mlir create mode 100644 mlir/test/Dialect/Rock/lowering_output_swizzle_tuning.mlir diff --git a/mlir/include/mlir/Dialect/Rock/IR/RockAccelTuningParamAttrInterface.td b/mlir/include/mlir/Dialect/Rock/IR/RockAccelTuningParamAttrInterface.td index 26613a499257..c458e68017f3 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/RockAccelTuningParamAttrInterface.td +++ b/mlir/include/mlir/Dialect/Rock/IR/RockAccelTuningParamAttrInterface.td @@ -90,6 +90,26 @@ def RockAccelTuningParamAttrInterface : AttrInterface<"RockAccelTuningParamAttrI /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/"" + >, + InterfaceMethod< + /*desc=*/[{ + Return waves_per_eu, this is a hint to the backend compiler to tune the number of wavefronts that are capable of fitting within the resources of an EU. + }], + /*retType=*/"int64_t", + /*methodName=*/"getWavesPerEU", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/"" + >, + InterfaceMethod< + /*desc=*/[{ + Group size for layout on the distribution of the workgroups. + }], + /*retType=*/"int64_t", + /*methodName=*/"getGridGroupSize", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/"" > // TODO: more methods here as needed diff --git a/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td b/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td index e56fec5b95f5..35022d4835ae 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td +++ b/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td @@ -280,7 +280,7 @@ def Rock_AttnPerfConfig : Rock_Attr<"AttnPerfConfig", [RockTuningParamAttrInterf "int64_t":$nPerBlockG0, "int64_t":$kpackPerBlock, "int64_t":$mPerWave, "int64_t":$nPerWave, "int64_t":$mnPerXdl, "int64_t":$kpack, "int64_t":$splitKFactor, "int64_t":$scheduleVersion, - "int64_t":$outputSwizzle, "bool":$forceUnroll); + "int64_t":$outputSwizzle, "int64_t":$wavesPerEU, "bool":$forceUnroll); let extraClassDeclaration = [{ void getPerfConfigStr(::llvm::SmallVectorImpl &perfStr) { @@ -296,13 +296,14 @@ def Rock_AttnPerfConfig : Rock_Attr<"AttnPerfConfig", [RockTuningParamAttrInterf + Twine(getSplitKFactor()) + "," + Twine(getScheduleVersion()) + "," + Twine(getOutputSwizzle()) + "," + + Twine(getWavesPerEU()) + "," + Twine(getForceUnroll())).toVector(perfStr); } AttnPerfConfigAttr withScheduleVersion(int64_t newScheduleVersion) const { return AttnPerfConfigAttr::get( getContext(), getMPerBlockG0(), getMPerBlockG1(), getNPerBlockG0(), getKpackPerBlock(), getMPerWave(), getNPerWave(), getMnPerXdl(), getKpack(), - getSplitKFactor(), newScheduleVersion, getOutputSwizzle(), getForceUnroll()); + getSplitKFactor(), newScheduleVersion, getOutputSwizzle(), getWavesPerEU(), getForceUnroll()); } }]; @@ -325,7 +326,7 @@ def Rock_MfmaGemmParamsAttr "int64_t":$nPerBlock, "int64_t":$kpack, "int64_t":$mPerWave, "int64_t":$nPerWave, "int64_t":$mnPerXdl, "int64_t":$splitKFactor, "int64_t":$scheduleVersion, "int64_t":$outputSwizzle, - "bool":$forceUnroll); + "int64_t":$wavesPerEU, "int64_t":$gridGroupSize, "bool":$forceUnroll); let extraClassDeclaration = [{ void getPerfConfigStr(::llvm::SmallVectorImpl &perfStr) { @@ -339,6 +340,8 @@ def Rock_MfmaGemmParamsAttr + Twine(getSplitKFactor()) + "," + Twine(getScheduleVersion()) + "," + Twine(getOutputSwizzle()) + "," + + Twine(getWavesPerEU()) + "," + + Twine(getGridGroupSize()) + "," + Twine(getForceUnroll()) + "," + "1") /* *ThreadCopyMore* */ .toVector(perfStr); @@ -359,7 +362,7 @@ def Rock_WmmaGemmParamsAttr : Rock_Attr<"WmmaGemmParams", [RockTuningParamAttrIn "int64_t":$nPerBlock, "int64_t":$kpack, "int64_t":$mPerWave, "int64_t":$nPerWave, "int64_t":$mnPerXdl, "int64_t":$splitKFactor, "int64_t":$scheduleVersion, "int64_t":$outputSwizzle, - "bool":$forceUnroll); + "int64_t":$wavesPerEU, "int64_t":$gridGroupSize, "bool":$forceUnroll); let extraClassDeclaration = [{ void getPerfConfigStr(SmallVectorImpl &perfStr) { @@ -373,6 +376,8 @@ def Rock_WmmaGemmParamsAttr : Rock_Attr<"WmmaGemmParams", [RockTuningParamAttrIn + Twine(getSplitKFactor()) + "," + Twine(getScheduleVersion()) + "," + Twine(getOutputSwizzle()) + "," + + Twine(getWavesPerEU()) + "," + + Twine(getGridGroupSize()) + "," + Twine(getForceUnroll()) + "," + "1") /* *ThreadCopyMore* */ .toVector(perfStr); @@ -459,9 +464,25 @@ def Rock_ScheduleVersionAttr : Rock_Attr<"ScheduleVersion"> { }]; } -// It is a temporary attribute def Rock_EnableSplitKForTuning : Rock_Attr<"EnableSplitKForTuning"> { let mnemonic = "enable_splitk_for_tuning"; + let description = [{ + Whether we tune for split-k. If unset, split-k=1. + }]; +} + +def Rock_WavesPerEU : Rock_Attr<"WavesPerEU"> { + let mnemonic = "waves_per_eu"; + let description = [{ + This is a hint to the backend compiler to tune the number of wavefronts that are capable of fitting within the resources of an EU. + }]; +} + +def Rock_OutputSwizzle : Rock_Attr<"OutputSwizzle"> { + let mnemonic = "output_swizzle"; + let description = [{ + Whether we run the output swizzle pass. 0 -> disabled, 1 -> enabled, 2 -> heuristic. + }]; } def Rock_PrefillAttr : Rock_Attr<"Prefill"> { diff --git a/mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h b/mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h index 0ee32255d874..b65c4f89d27c 100644 --- a/mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h +++ b/mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h @@ -142,6 +142,7 @@ struct InitParamsNonAccel : InitParams, Serializable { }; struct InitParamsAccel : InitParams, Serializable { + // TODO: remove once we generate new quick tuning list constexpr InitParamsAccel(int64_t mPerBlock, int64_t nPerBlock, int64_t kPerBlock, int64_t mPerWave, int64_t nPerWave, int64_t mnPerXdl, int64_t kPack, @@ -152,12 +153,28 @@ struct InitParamsAccel : InitParams, Serializable { gemmNPerWave(nPerWave), gemmMnPerXdl(mnPerXdl), gemmNPerWaveOrMnPerXdl(0), gemmKPack(kPack), splitKFactor(splitKFactor), gemmScheduleVersion(scheduleVersion), outputSwizzle(outputSwizzle), + wavesPerEU(0), gridGroupSize(0), + gemmAThreadCopyMoreGemmK(aThreadCopyMoreGemmK), + gemmBThreadCopyMoreGemmKPack(bThreadCopyMoreGemmKPack) {} + + constexpr InitParamsAccel(int64_t mPerBlock, int64_t nPerBlock, + int64_t kPerBlock, int64_t mPerWave, + int64_t nPerWave, int64_t mnPerXdl, int64_t kPack, + int64_t splitKFactor, int64_t scheduleVersion, + int64_t outputSwizzle, int64_t wavesPerEU, + int64_t gridGroupSize, bool aThreadCopyMoreGemmK, + bool bThreadCopyMoreGemmKPack) + : InitParams{mPerBlock, nPerBlock, kPerBlock}, gemmMPerWave(mPerWave), + gemmNPerWave(nPerWave), gemmMnPerXdl(mnPerXdl), + gemmNPerWaveOrMnPerXdl(0), gemmKPack(kPack), splitKFactor(splitKFactor), + gemmScheduleVersion(scheduleVersion), outputSwizzle(outputSwizzle), + wavesPerEU(wavesPerEU), gridGroupSize(gridGroupSize), gemmAThreadCopyMoreGemmK(aThreadCopyMoreGemmK), gemmBThreadCopyMoreGemmKPack(bThreadCopyMoreGemmKPack) {} constexpr InitParamsAccel() - : InitParamsAccel(0LL, 0LL, 0LL, 0LL, 0LL, 0LL, 0LL, 1LL, 1LL, 2LL, false, - false) {} + : InitParamsAccel(0LL, 0LL, 0LL, 0LL, 0LL, 0LL, 0LL, 1LL, 1LL, 2LL, 0LL, + 0LL, false, false) {} InitParamsAccel(MfmaGemmParamsAttr attr) : InitParams{attr.getMPerBlock(), attr.getNPerBlock(), @@ -167,6 +184,8 @@ struct InitParamsAccel : InitParams, Serializable { gemmKPack(attr.getKpack()), splitKFactor(attr.getSplitKFactor()), gemmScheduleVersion(attr.getScheduleVersion()), outputSwizzle(attr.getOutputSwizzle()), + wavesPerEU(attr.getWavesPerEU()), + gridGroupSize(attr.getGridGroupSize()), gemmAThreadCopyMoreGemmK(attr.getForceUnroll()), gemmBThreadCopyMoreGemmKPack(false) {}; @@ -178,6 +197,8 @@ struct InitParamsAccel : InitParams, Serializable { gemmKPack(attr.getKpack()), splitKFactor(attr.getSplitKFactor()), gemmScheduleVersion(attr.getScheduleVersion()), outputSwizzle(attr.getOutputSwizzle()), + wavesPerEU(attr.getWavesPerEU()), + gridGroupSize(attr.getGridGroupSize()), gemmAThreadCopyMoreGemmK(attr.getForceUnroll()), gemmBThreadCopyMoreGemmKPack(false) {}; @@ -191,6 +212,8 @@ struct InitParamsAccel : InitParams, Serializable { int64_t splitKFactor; int64_t gemmScheduleVersion; int64_t outputSwizzle; + int64_t wavesPerEU; + int64_t gridGroupSize; bool gemmAThreadCopyMoreGemmK; bool gemmBThreadCopyMoreGemmKPack; @@ -214,6 +237,10 @@ struct InitParamsAccel : InitParams, Serializable { f(self.gemmScheduleVersion); f(self.outputSwizzle); } + if (self.version >= Version::V4) { + f(self.wavesPerEU); + f(self.gridGroupSize); + } f(self.gemmAThreadCopyMoreGemmK); f(self.gemmBThreadCopyMoreGemmKPack); } diff --git a/mlir/include/mlir/Dialect/Rock/Tuning/Serializable.h b/mlir/include/mlir/Dialect/Rock/Tuning/Serializable.h index db0b182cbd5c..0756ec51cfa9 100644 --- a/mlir/include/mlir/Dialect/Rock/Tuning/Serializable.h +++ b/mlir/include/mlir/Dialect/Rock/Tuning/Serializable.h @@ -71,7 +71,7 @@ struct Serializable { } bool checkVersionFormat(const std::string &s) { - const int32_t maxNumTokensArray[] = {0, 8, 9, 11, 12}; + const int32_t maxNumTokensArray[] = {0, 8, 9, 11, 14}; const int32_t versionIdx = static_cast(version); if (versionIdx < 1 || versionIdx >= static_cast(Version::Count)) { llvm_unreachable("Unknown version of the perfConfig"); diff --git a/mlir/lib/Conversion/RockToGPU/RockToGPU.cpp b/mlir/lib/Conversion/RockToGPU/RockToGPU.cpp index edd9f89b7f02..e21cadb4c278 100644 --- a/mlir/lib/Conversion/RockToGPU/RockToGPU.cpp +++ b/mlir/lib/Conversion/RockToGPU/RockToGPU.cpp @@ -136,6 +136,63 @@ struct WorkgroupIdRewritePattern }; } // namespace +static void runWavesPerEUHeuristic(OpBuilder b, gpu::GPUFuncOp gpuFunc, + int64_t ldsUsage) { + LLVM_DEBUG(llvm::dbgs() << "Using heuristic to set wavesPerEU...\n"); + if (!gpuFunc->hasAttrOfType("block_size")) { + LLVM_DEBUG(llvm::dbgs() << "blockSize not found in gpuFunc.\n"); + return; + } + int64_t blockSize = + gpuFunc->getAttrOfType("block_size").getInt(); + if (!gpuFunc->hasAttrOfType("grid_size")) { + LLVM_DEBUG(llvm::dbgs() << "gridSize not found in gpuFunc.\n"); + return; + } + int64_t gridSize = gpuFunc->getAttrOfType("grid_size").getInt(); + FailureOr maybeArch = rock::getArch(gpuFunc); + if (succeeded(maybeArch)) { + StringAttr arch = maybeArch.value(); + rock::AmdArchInfo archInfo = rock::lookupArchInfo(arch); + FailureOr maybeNumCU = rock::getNumCU(gpuFunc); + int64_t numCU = maybeNumCU.value_or(archInfo.minNumCU); + int64_t totalEUs = archInfo.numEUPerCU * numCU; + int64_t wavesPerBlock = (blockSize / archInfo.waveSize); + int64_t totalWaves = wavesPerBlock * gridSize; + int64_t wavesPerEUPerBlock = wavesPerBlock / archInfo.numEUPerCU; + int64_t wavesPerEUPerGrid = (totalWaves + totalEUs - 1) / totalEUs; + int64_t wavesPerEU = std::max(wavesPerEUPerBlock, wavesPerEUPerGrid); + LLVM_DEBUG(llvm::dbgs() << "wavesPerEU:" << wavesPerEU << "\n"); + LLVM_DEBUG(llvm::dbgs() << " blockSize:" << blockSize << "\n"); + LLVM_DEBUG(llvm::dbgs() << " waveSize:" << archInfo.waveSize << "\n"); + LLVM_DEBUG(llvm::dbgs() << " gridSize:" << gridSize << "\n"); + LLVM_DEBUG(llvm::dbgs() << " numCU:" << numCU << "\n"); + LLVM_DEBUG(llvm::dbgs() << " numEUPerCU:" << archInfo.numEUPerCU << "\n"); + LLVM_DEBUG(llvm::dbgs() + << "maxSharedMemPerWG:" << archInfo.maxSharedMemPerWG << "\n"); + LLVM_DEBUG(llvm::dbgs() << "ldsUsage:" << ldsUsage << "\n"); + // limit wavesPerEU based on lds usage + if (ldsUsage > 0) { + wavesPerEU = + std::min(wavesPerEU, archInfo.totalSharedMemPerCU / ldsUsage); + } + // Currently limiting wavesPerEU to be two + // it is a future to ticket to remove this constraint with further + // analysis + constexpr int64_t wavesPerEUUpperBound = 2; + wavesPerEU = std::min(wavesPerEU, wavesPerEUUpperBound); + if (wavesPerEU > 1) { + LLVM_DEBUG(llvm::dbgs() << "waves_per_eu:" << wavesPerEU << "\n"); + gpuFunc->setAttr("rocdl.waves_per_eu", b.getI32IntegerAttr(wavesPerEU)); + } else { + LLVM_DEBUG(llvm::dbgs() << "waves_per_eu not set" + << "\n"); + } + } else { + LLVM_DEBUG(llvm::dbgs() << "arch not found.\n"); + } +} + void LowerRockOpsToGPUPass::runOnOperation() { ModuleOp op = getOperation(); MLIRContext *ctx = op.getContext(); @@ -204,6 +261,11 @@ void LowerRockOpsToGPUPass::runOnOperation() { gridSize = cast(gridSizeAttr).getInt(); gpuFunc.setKnownGridSizeAttr(b.getDenseI32ArrayAttr({gridSize, 1, 1})); + auto wavesPerEUAttr = theFunc->getAttr(rock::WavesPerEUAttr::getMnemonic()); + if (wavesPerEUAttr) { + gpuFunc->setAttr(rock::WavesPerEUAttr::getMnemonic(), wavesPerEUAttr); + } + FailureOr maybeArch = rock::getArch(theFunc); if (succeeded(maybeArch)) { gpuFunc->setAttr("arch", maybeArch.value()); @@ -391,61 +453,23 @@ void LowerRockOpsToGPUPass::runOnOperation() { gpuFunc->setAttr("rock.shared_buffer_size", b.getI32IntegerAttr(ldsUsage)); } - LLVM_DEBUG(llvm::dbgs() << "Attempting to set wavesPerEU...\n"); - if (!gpuFunc->hasAttrOfType("block_size")) { - LLVM_DEBUG(llvm::dbgs() << "blockSize not found in gpuFunc.\n"); - return; - } - int64_t blockSize = - gpuFunc->getAttrOfType("block_size").getInt(); - if (!gpuFunc->hasAttrOfType("grid_size")) { - LLVM_DEBUG(llvm::dbgs() << "gridSize not found in gpuFunc.\n"); - return; - } - int64_t gridSize = - gpuFunc->getAttrOfType("grid_size").getInt(); - FailureOr maybeArch = rock::getArch(gpuFunc); - if (succeeded(maybeArch)) { - StringAttr arch = maybeArch.value(); - rock::AmdArchInfo archInfo = rock::lookupArchInfo(arch); - FailureOr maybeNumCU = rock::getNumCU(gpuFunc); - int64_t numCU = maybeNumCU.value_or(archInfo.minNumCU); - int64_t totalEUs = archInfo.numEUPerCU * numCU; - int64_t wavesPerBlock = (blockSize / archInfo.waveSize); - int64_t totalWaves = wavesPerBlock * gridSize; - int64_t wavesPerEUPerBlock = wavesPerBlock / archInfo.numEUPerCU; - int64_t wavesPerEUPerGrid = (totalWaves + totalEUs - 1) / totalEUs; - int64_t wavesPerEU = std::max(wavesPerEUPerBlock, wavesPerEUPerGrid); - LLVM_DEBUG(llvm::dbgs() << "wavesPerEU:" << wavesPerEU << "\n"); - LLVM_DEBUG(llvm::dbgs() << " blockSize:" << blockSize << "\n"); - LLVM_DEBUG(llvm::dbgs() << " waveSize:" << archInfo.waveSize << "\n"); - LLVM_DEBUG(llvm::dbgs() << " gridSize:" << gridSize << "\n"); - LLVM_DEBUG(llvm::dbgs() << " numCU:" << numCU << "\n"); - LLVM_DEBUG(llvm::dbgs() - << " numEUPerCU:" << archInfo.numEUPerCU << "\n"); - LLVM_DEBUG(llvm::dbgs() - << "maxSharedMemPerWG:" << archInfo.maxSharedMemPerWG << "\n"); - LLVM_DEBUG(llvm::dbgs() << "ldsUsage:" << ldsUsage << "\n"); - // limit wavesPerEU based on lds usage - if (ldsUsage > 0) { - wavesPerEU = - std::min(wavesPerEU, archInfo.totalSharedMemPerCU / ldsUsage); - } - // Currently limiting wavesPerEU to be two - // it is a future to ticket to remove this constraint with further - // analysis - constexpr int64_t wavesPerEUUpperBound = 2; - wavesPerEU = std::min(wavesPerEU, wavesPerEUUpperBound); - if (wavesPerEU > 1) { - LLVM_DEBUG(llvm::dbgs() << "waves_per_eu:" << wavesPerEU << "\n"); + // if waves_per_eu is set, use it + if (gpuFunc->hasAttrOfType( + rock::WavesPerEUAttr::getMnemonic())) { + int64_t wavesPerEU = + gpuFunc + ->getAttrOfType(rock::WavesPerEUAttr::getMnemonic()) + .getInt(); + // zero means, use heuristic + if (wavesPerEU != 0) { gpuFunc->setAttr("rocdl.waves_per_eu", b.getI32IntegerAttr(wavesPerEU)); - } else { - LLVM_DEBUG(llvm::dbgs() << "waves_per_eu not set" - << "\n"); + LLVM_DEBUG(llvm::dbgs() << "Setting waves_per_eu using tuning param\n"); + // we are done + return; } - } else { - LLVM_DEBUG(llvm::dbgs() << "arch not found.\n"); } + // no "waves_per_eu" attribute, use heuristic + runWavesPerEUHeuristic(b, gpuFunc, ldsUsage); }); if (gpuModCount == 0) { diff --git a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp index 33df6ca3de85..63e319b5d2d0 100644 --- a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp +++ b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp @@ -3254,17 +3254,19 @@ AttnPerfConfigAttr AttnPerfConfigAttr::get(StringAttr perfConfigStrAttr, expectedNumTokens = 11; break; case 3: - expectedNumTokens = 12; + expectedNumTokens = 13; break; default: llvm_unreachable("Unknown version of the perfConfig"); } - SmallVector tokens; + SmallVector tokens; + SmallVector params; + tokens.reserve(expectedNumTokens); + params.reserve(expectedNumTokens); rest.split(tokens, ','); if (tokens.size() != expectedNumTokens) { return {}; } - SmallVector params; llvm::transform(tokens, std::back_inserter(params), [](StringRef s) { int param; llvm::to_integer(s, param); @@ -3298,11 +3300,12 @@ AttnPerfConfigAttr AttnPerfConfigAttr::get(StringAttr perfConfigStrAttr, int64_t splitKFactor = version > 1 ? params[lastIdx++] : 1; int64_t scheduleVersion = version > 1 ? params[lastIdx++] : 1; int64_t outputSwizzle = version > 1 ? params[lastIdx++] : 2; + int64_t wavesPerEU = isV3 ? params[lastIdx++] : 0; // 0 -> use heuristic int64_t forceUnroll = params[expectedNumTokens - 1] == 1; return AttnPerfConfigAttr::get( perfConfigStrAttr.getContext(), mPerBlockG0, mPerBlockG1, nPerBlockG0, kpackPerBlock, mPerWave, nPerWave, mnPerXdl, kpack, splitKFactor, - scheduleVersion, outputSwizzle, forceUnroll); + scheduleVersion, outputSwizzle, wavesPerEU, forceUnroll); } //===-----------------------------------------------------===// diff --git a/mlir/lib/Dialect/Rock/Transforms/AffixTuningParameters.cpp b/mlir/lib/Dialect/Rock/Transforms/AffixTuningParameters.cpp index 2ac7afe6ff39..6baa09d58323 100644 --- a/mlir/lib/Dialect/Rock/Transforms/AffixTuningParameters.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/AffixTuningParameters.cpp @@ -309,7 +309,7 @@ void AffixTuningParameters::affixTuningParametersImpl( Attribute params0 = op.getGemm0Params().value_or(nullptr); // set a default one if params is not provided StringAttr perfConfigStrAttr = - builder.getStringAttr("attn:v3:32,32,32,32,32,32,16,1,1,1,2,1"); + builder.getStringAttr("attn:v3:32,32,32,32,32,32,16,1,1,1,2,0,1"); if (!params0) { if (StringAttr mayBePerfConfigStrAttr = dyn_cast_or_null(op->getAttr("perf_config"))) { diff --git a/mlir/lib/Dialect/Rock/Transforms/GridLayoutEmitter.cpp b/mlir/lib/Dialect/Rock/Transforms/GridLayoutEmitter.cpp index a684e126eca6..7d4f4485e081 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridLayoutEmitter.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridLayoutEmitter.cpp @@ -93,6 +93,15 @@ GridCoordinates rock::layout::makeGroupedGridLayout(PatternRewriter &b, int64_t bitWidthOut = info.outputType.getIntOrFloatBitWidth(); int64_t groupSize = std::ceil(std::sqrt(info.numCU)) * (bitWidthOut / bitWidthIn); + // use gridGroupSize if it's not zero + if (info.gridGroupSize != 0) { + groupSize = info.gridGroupSize; + LLVM_DEBUG(llvm::dbgs() << "Setting groupSize by using tuning params to " + << groupSize << "\n"); + } else { + LLVM_DEBUG(llvm::dbgs() + << "Using heuristic to set groupSize to " << groupSize << "\n"); + } Value mBlocksPerGroup = b.createOrFold(loc, groupSize); Value blocksPerGroup = diff --git a/mlir/lib/Dialect/Rock/Transforms/GridLayoutEmitter.h b/mlir/lib/Dialect/Rock/Transforms/GridLayoutEmitter.h index 02b6f4245c30..5985e014c3ad 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridLayoutEmitter.h +++ b/mlir/lib/Dialect/Rock/Transforms/GridLayoutEmitter.h @@ -54,6 +54,7 @@ struct GridLayoutInfo { int64_t numCU; Type inputType; Type outputType; + int64_t gridGroupSize; }; /// This function emits the right triplet of diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index f2fd9009f89f..49b1c74f076e 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -532,9 +532,12 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern { if (failed(maybeArch)) { return op.emitError("arch needs to be set."); } + // always use heuristic for non-accel path + int64_t gridGroupSize = 0; auto gridCoords = layout::makeGroupedGridLayout( b, loc, bid, - {G, mBlocks, nBlocks, rock::getNumCUValue(op), elementTypeA, destType}, + {G, mBlocks, nBlocks, rock::getNumCUValue(op), elementTypeA, destType, + gridGroupSize}, maybeArch->getValue()); Value storeBufferA = GpuAllocOp::create(b, loc, loadBufferA.getType()); @@ -1980,6 +1983,20 @@ struct GridwiseAttentionAccelRewritePattern rock::accel::AccelEmitterParams accelParamsGemm1 = accelEmitterPtrGemm1->getParams(); + // wavesPerEU is needed in RockToGPU pass and OutputSwizzle for the + // OutputSwizzle pass. We add them as func attributes. + assert(gemm0TuningParams.getWavesPerEU() == + gemm1TuningParams.getWavesPerEU()); + assert(gemm0TuningParams.getOutputSwizzle() == + gemm1TuningParams.getOutputSwizzle()); + IntegerAttr wavesPerEUAttr = + rewriter.getI64IntegerAttr(gemm0TuningParams.getWavesPerEU()); + IntegerAttr outputSwizzleAttr = + rewriter.getI64IntegerAttr(gemm0TuningParams.getOutputSwizzle()); + func::FuncOp funcOp = cast(op->getParentOp()); + funcOp->setAttr(rock::WavesPerEUAttr::getMnemonic(), wavesPerEUAttr); + funcOp->setAttr(rock::OutputSwizzleAttr::getMnemonic(), outputSwizzleAttr); + // Get current workgroup ID. auto bid = WorkgroupIdOp::create(rewriter, loc, rewriter.getIndexType()); // Get current workitem ID. @@ -3037,11 +3054,23 @@ struct GridwiseGemmAccelRewritePattern auto tid = WorkitemIdOp::create(b, loc, b.getIndexType()); // Compute grid coordinates + int64_t gridGroupSize = tuningParams.getGridGroupSize(); auto gridCoords = layout::makeGroupedGridLayout( b, loc, bid, - {G, mBlocks, nBlocks, rock::getNumCUValue(op), elementTypeA, destType}, + {G, mBlocks, nBlocks, rock::getNumCUValue(op), elementTypeA, destType, + gridGroupSize}, arch); + // wavesPerEU is needed in RockToGPU pass and OutputSwizzle for the + // OutputSwizzle pass. We add them as func attributes. + IntegerAttr wavesPerEUAttr = + b.getI64IntegerAttr(tuningParams.getWavesPerEU()); + IntegerAttr outputSwizzleAttr = + b.getI64IntegerAttr(tuningParams.getOutputSwizzle()); + func::FuncOp funcOp = cast(op->getParentOp()); + funcOp->setAttr(rock::WavesPerEUAttr::getMnemonic(), wavesPerEUAttr); + funcOp->setAttr(rock::OutputSwizzleAttr::getMnemonic(), outputSwizzleAttr); + LDSLayoutConfigDim ldsLayoutConfigA = getLDSLayoutConfigDim( elementTypeA, kpack, maybeVecDimInfoA.value(), directToLDS); LDSLayoutConfigDim ldsLayoutConfigB = getLDSLayoutConfigDim( diff --git a/mlir/lib/Dialect/Rock/Transforms/OutputSwizzle.cpp b/mlir/lib/Dialect/Rock/Transforms/OutputSwizzle.cpp index bbba4d5cfeb7..8c541e2f4859 100644 --- a/mlir/lib/Dialect/Rock/Transforms/OutputSwizzle.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/OutputSwizzle.cpp @@ -58,6 +58,8 @@ using namespace mlir::arith; using namespace mlir::rock; using mlir::gpu::AddressSpace; +enum OutputSwizzleTuningParam { DISABLED = 0, ENABLED = 1, HEURISTIC = 2 }; + namespace { struct RockOutputSwizzlePass : public rock::impl::RockOutputSwizzlePassBase { @@ -414,9 +416,19 @@ void RockOutputSwizzlePass::runOnOperation() { // Get total LDS memory allocated int64_t ldsAllocated = getLDSTotalSize(func); + OutputSwizzleTuningParam tuning = OutputSwizzleTuningParam::HEURISTIC; + if (func->hasAttrOfType( + rock::OutputSwizzleAttr::getMnemonic())) { + // 0 -> disabled, 1 -> enabled, 2 -> heuristic + int64_t outputSwizzleTuning = + func->getAttrOfType(rock::OutputSwizzleAttr::getMnemonic()) + .getInt(); + tuning = static_cast(outputSwizzleTuning); + } + SmallVector writes; - func.walk([&writes, &rewriter, - ldsAllocated](ThreadwiseWriteAllOp threadwiseWriteAll) { + func.walk([&writes, &rewriter, ldsAllocated, + tuning](ThreadwiseWriteAllOp threadwiseWriteAll) { MemRefType destMemRefType = cast(threadwiseWriteAll.getDest().getType()); @@ -444,12 +456,20 @@ void RockOutputSwizzlePass::runOnOperation() { << ldsRequiredBytes << " bytes, skipping pass\n"); return; } - // heuristic: if we need more LDS, skip this pass - if (ldsRequiredBytes > ldsAllocated) { + if (tuning == OutputSwizzleTuningParam::HEURISTIC) { + // heuristic: if we need more LDS, skip this pass + LLVM_DEBUG(llvm::dbgs() << "Using heuristic\n"); + if (ldsRequiredBytes > ldsAllocated) { + LLVM_DEBUG( + llvm::dbgs() + << "OutputSwizzle requires more LDS memory, current usage: " + << ldsAllocated << " bytes, required: " << ldsRequiredBytes + << " bytes, skipping pass\n"); + return; + } + } else if (tuning == OutputSwizzleTuningParam::DISABLED) { LLVM_DEBUG(llvm::dbgs() - << "OutputSwizzle requires more LDS memory, current usage: " - << ldsAllocated << " bytes, required: " << ldsRequiredBytes - << " bytes, skipping pass\n"); + << "OutputSwizzle disabled using tuning params\n"); return; } writes.push_back(threadwiseWriteAll); diff --git a/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp b/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp index be136a4775d2..c40caff307e7 100644 --- a/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp +++ b/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp @@ -577,6 +577,7 @@ PopulateParamsXDL::getGemmParamsAttr(OpBuilder &builder, validParams.gemmMPerWave, validParams.gemmNPerWave, validParams.gemmMnPerXdl, validParams.splitKFactor, validParams.gemmScheduleVersion, validParams.outputSwizzle, + validParams.wavesPerEU, validParams.gridGroupSize, validParams.gemmAThreadCopyMoreGemmK); } else { // V3 and older @@ -590,11 +591,13 @@ PopulateParamsXDL::getGemmParamsAttr(OpBuilder &builder, mPerWave = mPerBlock / mWaves; int64_t nPerWave = std::max(nPerBlock / nWaves, mnPerXdl); + return builder.getAttr( validParams.gemmKPerBlock, validParams.gemmMPerBlock, validParams.gemmNPerBlock, validParams.gemmKPack, mPerWave, nPerWave, mnPerXdl, validParams.splitKFactor, validParams.gemmScheduleVersion, - validParams.outputSwizzle, validParams.gemmAThreadCopyMoreGemmK); + validParams.outputSwizzle, validParams.wavesPerEU, + validParams.gridGroupSize, validParams.gemmAThreadCopyMoreGemmK); } } @@ -748,6 +751,7 @@ Attribute PopulateParamsWmma::getGemmParamsAttr( validParams.gemmNPerBlock, validParams.gemmKPack, validParams.gemmMPerWave, nPerWave, mnPerXdl, validParams.splitKFactor, validParams.gemmScheduleVersion, validParams.outputSwizzle, + validParams.wavesPerEU, validParams.gridGroupSize, validParams.gemmAThreadCopyMoreGemmK); } @@ -768,6 +772,8 @@ deriveGemm1TuningParams(OpBuilder &b, gemm0XdlDerivedParams.getMnPerXdl(), attnPerfConfig.getSplitKFactor(), gemm0XdlDerivedParams.getScheduleVersion(), gemm0XdlDerivedParams.getOutputSwizzle(), + gemm0XdlDerivedParams.getWavesPerEU(), + gemm0XdlDerivedParams.getGridGroupSize(), gemm0XdlDerivedParams.getForceUnroll()); } return WmmaGemmParamsAttr::get( @@ -778,7 +784,8 @@ deriveGemm1TuningParams(OpBuilder &b, (attnPerfConfig.getMPerBlockG1() / gemm0TuningParams.getMPerBlock()), gemm0TuningParams.getNPerWave(), gemm0TuningParams.getMnPerXdl(), attnPerfConfig.getSplitKFactor(), gemm0TuningParams.getScheduleVersion(), - gemm0TuningParams.getOutputSwizzle(), gemm0TuningParams.getForceUnroll()); + gemm0TuningParams.getOutputSwizzle(), gemm0TuningParams.getWavesPerEU(), + gemm0TuningParams.getGridGroupSize(), gemm0TuningParams.getForceUnroll()); } FailureOr> +getFinetuningParams(int64_t maxWavesPerEU) { + std::vector wavesPerEUList; + wavesPerEUList.push_back(0); // use heuristic + for (int64_t wavesPerEU = 1; wavesPerEU <= maxWavesPerEU; wavesPerEU *= 2) { + wavesPerEUList.push_back(wavesPerEU); + } + std::vector> finetuningParams = { + {0, 1}, // outputSwizzle + wavesPerEUList, // wavesPerEU + {0, 4, 8, 16, 32, 64}}; // gridGroupSize + return finetuningParams; +} + static std::vector> getAccelRangeGemm(RockGemmWrapperInterface gemmOp, TuningParamSetKind kind) { auto dPerBlock = computeDPerBlock(kind); @@ -264,7 +278,7 @@ static void createAttnTuningRangeGreedyPhase1( return; } - int64_t outputSwizzle{2}; + int64_t outputSwizzle{2}, wavesPerEU{0}; for (uint32_t gemm0MPerBlock : params[0]) { SmallVector mPerWaveRange = computeDPerWave(TuningParamSetKind::Greedy, gemm0MPerBlock, waveSize); @@ -300,7 +314,7 @@ static void createAttnTuningRangeGreedyPhase1( gemmGemmOp.getContext(), gemm0MPerBlock, gemm1MPerBlock, gemm0NPerBlock, gemmKPerBlock, gemmMPerWave, gemmNPerWave, gemmMnPerXdl, gemmKPack, splitKFactor, gemmSchedule, - outputSwizzle, true); + outputSwizzle, wavesPerEU, true); if (succeeded(paramsAttnProbablyValid(b, gemmGemmOp, attnParams))) { newSpace->tuningRange.push_back( cast(attnParams)); @@ -326,7 +340,7 @@ static void createAttnTuningRangeGreedyPhase2( } int64_t waveSize = rock::lookupArchInfo(rock::getArchValue(gemmGemmOp)).waveSize; - int64_t outputSwizzle{2}; + int64_t outputSwizzle{2}, wavesPerEU{0}; OpBuilder b(gemmGemmOp.getContext()); auto attnPerfConfig = @@ -354,7 +368,7 @@ static void createAttnTuningRangeGreedyPhase2( gemmGemmOp.getContext(), gemm0MPerBlock, gemm1MPerBlock, gemm0NPerBlock, gemmKPerBlock, gemmMPerWave, gemmNPerWave, gemmMnPerXdl, gemmKPack, splitKFactor, gemmSchedule, - outputSwizzle, true); + outputSwizzle, wavesPerEU, true); if (succeeded( paramsAttnProbablyValid(b, gemmGemmOp, attnParams))) { newSpace->tuningRange.push_back( @@ -369,6 +383,52 @@ static void createAttnTuningRangeGreedyPhase2( } } +// With almost all tuning params already set, tune for fine-tuning parameters by +// brute force (greedy tuning, phase 3) +static void createAttnTuningRangeGreedyPhase3( + TuningParamSet *newSpace, RockGemmGemmWrapperInterface gemmGemmOp, + bool isSplitKFusible, StringRef winningConfig) { + GemmFeatures features = rock::getFeatures(gemmGemmOp); + bool isWMMA = bitEnumContainsAny(features, GemmFeatures::wmma); + if (!bitEnumContainsAny(features, GemmFeatures::mfma) && !isWMMA) { + // We only support GPUs with matrix accelerator extensions + return; + } + OpBuilder b(gemmGemmOp.getContext()); + + auto attnPerfConfig = + AttnPerfConfigAttr::get(b.getStringAttr(winningConfig), isWMMA); + assert(attnPerfConfig && "Tile sizes must be extracted from winning config"); + uint32_t gemm0MPerBlock = attnPerfConfig.getMPerBlockG0(); + uint32_t gemm1MPerBlock = attnPerfConfig.getMPerBlockG1(); + uint32_t gemm0NPerBlock = attnPerfConfig.getNPerBlockG0(); + uint32_t gemmKPerBlock = attnPerfConfig.getKpackPerBlock(); + uint32_t gemmMPerWave = attnPerfConfig.getMPerWave(); + uint32_t gemmNPerWave = attnPerfConfig.getNPerWave(); + uint32_t gemmMnPerXdl = attnPerfConfig.getMnPerXdl(); + uint32_t gemmKPack = attnPerfConfig.getKpack(); + uint32_t splitKFactor = attnPerfConfig.getSplitKFactor(); + uint32_t gemmSchedule = attnPerfConfig.getScheduleVersion(); + + int64_t maxWavesPerEU = + rock::lookupArchInfo(rock::getArchValue(gemmGemmOp)).maxWavesPerEU; + auto finetuningParams = getFinetuningParams(maxWavesPerEU); + + for (int64_t outputSwizzle : finetuningParams[0]) { + for (uint32_t wavesPerEU : finetuningParams[1]) { + auto attnParams = AttnPerfConfigAttr::get( + gemmGemmOp.getContext(), gemm0MPerBlock, gemm1MPerBlock, + gemm0NPerBlock, gemmKPerBlock, gemmMPerWave, gemmNPerWave, + gemmMnPerXdl, gemmKPack, splitKFactor, gemmSchedule, outputSwizzle, + wavesPerEU, true); + if (succeeded(paramsAttnProbablyValid(b, gemmGemmOp, attnParams))) { + newSpace->tuningRange.push_back( + cast(attnParams)); + } + } + } +} + // Keep in sync with attentionSweeps.py // The full space is a brute-force search for attention kernels static void createAttnTuningRangeBF(TuningParamSet *newSpace, @@ -387,7 +447,7 @@ static void createAttnTuningRangeBF(TuningParamSet *newSpace, } int64_t waveSize = rock::lookupArchInfo(rock::getArchValue(gemmGemmOp)).waveSize; - int64_t outputSwizzle{2}; + int64_t outputSwizzle{2}, wavesPerEU{0}; OpBuilder b(gemmGemmOp.getContext()); for (uint32_t gemm0MPerBlock : validRangeAttnParams[0]) { SmallVector mPerWaveRange = @@ -419,7 +479,8 @@ static void createAttnTuningRangeBF(TuningParamSet *newSpace, gemmGemmOp.getContext(), gemm0MPerBlock, gemm1MPerBlock, gemm0NPerBlock, gemmKPerBlock, gemmMPerWave, gemmNPerWave, gemmMnPerXdl, gemmKPack, - splitKFactor, gemmSchedule, outputSwizzle, true); + splitKFactor, gemmSchedule, outputSwizzle, wavesPerEU, + true); if (succeeded(paramsAttnProbablyValid(b, gemmGemmOp, attnParams))) { newSpace->tuningRange.push_back( @@ -555,6 +616,8 @@ static void createGemmTuningRangeBF(TuningParamSet *newSpace, tuningInfo = std::make_unique(); int64_t waveSize = rock::lookupArchInfo(rock::getArchValue(gemmOp)).waveSize; + // hardcode to use heuristics + int64_t outputSwizzle{2}, wavesPerEU{0}, gridGroupSize{0}; OpBuilder b(gemmOp.getContext()); if (bitEnumContainsAll(currentFeatures, GemmFeatures::mfma) || bitEnumContainsAll(currentFeatures, GemmFeatures::wmma)) { @@ -575,13 +638,13 @@ static void createGemmTuningRangeBF(TuningParamSet *newSpace, for (int64_t splitKFactor : optimalSplitKFactors) { for (int64_t gemmSchedule : accelParams[5]) { for (uint32_t forceUnroll : accelParams[6]) { - // hardcode outputSwizzle to heuristics = 2 InitParamsAccel gemmParams( gemmMPerBlock, gemmNPerBlock, gemmKPerBlock, gemmMPerWave, gemmNPerWave, gemmMnPerXdl, gemmKPack, - splitKFactor, gemmSchedule, 2, forceUnroll, true); + splitKFactor, gemmSchedule, outputSwizzle, + wavesPerEU, gridGroupSize, forceUnroll, true); if (gemmMPerBlock >= gemmMPerWave && - gemmNPerBlock >= gemmMnPerXdl) { + gemmNPerBlock >= gemmNPerWave) { if (succeeded(tuningInfo->paramsProbablyValid( b, info, gemmParams)) && (kind != TuningParamSetKind::Full || @@ -689,7 +752,7 @@ static void createAttnTuningRangeQuick(TuningParamSet *newSpace, Op gemmGemmOp, Type elemType) { OpBuilder b(gemmGemmOp.getContext()); GemmFeatures currentFeatures = rock::getFeatures(gemmGemmOp); - int64_t splitKFactor{1}, gemmSchedule{1}, outputSwizzle{2}; + int64_t splitKFactor{1}, gemmSchedule{1}, outputSwizzle{2}, wavesPerEU{0}; // g0Mpb, g1Mpb, g0Npb, Kpb, mPw, mnPxdl, kpack using PerfConfigVals = std::tuple; @@ -719,7 +782,7 @@ static void createAttnTuningRangeQuick(TuningParamSet *newSpace, Op gemmGemmOp, auto attnParams = AttnPerfConfigAttr::get( gemmGemmOp.getContext(), mPerBlockG0, mPerBlockG1, nPerBlockG0, kPackBerBlock, mPerWave, nPerWave, mnPerXdl, kPack, splitKFactor, - gemmSchedule, outputSwizzle, true); + gemmSchedule, outputSwizzle, wavesPerEU, true); if (succeeded(paramsAttnProbablyValid(b, gemmGemmOp, attnParams))) { newSpace->tuningRange.push_back( cast(attnParams)); @@ -740,7 +803,7 @@ static void createAttnTuningRangeQuick(TuningParamSet *newSpace, Op gemmGemmOp, auto attnParams = AttnPerfConfigAttr::get( gemmGemmOp.getContext(), mPerBlockG0, mPerBlockG1, nPerBlockG0, kPackBerBlock, mPerWave, nPerWave, mnPerXdl, kPack, splitKFactor, - gemmSchedule, outputSwizzle, true); + gemmSchedule, outputSwizzle, wavesPerEU, true); if (succeeded(paramsAttnProbablyValid(b, gemmGemmOp, attnParams))) { newSpace->tuningRange.push_back( cast(attnParams)); @@ -769,7 +832,7 @@ unsigned getNumberOfIterations(TuningParamSetKind kind) { case TuningParamSetKind::Exhaustive: return 1; case TuningParamSetKind::Greedy: - return 2; + return 3; } llvm_unreachable("invalid tuning kind"); } @@ -794,6 +857,8 @@ static void createGemmTuningRangeGreedyPhase1(TuningParamSet *newSpace, tuningInfo = std::make_unique(); int64_t waveSize = rock::lookupArchInfo(rock::getArchValue(gemmOp)).waveSize; + // hardcode to use heuristics + int64_t outputSwizzle{2}, wavesPerEU{0}, gridGroupSize{0}; for (uint32_t gemmMPerBlock : params[0]) { SmallVector mPerWaveRange = computeDPerWave(TuningParamSetKind::Greedy, gemmMPerBlock, waveSize); @@ -822,11 +887,10 @@ static void createGemmTuningRangeGreedyPhase1(TuningParamSet *newSpace, isSplitKFusible); uint32_t splitKFactor = optimalSplitKFactors[rng() % optimalSplitKFactors.size()]; - // hardcode outputSwizzle to heuristics = 2 - InitParamsAccel gemmParams(gemmMPerBlock, gemmNPerBlock, gemmKPerBlock, - gemmMPerWave, gemmNPerWave, gemmMnPerXdl, - gemmKPack, splitKFactor, gemmSchedule, 2, - forceUnroll, true); + InitParamsAccel gemmParams( + gemmMPerBlock, gemmNPerBlock, gemmKPerBlock, gemmMPerWave, + gemmNPerWave, gemmMnPerXdl, gemmKPack, splitKFactor, gemmSchedule, + outputSwizzle, wavesPerEU, gridGroupSize, forceUnroll, true); if (succeeded(tuningInfo->paramsProbablyValid(b, info, gemmParams))) { newSpace->tuningRange.push_back(cast( tuningInfo->getGemmParamsAttr(b, gemmParams))); @@ -870,6 +934,8 @@ static void createGemmTuningRangeGreedyPhase2(TuningParamSet *newSpace, SmallVector nPerWaveRange = computeDPerWave(TuningParamSetKind::Greedy, winningNPerBlock, waveSize); + // hardcode to use heuristics + int64_t outputSwizzle{2}, wavesPerEU{0}, gridGroupSize{0}; for (uint32_t gemmKPerBlock : params[2]) { for (uint32_t gemmMPerWave : mPerWaveRange) { for (uint32_t gemmNPerWave : nPerWaveRange) { @@ -881,13 +947,13 @@ static void createGemmTuningRangeGreedyPhase2(TuningParamSet *newSpace, for (int64_t splitKFactor : optimalSplitKFactors) { for (int64_t gemmSchedule : params[5]) { for (uint32_t forceUnroll : params[6]) { - // hardcode outputSwizzle to heuristics = 2 InitParamsAccel gemmParams( winningMPerBlock, winningNPerBlock, gemmKPerBlock, gemmMPerWave, gemmNPerWave, gemmMnPerXdl, gemmKPack, - splitKFactor, gemmSchedule, 2, forceUnroll, true); + splitKFactor, gemmSchedule, outputSwizzle, wavesPerEU, + gridGroupSize, forceUnroll, true); if (winningMPerBlock >= gemmMPerWave && - winningNPerBlock >= gemmMnPerXdl) { + winningNPerBlock >= gemmNPerWave) { if (succeeded(tuningInfo->paramsProbablyValid(b, info, gemmParams))) newSpace->tuningRange.push_back( @@ -904,6 +970,58 @@ static void createGemmTuningRangeGreedyPhase2(TuningParamSet *newSpace, } } +// With almost all tuning params already set, tune for fine-tuning parameters by +// brute force (greedy tuning, phase 3) +static void createGemmTuningRangeGreedyPhase3(TuningParamSet *newSpace, + RockGemmWrapperInterface gemmOp, + bool isSplitKFusible, + StringRef winningConfig) { + auto info = PopulateParamsInfo::fromOp(gemmOp); + OpBuilder b(gemmOp.getContext()); + GemmFeatures currentFeatures = rock::getFeatures(gemmOp); + + InitParamsAccel validParams; + auto populateParamsAccelPtr = PopulateParamsAccel::select(currentFeatures); + LogicalResult status = populateParamsAccelPtr->obtainTuningParameters( + gemmOp, winningConfig, validParams); + assert(llvm::succeeded(status) && + "Tile sizes must be extracted from winning config"); + uint32_t mPerBlock = validParams.gemmMPerBlock; + uint32_t nPerBlock = validParams.gemmNPerBlock; + uint32_t kpackPerBlock = validParams.gemmKPerBlock; + uint32_t mPerWave = validParams.gemmMPerWave; + uint32_t nPerWave = validParams.gemmNPerWave; + uint32_t mnPerXdl = validParams.gemmMnPerXdl; + uint32_t kpack = validParams.gemmKPack; + uint32_t splitKFactor = validParams.splitKFactor; + uint32_t scheduleVersion = validParams.gemmScheduleVersion; + uint32_t forceUnroll = validParams.gemmAThreadCopyMoreGemmK; + + std::unique_ptr tuningInfo; + if (bitEnumContainsAll(currentFeatures, GemmFeatures::mfma)) + tuningInfo = std::make_unique(); + else + tuningInfo = std::make_unique(); + + int64_t maxWavesPerEU = + rock::lookupArchInfo(rock::getArchValue(gemmOp)).maxWavesPerEU; + auto finetuningParams = getFinetuningParams(maxWavesPerEU); + + for (int64_t outputSwizzle : finetuningParams[0]) { + for (int64_t wavesPerEU : finetuningParams[1]) { + for (int64_t gridGroupSize : finetuningParams[2]) { + InitParamsAccel gemmParams( + mPerBlock, nPerBlock, kpackPerBlock, mPerWave, nPerWave, mnPerXdl, + kpack, splitKFactor, scheduleVersion, outputSwizzle, wavesPerEU, + gridGroupSize, forceUnroll, true); + if (succeeded(tuningInfo->paramsProbablyValid(b, info, gemmParams))) + newSpace->tuningRange.push_back(cast( + tuningInfo->getGemmParamsAttr(b, gemmParams))); + } + } + } +} + TuningParamSet * createTunableParamSpace(ModuleOp mod, TuningParamSetKind kind, rock::TuningParamSpaceSettings &settings) { @@ -935,10 +1053,17 @@ createTunableParamSpace(ModuleOp mod, TuningParamSetKind kind, createGemmTuningRangeGreedyPhase1( newSpace, op, isSplitKFusible, NUM_RANDOM_PERFCONFIGS_PER_TILE_SIZE, RND_SEED); - } else { - // Second iteration: brute force with winning tile sizes + } else if (settings.iteration == 1) { + // Second iteration: brute force (except waves_per_eu, + // output_swizzle and grid_group_size, which we hardcode to use the + // heuristic) with winning tile sizes createGemmTuningRangeGreedyPhase2(newSpace, op, isSplitKFusible, settings.winningConfig); + } else { + // Third iteration: brute force the remaining configs (waves_per_eu, + // output_swizzle and grid_group_size) + createGemmTuningRangeGreedyPhase3(newSpace, op, isSplitKFusible, + settings.winningConfig); } break; case TuningParamSetKind::Quick: @@ -962,10 +1087,18 @@ createTunableParamSpace(ModuleOp mod, TuningParamSetKind kind, createAttnTuningRangeGreedyPhase1( newSpace, op, isSplitKFusible, NUM_RANDOM_PERFCONFIGS_PER_TILE_SIZE, RND_SEED); - } else { - // Second iteration: brute force with winning tile sizes + } else if (settings.iteration == 1) { + // Second iteration: brute force (except waves_per_eu and + // output_swizzle, which we hardcode to use the heuristic) with + // winning tile sizes + createAttnTuningRangeGreedyPhase2(newSpace, op, isSplitKFusible, settings.winningConfig); + } else { + // Third iteration: brute force the remaining configs (waves_per_eu + // and output_swizzle) + createAttnTuningRangeGreedyPhase3(newSpace, op, isSplitKFusible, + settings.winningConfig); } break; case TuningParamSetKind::Quick: diff --git a/mlir/test/Conversion/RockToGPU/waves_per_eu.mlir b/mlir/test/Conversion/RockToGPU/waves_per_eu.mlir new file mode 100644 index 000000000000..6ff4b4866997 --- /dev/null +++ b/mlir/test/Conversion/RockToGPU/waves_per_eu.mlir @@ -0,0 +1,35 @@ +// RUN: rocmlir-opt -convert-rock-to-gpu -split-input-file %s | FileCheck %s + +// CHECK: module attributes {gpu.container_module} +// CHECK-NEXT: gpu.module @misckernel_module +// CHECK-NEXT: gpu.func @misckernel(%{{.*}}: memref, %{{.*}}: memref) +// CHECK-SAME: workgroup(%arg2 : memref<64xf32, #gpu.address_space> {llvm.align = 64 : i64}) +// CHECK-SAME: kernel +// CHECK-SAME: arch = "amdgcn-amd-amdhsa:gfx1100" +// CHECK-SAME: block_size = 128 : i32 +// CHECK-SAME: grid_size = 256 : i32 +// CHECK-SAME: known_block_size = array +// CHECK-SAME: known_grid_size = array +// CHECK-SAME: num_cu = 96 : i64 +// CHECK-SAME: rocdl.unsafe_fp_atomics = true +// CHECK-SAME: rocdl.waves_per_eu = 8 : i32 +// CHECK-SAME: rock.shared_buffer_size = 256 : i32 +module { + func.func @misckernel(%arg0: memref, %arg1: memref) attributes {block_size = 128 : i32, enable_splitk_for_tuning, features = #rock, grid_size = 256 : i32, kernel, mhal.arch = "amdgcn-amd-amdhsa:gfx1100", num_cu = 96 : i64, waves_per_eu = 8 : i64} { + %lds = rock.alloc() : memref<64xf32, #gpu.address_space> + + rock.workgroup_barrier + rock.lds_barrier + + %bid = rock.workgroup_id : index + %tid = rock.workitem_id : index + %idx = arith.muli %bid, %tid : index + + %val = memref.load %arg0[%idx] : memref + %val_lds = memref.load %lds[%idx] : memref<64xf32, #gpu.address_space> + + memref.store %val, %arg1[%idx] : memref + memref.store %val_lds, %arg1[%idx] : memref + return + } +} diff --git a/mlir/test/Dialect/Rock/affix_tuning_params.mlir b/mlir/test/Dialect/Rock/affix_tuning_params.mlir index 22db85f5ebfd..470faca2a161 100644 --- a/mlir/test/Dialect/Rock/affix_tuning_params.mlir +++ b/mlir/test/Dialect/Rock/affix_tuning_params.mlir @@ -65,7 +65,7 @@ func.func @rock_conv_f16(%filter : memref<1x128x8x3x3xf16>, %input : memref<128x func.func @rock_conv_i8(%filter : memref<1x128x8x3x3xi8>, %input : memref<128x1x8x32x32xi8>, %output : memref<128x1x128x30x30xi32>) attributes {arch = "amdgcn-amd-amdhsa:gfx908"} { // CHECK: rock.conv // CHECK-SAME: derivedBlockSize = 64 - // CHECK-SAME: params = #rock.mfma_gemm_params + // CHECK-SAME: params = #rock.mfma_gemm_params // GRID: rock.gridwise_gemm // GRID-SAME: gridSize = 3600 rock.conv(%filter, %input, %output) features = mfma|dot|atomic_add|atomic_add_f16 { @@ -84,7 +84,7 @@ func.func @rock_conv_i8(%filter : memref<1x128x8x3x3xi8>, %input : memref<128x1x func.func @rock_conv_bwd_data(%filter: memref<1x1024x1024x1x1xf32>, %input: memref<128x1x1024x14x14xf32>, %output: memref<128x1x1024x14x14xf32>) attributes {kernel = 0 : i32, arch = "amdgcn-amd-amdhsa:gfx908"} { // CHECK: rock.conv_bwd_data // CHECK-SAME: derivedBlockSize = 256 - // CHECK-SAME: params = #rock.mfma_gemm_params + // CHECK-SAME: params = #rock.mfma_gemm_params // GRID: rock.gridwise_gemm // GRID-SAME: gridSize = 25088 rock.conv_bwd_data(%filter, %input, %output) features = mfma|dot|atomic_add|atomic_add_f16 { @@ -105,7 +105,7 @@ func.func @rock_conv_bwd_data(%filter: memref<1x1024x1024x1x1xf32>, %input: memr func.func @rock_conv_bwd_data_f16(%filter: memref<1x1024x1024x1x1xf16>, %input: memref<128x1x1024x14x14xf16>, %output: memref<128x1x1024x14x14xf16>) attributes {kernel = 0 : i32, arch = "amdgcn-amd-amdhsa:gfx908"} { // CHECK: rock.conv_bwd_data // CHECK-SAME: derivedBlockSize = 256 - // CHECK-SAME: params = #rock.mfma_gemm_params + // CHECK-SAME: params = #rock.mfma_gemm_params // GRID: rock.gridwise_gemm // GRID-SAME: gridSize = 25088 rock.conv_bwd_data(%filter, %input, %output) features = mfma|dot|atomic_add|atomic_add_f16 { @@ -238,7 +238,7 @@ func.func @rock_conv_bwd_weight_padALL_f16(%filter : memref<1x20x8x3x3xf16>, %in func.func @rock_conv_7x7_tuning(%arg0: memref<1x64x3x7x7xf32>, %arg1: memref<256x1x3x230x230xf32>, %arg2: memref<256x1x64x112x112xf32>) attributes {arch = "amdgcn-amd-amdhsa:gfx906"} { // CHECK: rock.conv // CHECK-SAME: derivedBlockSize = 256 - // CHECK-SAME: params = #rock.mfma_gemm_params + // CHECK-SAME: params = #rock.mfma_gemm_params // GRID: rock.gridwise_gemm // GRID-SAME: gridSize = 12544 rock.conv(%arg0, %arg1, %arg2) features = mfma|dot|atomic_add|atomic_add_f16 { @@ -260,7 +260,7 @@ func.func @rock_conv_7x7_tuning(%arg0: memref<1x64x3x7x7xf32>, %arg1: memref<256 func.func @rock_conv_7x7(%arg0: memref<1x64x3x7x7xf32>, %arg1: memref<256x1x3x230x230xf32>, %arg2: memref<256x1x64x112x112xf32>) attributes {arch = "amdgcn-amd-amdhsa:gfx906"} { // CHECK: rock.conv // CHECK-SAME: derivedBlockSize = 64 - // CHECK-SAME: params = #rock.mfma_gemm_params + // CHECK-SAME: params = #rock.mfma_gemm_params // GRID: rock.gridwise_gemm // GRID-SAME: gridSize = 100352 rock.conv(%arg0, %arg1, %arg2) features = mfma|dot|atomic_add|atomic_add_f16 { @@ -279,7 +279,7 @@ func.func @rock_conv_7x7(%arg0: memref<1x64x3x7x7xf32>, %arg1: memref<256x1x3x23 func.func @rock_conv_bwd_weight_7x7(%arg0: memref<1x64x3x7x7xf32>, %arg1: memref<256x1x3x230x230xf32>, %arg2: memref<256x1x64x112x112xf32>) attributes {kernel = 0 : i32, arch = "amdgcn-amd-amdhsa:gfx906", numCU = 120 : i32} { // CHECK: rock.conv_bwd_weight // CHECK-SAME: derivedBlockSize = 256 - // CHECK-SAME: params = #rock.mfma_gemm_params + // CHECK-SAME: params = #rock.mfma_gemm_params // GRID: rock.gridwise_gemm // GRID-SAME: gridSize = 10 rock.conv_bwd_weight(%arg0, %arg1, %arg2) features = mfma|dot|atomic_add|atomic_add_f16 { @@ -298,7 +298,7 @@ func.func @rock_conv_bwd_weight_7x7(%arg0: memref<1x64x3x7x7xf32>, %arg1: memref func.func @rock_conv_bwd_data_7x7_tuning(%arg0: memref<1x64x3x7x7xf32>, %arg1: memref<256x1x3x230x230xf32>, %arg2: memref<256x1x64x112x112xf32>) attributes {kernel = 1 : i32, arch = "amdgcn-amd-amdhsa:gfx906"} { // CHECK: rock.conv_bwd_data // CHECK-SAME: derivedBlockSize = 256 - // CHECK-SAME: params = #rock.mfma_gemm_params + // CHECK-SAME: params = #rock.mfma_gemm_params // GRID: rock.gridwise_gemm // GRID-SAME: gridSize = 26450 rock.conv_bwd_data(%arg0, %arg1, %arg2) features = mfma|dot|atomic_add|atomic_add_f16 { @@ -320,7 +320,7 @@ func.func @rock_conv_bwd_data_7x7_tuning(%arg0: memref<1x64x3x7x7xf32>, %arg1: m func.func @rock_conv_bwd_data_7x7(%arg0: memref<1x64x3x7x7xf32>, %arg1: memref<256x1x3x230x230xf32>, %arg2: memref<256x1x64x112x112xf32>) attributes {kernel = 1 : i32, arch = "amdgcn-amd-amdhsa:gfx908"} { // CHECK: rock.conv_bwd_data // CHECK-SAME: derivedBlockSize = 64 - // CHECK-SAME: params = #rock.mfma_gemm_params + // CHECK-SAME: params = #rock.mfma_gemm_params // GRID: rock.gridwise_gemm // GRID-SAME: gridSize = 211600 rock.conv_bwd_data(%arg0, %arg1, %arg2) features = mfma|dot|atomic_add|atomic_add_f16 { @@ -353,7 +353,7 @@ func.func @rock_gemm_from_conv(%a : memref<1x72x128xf32>, %b : memref<1x72x11520 func.func @rock_gemm_from_i8_conv(%a : memref<1x72x128xi8>, %b : memref<1x72x115200xi8>, %c : memref<1x128x115200xi32>) attributes {arch = "amdgcn-amd-amdhsa:gfx908", numCU = 120 : i32} { // CHECK: rock.gemm // CHECK-SAME: derivedBlockSize = 256 - // CHECK-SAME: params = #rock.mfma_gemm_params + // CHECK-SAME: params = #rock.mfma_gemm_params // GRID: rock.gridwise_gemm // GRID-SAME: gridSize = 7200 rock.gemm %c = tr %a * %b features = mfma|dot|atomic_add|atomic_add_f16 storeMethod = set @@ -366,7 +366,7 @@ func.func @rock_gemm_from_i8_conv(%a : memref<1x72x128xi8>, %b : memref<1x72x115 func.func @rock_gemm_from_i8_conv_schedule_v2(%a : memref<1x72x128xi8>, %b : memref<1x72x115200xi8>, %c : memref<1x128x115200xi32>) attributes {schedule_version = #rock.schedule_version<2>, arch = "amdgcn-amd-amdhsa:gfx908", numCU = 120 : i32} { // CHECK: rock.gemm // CHECK-SAME: derivedBlockSize = 256 - // CHECK-SAME: params = #rock.mfma_gemm_params + // CHECK-SAME: params = #rock.mfma_gemm_params // GRID: rock.gridwise_gemm // GRID-SAME: gridSize = 7200 rock.gemm %c = tr %a * %b features = mfma|dot|atomic_add|atomic_add_f16 storeMethod = set @@ -382,7 +382,7 @@ func.func @rock_gemm_from_i8_conv_schedule_v2(%a : memref<1x72x128xi8>, %b : mem func.func @rock_gemm_from_i8_conv_gfx942(%a : memref<1x72x128xi8>, %b : memref<1x72x115200xi8>, %c : memref<1x128x115200xi32>) attributes {arch = "amdgcn-amd-amdhsa:gfx942", numCU = 120 : i32} { // CHECK: rock.gemm // CHECK-SAME: derivedBlockSize = 256 - // CHECK-SAME: params = #rock.mfma_gemm_params + // CHECK-SAME: params = #rock.mfma_gemm_params // GRID: rock.gridwise_gemm // GRID-SAME: gridSize = 14400 rock.gemm %c = tr %a * %b features = mfma|dot|atomic_add|atomic_add_f16 storeMethod = set @@ -396,7 +396,7 @@ func.func @rock_gemm_from_i8_conv_gfx942(%a : memref<1x72x128xi8>, %b : memref<1 func.func @rock_gemm_xdlops_fp8_bf8(%a : memref<1x72x128xf8E4M3FNUZ>, %b : memref<1x72x115200xf8E5M2FNUZ>, %c : memref<1x128x115200xf32>) attributes {arch = "amdgcn-amd-amdhsa:gfx942", numCU = 120 : i32} { // CHECK: rock.gemm // CHECK-SAME: derivedBlockSize = 256 - // CHECK-SAME: params = #rock.mfma_gemm_params + // CHECK-SAME: params = #rock.mfma_gemm_params // GRID: rock.gridwise_gemm // GRID-SAME: gridSize = 1800 rock.gemm %c = tr %a * %b features = mfma|dot|atomic_add|atomic_add_f16 storeMethod = set @@ -410,7 +410,7 @@ func.func @rock_gemm_xdlops_fp8_bf8(%a : memref<1x72x128xf8E4M3FNUZ>, %b : memre func.func @rock_gemm_xdlops_fp8_bf8_ocp(%a : memref<1x72x128xf8E4M3FN>, %b : memref<1x72x115200xf8E5M2>, %c : memref<1x128x115200xf32>) attributes {arch = "amdgcn-amd-amdhsa:gfx950", numCU = 120 : i32} { // CHECK: rock.gemm // CHECK-SAME: derivedBlockSize = 256 - // CHECK-SAME: params = #rock.mfma_gemm_params + // CHECK-SAME: params = #rock.mfma_gemm_params // GRID: rock.gridwise_gemm // GRID-SAME: gridSize = 1800 rock.gemm %c = tr %a * %b features = mfma|dot|atomic_add|atomic_add_f16|atomic_add_bf16 storeMethod = set @@ -424,7 +424,7 @@ func.func @rock_gemm_xdlops_fp8_bf8_ocp(%a : memref<1x72x128xf8E4M3FN>, %b : mem // GRID-SAME: grid_size = 12 func.func @rock_attention_default(%arg0: memref<1x384x64xf16>, %arg1: memref<1x384x64xf16>, %arg2: memref<1x384x64xf16>, %arg3: memref<1x384x64xf16>) attributes {kernel, mhal.arch = "amdgcn-amd-amdhsa:gfx1100"} { // CHECK: rock.attention - // CHECK: #rock.wmma_gemm_params + // CHECK: #rock.wmma_gemm_params rock.attention{ qk = %arg0 * tr %arg1 : memref<1x384x64xf16>, memref<1x384x64xf16> %arg3 = softmax(qk) * %arg2 : memref<1x384x64xf16> -> memref<1x384x64xf16> @@ -439,8 +439,8 @@ func.func @rock_attention_default(%arg0: memref<1x384x64xf16>, %arg1: memref<1x3 func.func @rock_attention_large(%arg0: memref<1x16384x512xf32>, %arg1: memref<1x512x16384xf32>, %arg2: memref<1x16384x512xf32>, %arg3: memref<1x16384x512xf32>) attributes {arch = "gfx942:sramecc+:xnack-"} { %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x16384x512xf32> // CHECK: rock.attention - // CHECK: params0 = #rock.mfma_gemm_params - // CHECK: params1 = #rock.mfma_gemm_params + // CHECK: params0 = #rock.mfma_gemm_params + // CHECK: params1 = #rock.mfma_gemm_params rock.attention{ qk = %arg0 * %arg1 : memref<1x16384x512xf32>, memref<1x512x16384xf32> %arg3 = softmax(qk) * %arg2 : memref<1x16384x512xf32> -> memref<1x16384x512xf32> @@ -454,8 +454,8 @@ func.func @rock_attention_large(%arg0: memref<1x16384x512xf32>, %arg1: memref<1x // GRID-SAME: grid_size = 3 func.func @rock_attention_mperblockg1_wmma(%arg0: memref<1x384x64xf16>, %arg1: memref<1x384x64xf16>, %arg2: memref<1x384x64xf16>, %arg3: memref<1x384x64xf16>) attributes {kernel, mhal.arch = "amdgcn-amd-amdhsa:gfx1100"} { // CHECK: rock.attention - // CHECK: #rock.wmma_gemm_params - // CHECK: #rock.wmma_gemm_params + // CHECK: #rock.wmma_gemm_params + // CHECK: #rock.wmma_gemm_params rock.attention{ qk = %arg0 * tr %arg1 : memref<1x384x64xf16>, memref<1x384x64xf16> %arg3 = softmax(qk) * %arg2 : memref<1x384x64xf16> -> memref<1x384x64xf16> @@ -469,8 +469,8 @@ func.func @rock_attention_mperblockg1_wmma(%arg0: memref<1x384x64xf16>, %arg1: m // GRID-SAME: grid_size = 3 func.func @rock_attention_mperblockg1_mfma(%arg0: memref<1x384x64xf16>, %arg1: memref<1x384x64xf16>, %arg2: memref<1x384x64xf16>, %arg3: memref<1x384x64xf16>) attributes {kernel, mhal.arch = "gfx942:sramecc+:xnack-"} { // CHECK: rock.attention - // CHECK: #rock.mfma_gemm_params - // CHECK: #rock.mfma_gemm_params + // CHECK: #rock.mfma_gemm_params + // CHECK: #rock.mfma_gemm_params rock.attention{ qk = %arg0 * tr %arg1 : memref<1x384x64xf16>, memref<1x384x64xf16> %arg3 = softmax(qk) * %arg2 : memref<1x384x64xf16> -> memref<1x384x64xf16> @@ -484,7 +484,7 @@ func.func @rock_attention_mperblockg1_mfma(%arg0: memref<1x384x64xf16>, %arg1: m // GRID-SAME: grid_size = 12 func.func @rock_gemm_gemm_default(%arg0: memref<1x384x64xf16>, %arg1: memref<1x384x64xf16>, %arg2: memref<1x384x64xf16>, %arg3: memref<1x384x64xf16>) attributes {kernel, mhal.arch = "amdgcn-amd-amdhsa:gfx1100"} { // CHECK: rock.gemm_elementwise_gemm - // CHECK: #rock.wmma_gemm_params + // CHECK: #rock.wmma_gemm_params rock.gemm_elementwise_gemm{ ab = %arg0 * tr %arg1 : memref<1x384x64xf16>, memref<1x384x64xf16> %arg3 = ab * %arg2 : memref<1x384x64xf16> -> memref<1x384x64xf16> @@ -499,8 +499,8 @@ func.func @rock_gemm_gemm_default(%arg0: memref<1x384x64xf16>, %arg1: memref<1x3 func.func @rock_gemm_gemm_v1(%arg0: memref<1x16384x512xf32>, %arg1: memref<1x512x16384xf32>, %arg2: memref<1x16384x512xf32>, %arg3: memref<1x16384x512xf32>) attributes {arch = "gfx942:sramecc+:xnack-"} { %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x16384x512xf32> // CHECK: rock.gemm_elementwise_gemm - // CHECK: params0 = #rock.mfma_gemm_params - // CHECK: params1 = #rock.mfma_gemm_params + // CHECK: params0 = #rock.mfma_gemm_params + // CHECK: params1 = #rock.mfma_gemm_params rock.gemm_elementwise_gemm{ ab = %arg0 * %arg1 : memref<1x16384x512xf32>, memref<1x512x16384xf32> %arg3 = ab * %arg2 : memref<1x16384x512xf32> -> memref<1x16384x512xf32> @@ -515,8 +515,8 @@ func.func @rock_gemm_gemm_v1(%arg0: memref<1x16384x512xf32>, %arg1: memref<1x512 func.func @rock_gemm_gemm_large(%arg0: memref<1x16384x512xf32>, %arg1: memref<1x512x16384xf32>, %arg2: memref<1x16384x512xf32>, %arg3: memref<1x16384x512xf32>) attributes {arch = "gfx942:sramecc+:xnack-"} { %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x16384x512xf32> // CHECK: rock.gemm_elementwise_gemm - // CHECK: params0 = #rock.mfma_gemm_params - // CHECK: params1 = #rock.mfma_gemm_params + // CHECK: params0 = #rock.mfma_gemm_params + // CHECK: params1 = #rock.mfma_gemm_params rock.gemm_elementwise_gemm{ ab = %arg0 * %arg1 : memref<1x16384x512xf32>, memref<1x512x16384xf32> %arg3 = ab * %arg2 : memref<1x16384x512xf32> -> memref<1x16384x512xf32> @@ -530,8 +530,8 @@ func.func @rock_gemm_gemm_large(%arg0: memref<1x16384x512xf32>, %arg1: memref<1x // GRID-SAME: grid_size = 3 func.func @rock_gemm_gemm_mperblockg1_wmma(%arg0: memref<1x384x64xf16>, %arg1: memref<1x384x64xf16>, %arg2: memref<1x384x64xf16>, %arg3: memref<1x384x64xf16>) attributes {kernel, mhal.arch = "amdgcn-amd-amdhsa:gfx1100"} { // CHECK: rock.gemm_elementwise_gemm - // CHECK: #rock.wmma_gemm_params - // CHECK: #rock.wmma_gemm_params + // CHECK: #rock.wmma_gemm_params + // CHECK: #rock.wmma_gemm_params rock.gemm_elementwise_gemm{ ab = %arg0 * tr %arg1 : memref<1x384x64xf16>, memref<1x384x64xf16> %arg3 = ab * %arg2 : memref<1x384x64xf16> -> memref<1x384x64xf16> @@ -545,8 +545,8 @@ func.func @rock_gemm_gemm_mperblockg1_wmma(%arg0: memref<1x384x64xf16>, %arg1: m // GRID-SAME: grid_size = 3 func.func @rock_gemm_gemm_mperblockg1_mfma(%arg0: memref<1x384x64xf32>, %arg1: memref<1x384x64xf32>, %arg2: memref<1x384x64xf32>, %arg3: memref<1x384x64xf32>) attributes {kernel, mhal.arch = "gfx942:sramecc+:xnack-"} { // CHECK: rock.gemm_elementwise_gemm - // CHECK: #rock.mfma_gemm_params - // CHECK: #rock.mfma_gemm_params + // CHECK: #rock.mfma_gemm_params + // CHECK: #rock.mfma_gemm_params rock.gemm_elementwise_gemm{ ab = %arg0 * tr %arg1 : memref<1x384x64xf32>, memref<1x384x64xf32> %arg3 = ab * %arg2 : memref<1x384x64xf32> -> memref<1x384x64xf32> @@ -560,7 +560,7 @@ func.func @rock_gemm_gemm_mperblockg1_mfma(%arg0: memref<1x384x64xf32>, %arg1: m // GRID-SAME: grid_size = 64 func.func @rock_conv_gemm_default(%arg0: memref<1x128x256x1x1xf16>, %arg1: memref<2x1x256x32x32xf16>, %arg2: memref<1x128x64xf16>, %arg3: memref<1x2048x64xf16>) attributes {kernel, mhal.arch = "amdgcn-amd-amdhsa:gfx1100"} { // CHECK: rock.conv_elementwise_gemm - // CHECK: #rock.wmma_gemm_params + // CHECK: #rock.wmma_gemm_params rock.conv_elementwise_gemm{ ab = conv(%arg0, %arg1) : memref<1x128x256x1x1xf16>, memref<2x1x256x32x32xf16> %arg3 = ab * %arg2 : memref<1x128x64xf16> -> memref<1x2048x64xf16> @@ -574,8 +574,8 @@ func.func @rock_conv_gemm_default(%arg0: memref<1x128x256x1x1xf16>, %arg1: memre // GRID-SAME: grid_size = 2048 func.func @rock_conv_gemm_splitk(%arg0: memref<1x128x256x3x3xf32>, %arg1: memref<2x1x256x128x128xf32>, %arg2: memref<1x128x128xf32>, %arg3: memref<1x32768x128xf32>) attributes {arch = "amdgcn-amd-amdhsa:gfx942:sramecc+:xnack-"} { // CHECK: rock.conv_elementwise_gemm - // CHECK: params0 = #rock.mfma_gemm_params - // CHECK: params1 = #rock.mfma_gemm_params + // CHECK: params0 = #rock.mfma_gemm_params + // CHECK: params1 = #rock.mfma_gemm_params rock.conv_elementwise_gemm{ ab = conv(%arg0, %arg1) : memref<1x128x256x3x3xf32>, memref<2x1x256x128x128xf32> %arg3 = ab * %arg2 : memref<1x128x128xf32> -> memref<1x32768x128xf32> @@ -589,8 +589,8 @@ func.func @rock_conv_gemm_splitk(%arg0: memref<1x128x256x3x3xf32>, %arg1: memref // GRID-SAME: grid_size = 256 func.func @rock_conv_gemm_large(%arg0: memref<1x128x256x3x3xf32>, %arg1: memref<2x1x256x128x128xf32>, %arg2: memref<1x128x128xf32>, %arg3: memref<1x32768x128xf32>) attributes {arch = "amdgcn-amd-amdhsa:gfx942:sramecc+:xnack-"} { // CHECK: rock.conv_elementwise_gemm - // CHECK: params0 = #rock.mfma_gemm_params - // CHECK: params1 = #rock.mfma_gemm_params + // CHECK: params0 = #rock.mfma_gemm_params + // CHECK: params1 = #rock.mfma_gemm_params rock.conv_elementwise_gemm{ ab = conv(%arg0, %arg1) : memref<1x128x256x3x3xf32>, memref<2x1x256x128x128xf32> %arg3 = ab * %arg2 : memref<1x128x128xf32> -> memref<1x32768x128xf32> @@ -604,8 +604,8 @@ func.func @rock_conv_gemm_large(%arg0: memref<1x128x256x3x3xf32>, %arg1: memref< // GRID-SAME: grid_size = 256 func.func @rock_conv_gemm_mperblockg1_wmma(%arg0: memref<1x128x256x1x1xf16>, %arg1: memref<2x1x256x128x128xf16>, %arg2: memref<1x128x128xf16>, %arg3: memref<1x32768x128xf16>) attributes {kernel, mhal.arch = "amdgcn-amd-amdhsa:gfx1100"} { // CHECK: rock.conv_elementwise_gemm - // CHECK: #rock.wmma_gemm_params - // CHECK: #rock.wmma_gemm_params + // CHECK: #rock.wmma_gemm_params + // CHECK: #rock.wmma_gemm_params rock.conv_elementwise_gemm{ ab = conv(%arg0, %arg1) : memref<1x128x256x1x1xf16>, memref<2x1x256x128x128xf16> @@ -620,8 +620,8 @@ func.func @rock_conv_gemm_mperblockg1_wmma(%arg0: memref<1x128x256x1x1xf16>, %ar // GRID-SAME: grid_size = 256 func.func @rock_conv_gemm_mperblockg1_mfma(%arg0: memref<1x128x256x1x1xf32>, %arg1: memref<2x1x256x128x128xf32>, %arg2: memref<1x128x128xf32>, %arg3: memref<1x32768x128xf32>) attributes {kernel, mhal.arch = "gfx942:sramecc+:xnack-"} { // CHECK: rock.conv_elementwise_gemm - // CHECK: #rock.mfma_gemm_params - // CHECK: #rock.mfma_gemm_params + // CHECK: #rock.mfma_gemm_params + // CHECK: #rock.mfma_gemm_params rock.conv_elementwise_gemm{ ab = conv(%arg0, %arg1) : memref<1x128x256x1x1xf32>, memref<2x1x256x128x128xf32> %arg3 = ab * %arg2 : memref<1x128x128xf32> -> memref<1x32768x128xf32> @@ -649,8 +649,8 @@ func.func @rock_conv_tuning(%arg0: memref<1x1x1x3x3xf32>, %arg1: memref<64x1x1x1 // GRID-LABEL: @rock_attn_schedulev2 func.func @rock_attn_schedulev2(%arg0: memref<1x384x64xf16>, %arg1: memref<1x384x64xf16>, %arg2: memref<1x384x64xf16>, %arg3: memref<1x384x64xf16>) attributes {schedule_version = #rock.schedule_version<2>, arch = "amdgcn-amd-amdhsa:gfx1100"} { // CHECK: rock.attention - // CHECK: params0 = #rock.wmma_gemm_params - // CHECK-SAME: params1 = #rock.wmma_gemm_params + // CHECK: params0 = #rock.wmma_gemm_params + // CHECK-SAME: params1 = #rock.wmma_gemm_params // GRID: rock.gridwise_attention_accel // GRID: gridSize = 12 rock.attention{ @@ -664,8 +664,8 @@ func.func @rock_attn_schedulev2(%arg0: memref<1x384x64xf16>, %arg1: memref<1x384 // GRID-LABEL: @rock_attn_schedulev3 func.func @rock_attn_schedulev3(%arg0: memref<1x384x64xf16>, %arg1: memref<1x384x64xf16>, %arg2: memref<1x384x64xf16>, %arg3: memref<1x384x64xf16>) attributes {schedule_version = #rock.schedule_version<3>, arch = "amdgcn-amd-amdhsa:gfx950:sramecc+:xnack-"} { // CHECK: rock.attention - // CHECK: params0 = #rock.mfma_gemm_params - // CHECK-SAME: params1 = #rock.mfma_gemm_params + // CHECK: params0 = #rock.mfma_gemm_params + // CHECK-SAME: params1 = #rock.mfma_gemm_params // GRID: rock.gridwise_attention_accel // GRID: gridSize = 12 rock.attention{ @@ -679,8 +679,8 @@ func.func @rock_attn_schedulev3(%arg0: memref<1x384x64xf16>, %arg1: memref<1x384 // GRID-LABEL: @rock_attn_schedulev4 func.func @rock_attn_schedulev4(%arg0: memref<1x384x64xf16>, %arg1: memref<1x384x64xf16>, %arg2: memref<1x384x64xf16>, %arg3: memref<1x384x64xf16>) attributes {schedule_version = #rock.schedule_version<4>, arch = "amdgcn-amd-amdhsa:gfx950:sramecc+:xnack-"} { // CHECK: rock.attention - // CHECK: params0 = #rock.mfma_gemm_params - // CHECK-SAME: params1 = #rock.mfma_gemm_params + // CHECK: params0 = #rock.mfma_gemm_params + // CHECK-SAME: params1 = #rock.mfma_gemm_params // GRID: rock.gridwise_attention_accel // GRID: gridSize = 12 rock.attention{ @@ -694,8 +694,8 @@ func.func @rock_attn_schedulev4(%arg0: memref<1x384x64xf16>, %arg1: memref<1x384 // GRID-LABEL: @rock_attn_perfconfig_schedulev2 func.func @rock_attn_perfconfig_schedulev2(%arg0: memref<1x384x64xf16>, %arg1: memref<1x384x64xf16>, %arg2: memref<1x384x64xf16>, %arg3: memref<1x384x64xf16>) attributes {arch = "amdgcn-amd-amdhsa:gfx1100"} { // CHECK: rock.attention - // CHECK: params0 = #rock.wmma_gemm_params - // CHECK-SAME: params1 = #rock.wmma_gemm_params + // CHECK: params0 = #rock.wmma_gemm_params + // CHECK-SAME: params1 = #rock.wmma_gemm_params // GRID: rock.gridwise_attention_accel // GRID: gridSize = 12 rock.attention{ @@ -709,8 +709,8 @@ func.func @rock_attn_perfconfig_schedulev2(%arg0: memref<1x384x64xf16>, %arg1: m // GRID-LABEL: @rock_attn_perfconfig_schedulev3 func.func @rock_attn_perfconfig_schedulev3(%arg0: memref<1x384x64xf16>, %arg1: memref<1x384x64xf16>, %arg2: memref<1x384x64xf16>, %arg3: memref<1x384x64xf16>) attributes {arch = "amdgcn-amd-amdhsa:gfx950:sramecc+:xnack-"} { // CHECK: rock.attention - // CHECK: params0 = #rock.mfma_gemm_params - // CHECK-SAME: params1 = #rock.mfma_gemm_params + // CHECK: params0 = #rock.mfma_gemm_params + // CHECK-SAME: params1 = #rock.mfma_gemm_params // GRID: rock.gridwise_attention_accel // GRID: gridSize = 12 rock.attention{ @@ -724,8 +724,8 @@ func.func @rock_attn_perfconfig_schedulev3(%arg0: memref<1x384x64xf16>, %arg1: m // GRID-LABEL: @rock_attn_perfconfig_schedulev4 func.func @rock_attn_perfconfig_schedulev4(%arg0: memref<1x384x64xf16>, %arg1: memref<1x384x64xf16>, %arg2: memref<1x384x64xf16>, %arg3: memref<1x384x64xf16>) attributes {arch = "amdgcn-amd-amdhsa:gfx950:sramecc+:xnack-"} { // CHECK: rock.attention - // CHECK: params0 = #rock.mfma_gemm_params - // CHECK-SAME: params1 = #rock.mfma_gemm_params + // CHECK: params0 = #rock.mfma_gemm_params + // CHECK-SAME: params1 = #rock.mfma_gemm_params // GRID: rock.gridwise_attention_accel // GRID: gridSize = 12 rock.attention{ @@ -739,8 +739,8 @@ func.func @rock_attn_perfconfig_schedulev4(%arg0: memref<1x384x64xf16>, %arg1: m // GRID-LABEL: @rock_attn_schedule_default func.func @rock_attn_schedule_default(%arg0: memref<1x384x64xf16>, %arg1: memref<1x384x64xf16>, %arg2: memref<1x384x64xf16>, %arg3: memref<1x384x64xf16>) attributes {arch = "amdgcn-amd-amdhsa:gfx1100"} { // CHECK: rock.attention - // CHECK: params0 = #rock.wmma_gemm_params - // CHECK-SAME: params1 = #rock.wmma_gemm_params + // CHECK: params0 = #rock.wmma_gemm_params + // CHECK-SAME: params1 = #rock.wmma_gemm_params // GRID: rock.gridwise_attention_accel // GRID: gridSize = 12 rock.attention{ diff --git a/mlir/test/Dialect/Rock/async_wait_add.mlir b/mlir/test/Dialect/Rock/async_wait_add.mlir index ce7cbad8b0ec..46e701e26d90 100644 --- a/mlir/test/Dialect/Rock/async_wait_add.mlir +++ b/mlir/test/Dialect/Rock/async_wait_add.mlir @@ -121,7 +121,7 @@ func.func @gemm_pipelining(%arg0: memref<2359296xbf16>, %arg1: memref<2359296xbf %115 = rock.transform %subview_19 by (d1)> by [ [] at []>, ["k"] at [0]>] bounds = [1, 4] -> [4]> : memref<4xvector<8xbf16>, strided<[1], offset: ?>, #gpu.address_space> to memref<1x4xvector<8xbf16>, #gpu.address_space> affine.for %arg6 = 0 to 4 { %116 = rock.transform %27 by (d0 + d1)> by [ ["offset"] at [0]>] bounds = [1, 1] -> [1]> : memref<1xvector<16xf32>, #gpu.address_space> to memref<1x1xvector<16xf32>, #gpu.address_space> - rock.threadwise_gemm_accel %116 += %112 * %115 at[%arg4, %arg5, %arg6] features = mfma|dot|atomic_add|atomic_add_bf16|atomic_add_f16|direct_to_lds_32b|direct_to_lds_128b {params = #rock.mfma_gemm_params} : memref<1x1xvector<16xf32>, #gpu.address_space> += memref<1x4xvector<8xbf16>, #gpu.address_space> * memref<1x4xvector<8xbf16>, #gpu.address_space> + rock.threadwise_gemm_accel %116 += %112 * %115 at[%arg4, %arg5, %arg6] features = mfma|dot|atomic_add|atomic_add_bf16|atomic_add_f16|direct_to_lds_32b|direct_to_lds_128b {params = #rock.mfma_gemm_params} : memref<1x1xvector<16xf32>, #gpu.address_space> += memref<1x4xvector<8xbf16>, #gpu.address_space> * memref<1x4xvector<8xbf16>, #gpu.address_space> } } } @@ -169,7 +169,7 @@ func.func @gemm_pipelining(%arg0: memref<2359296xbf16>, %arg1: memref<2359296xbf %83 = rock.transform %subview_15 by (d1)> by [ [] at []>, ["k"] at [0]>] bounds = [1, 4] -> [4]> : memref<4xvector<8xbf16>, strided<[1], offset: ?>, #gpu.address_space> to memref<1x4xvector<8xbf16>, #gpu.address_space> affine.for %arg5 = 0 to 4 { %84 = rock.transform %27 by (d0 + d1)> by [ ["offset"] at [0]>] bounds = [1, 1] -> [1]> : memref<1xvector<16xf32>, #gpu.address_space> to memref<1x1xvector<16xf32>, #gpu.address_space> - rock.threadwise_gemm_accel %84 += %80 * %83 at[%arg3, %arg4, %arg5] features = mfma|dot|atomic_add|atomic_add_bf16|atomic_add_f16|direct_to_lds_32b|direct_to_lds_128b {params = #rock.mfma_gemm_params} : memref<1x1xvector<16xf32>, #gpu.address_space> += memref<1x4xvector<8xbf16>, #gpu.address_space> * memref<1x4xvector<8xbf16>, #gpu.address_space> + rock.threadwise_gemm_accel %84 += %80 * %83 at[%arg3, %arg4, %arg5] features = mfma|dot|atomic_add|atomic_add_bf16|atomic_add_f16|direct_to_lds_32b|direct_to_lds_128b {params = #rock.mfma_gemm_params} : memref<1x1xvector<16xf32>, #gpu.address_space> += memref<1x4xvector<8xbf16>, #gpu.address_space> * memref<1x4xvector<8xbf16>, #gpu.address_space> } } } @@ -211,7 +211,7 @@ func.func @gemm_pipelining(%arg0: memref<2359296xbf16>, %arg1: memref<2359296xbf %83 = rock.transform %subview_15 by (d1)> by [ [] at []>, ["k"] at [0]>] bounds = [1, 4] -> [4]> : memref<4xvector<8xbf16>, strided<[1], offset: ?>, #gpu.address_space> to memref<1x4xvector<8xbf16>, #gpu.address_space> affine.for %arg5 = 0 to 4 { %84 = rock.transform %27 by (d0 + d1)> by [ ["offset"] at [0]>] bounds = [1, 1] -> [1]> : memref<1xvector<16xf32>, #gpu.address_space> to memref<1x1xvector<16xf32>, #gpu.address_space> - rock.threadwise_gemm_accel %84 += %80 * %83 at[%arg3, %arg4, %arg5] features = mfma|dot|atomic_add|atomic_add_bf16|atomic_add_f16|direct_to_lds_32b|direct_to_lds_128b {params = #rock.mfma_gemm_params} : memref<1x1xvector<16xf32>, #gpu.address_space> += memref<1x4xvector<8xbf16>, #gpu.address_space> * memref<1x4xvector<8xbf16>, #gpu.address_space> + rock.threadwise_gemm_accel %84 += %80 * %83 at[%arg3, %arg4, %arg5] features = mfma|dot|atomic_add|atomic_add_bf16|atomic_add_f16|direct_to_lds_32b|direct_to_lds_128b {params = #rock.mfma_gemm_params} : memref<1x1xvector<16xf32>, #gpu.address_space> += memref<1x4xvector<8xbf16>, #gpu.address_space> * memref<1x4xvector<8xbf16>, #gpu.address_space> } } } @@ -306,7 +306,7 @@ func.func @gemm_no_pipelining(%arg0: memref<2359296xbf16>, %arg1: memref<2359296 %51 = rock.transform %50 by (d1)> by [ [] at []>, ["k"] at [0]>] bounds = [1, 4] -> [4]> : memref<4xvector<8xbf16>, #gpu.address_space> to memref<1x4xvector<8xbf16>, #gpu.address_space> %52 = rock.extract_multibuffer(%view_5) [%arg5](memref<4xvector<8xbf16>, #gpu.address_space>) : memref<4xvector<8xbf16>, #gpu.address_space> %53 = rock.transform %52 by (d1)> by [ [] at []>, ["k"] at [0]>] bounds = [1, 4] -> [4]> : memref<4xvector<8xbf16>, #gpu.address_space> to memref<1x4xvector<8xbf16>, #gpu.address_space> - rock.threadwise_gemm_accel %49 += %51 * %53 at[%arg4, %arg5, %arg6] features = mfma|dot|atomic_add|atomic_add_bf16|atomic_add_f16|direct_to_lds_32b|direct_to_lds_128b {params = #rock.mfma_gemm_params} : memref<1x1xvector<16xf32>, #gpu.address_space> += memref<1x4xvector<8xbf16>, #gpu.address_space> * memref<1x4xvector<8xbf16>, #gpu.address_space> + rock.threadwise_gemm_accel %49 += %51 * %53 at[%arg4, %arg5, %arg6] features = mfma|dot|atomic_add|atomic_add_bf16|atomic_add_f16|direct_to_lds_32b|direct_to_lds_128b {params = #rock.mfma_gemm_params} : memref<1x1xvector<16xf32>, #gpu.address_space> += memref<1x4xvector<8xbf16>, #gpu.address_space> * memref<1x4xvector<8xbf16>, #gpu.address_space> } } } diff --git a/mlir/test/Dialect/Rock/conv_to_gemm.mlir b/mlir/test/Dialect/Rock/conv_to_gemm.mlir index b8627f8ee404..d32d46997b9d 100644 --- a/mlir/test/Dialect/Rock/conv_to_gemm.mlir +++ b/mlir/test/Dialect/Rock/conv_to_gemm.mlir @@ -23,7 +23,7 @@ func.func @nhwc_1x1(%arg0: memref<16384xf16>, %arg1: memref<802816xf16>, %arg2: %0 = rock.transform %arg0 by ((d0 * 256 + d1 + d2 + d3) * 64 + d4)> by [ ["raw"] at [0]>] bounds = [1, 256, 1, 1, 64] -> [16384]> : memref<16384xf16> to memref<1x256x1x1x64xf16> %1 = rock.transform %arg1 by (((d0 * 14 + d1) * 14 + d2 + d3) * 64 + d4)> by [ ["raw"] at [0]>] bounds = [64, 14, 14, 1, 64] -> [802816]> : memref<802816xf16> to memref<64x14x14x1x64xf16> %2 = rock.transform %arg2 by (((d0 * 14 + d1) * 14 + d2 + d3) * 256 + d4)> by [ ["raw"] at [0]>] bounds = [64, 14, 14, 1, 256] -> [3211264]> : memref<3211264xf16> to memref<64x14x14x1x256xf16> - rock.conv(%0, %1, %2) {derivedBlockSize = 128 : i32, dilations = [1 : index, 1 : index], filter_layout = ["g", "k", "0", "1", "c"], input_layout = ["ni", "0i", "1i", "gi", "ci"], output_layout = ["no", "0o", "1o", "go", "ko"], padding = [0 : index, 0 : index, 0 : index, 0 : index], params = #rock.wmma_gemm_params, strides = [1 : index, 1 : index]} : memref<1x256x1x1x64xf16>, memref<64x14x14x1x64xf16>, memref<64x14x14x1x256xf16> + rock.conv(%0, %1, %2) {derivedBlockSize = 128 : i32, dilations = [1 : index, 1 : index], filter_layout = ["g", "k", "0", "1", "c"], input_layout = ["ni", "0i", "1i", "gi", "ci"], output_layout = ["no", "0o", "1o", "go", "ko"], padding = [0 : index, 0 : index, 0 : index, 0 : index], params = #rock.wmma_gemm_params, strides = [1 : index, 1 : index]} : memref<1x256x1x1x64xf16>, memref<64x14x14x1x64xf16>, memref<64x14x14x1x256xf16> return } @@ -34,7 +34,7 @@ func.func @nhwc_1x1_stride_2(%arg0: memref<16384xf16>, %arg1: memref<802816xf16> %0 = rock.transform %arg0 by ((d0 * 256 + d1 + d2 + d3) * 64 + d4)> by [ ["raw"] at [0]>] bounds = [1, 256, 1, 1, 64] -> [16384]> : memref<16384xf16> to memref<1x256x1x1x64xf16> %1 = rock.transform %arg1 by (((d0 * 14 + d1) * 14 + d2 + d3) * 64 + d4)> by [ ["raw"] at [0]>] bounds = [64, 14, 14, 1, 64] -> [802816]> : memref<802816xf16> to memref<64x14x14x1x64xf16> %2 = rock.transform %arg2 by (((d0 * 7 + d1) * 7 + d2 + d3) * 256 + d4)> by [ ["raw"] at [0]>] bounds = [64, 7, 7, 1, 256] -> [802816]> : memref<802816xf16> to memref<64x7x7x1x256xf16> - rock.conv(%0, %1, %2) {derivedBlockSize = 128 : i32, dilations = [1 : index, 1 : index], filter_layout = ["g", "k", "0", "1", "c"], input_layout = ["ni", "0i", "1i", "gi", "ci"], numCU = 96 : i32, output_layout = ["no", "0o", "1o", "go", "ko"], padding = [0 : index, 0 : index, 0 : index, 0 : index], params = #rock.wmma_gemm_params, strides = [2 : index, 2 : index]} : memref<1x256x1x1x64xf16>, memref<64x14x14x1x64xf16>, memref<64x7x7x1x256xf16> + rock.conv(%0, %1, %2) {derivedBlockSize = 128 : i32, dilations = [1 : index, 1 : index], filter_layout = ["g", "k", "0", "1", "c"], input_layout = ["ni", "0i", "1i", "gi", "ci"], numCU = 96 : i32, output_layout = ["no", "0o", "1o", "go", "ko"], padding = [0 : index, 0 : index, 0 : index, 0 : index], params = #rock.wmma_gemm_params, strides = [2 : index, 2 : index]} : memref<1x256x1x1x64xf16>, memref<64x14x14x1x64xf16>, memref<64x7x7x1x256xf16> return } @@ -45,7 +45,7 @@ func.func @nhwc_3x3(%arg0: memref<147456xf16>, %arg1: memref<802816xf16>, %arg2: %0 = rock.transform %arg0 by ((((d0 * 256 + d1) * 3 + d2) * 3 + d3) * 64 + d4)> by [ ["raw"] at [0]>] bounds = [1, 256, 3, 3, 64] -> [147456]> : memref<147456xf16> to memref<1x256x3x3x64xf16> %1 = rock.transform %arg1 by (((d0 * 14 + d1) * 14 + d2 + d3) * 64 + d4)> by [ ["raw"] at [0]>] bounds = [64, 14, 14, 1, 64] -> [802816]> : memref<802816xf16> to memref<64x14x14x1x64xf16> %2 = rock.transform %arg2 by (((d0 * 12 + d1) * 12 + d2 + d3) * 256 + d4)> by [ ["raw"] at [0]>] bounds = [64, 12, 12, 1, 256] -> [2359296]> : memref<2359296xf16> to memref<64x12x12x1x256xf16> - rock.conv(%0, %1, %2) {derivedBlockSize = 128 : i32, dilations = [1 : index, 1 : index], filter_layout = ["g", "k", "0", "1", "c"], input_layout = ["ni", "0i", "1i", "gi", "ci"], output_layout = ["no", "0o", "1o", "go", "ko"], padding = [0 : index, 0 : index, 0 : index, 0 : index], params = #rock.wmma_gemm_params, strides = [1 : index, 1 : index]} : memref<1x256x3x3x64xf16>, memref<64x14x14x1x64xf16>, memref<64x12x12x1x256xf16> + rock.conv(%0, %1, %2) {derivedBlockSize = 128 : i32, dilations = [1 : index, 1 : index], filter_layout = ["g", "k", "0", "1", "c"], input_layout = ["ni", "0i", "1i", "gi", "ci"], output_layout = ["no", "0o", "1o", "go", "ko"], padding = [0 : index, 0 : index, 0 : index, 0 : index], params = #rock.wmma_gemm_params, strides = [1 : index, 1 : index]} : memref<1x256x3x3x64xf16>, memref<64x14x14x1x64xf16>, memref<64x12x12x1x256xf16> return } diff --git a/mlir/test/Dialect/Rock/effects.mlir b/mlir/test/Dialect/Rock/effects.mlir index 70d261be4d06..1cba9a2a2903 100644 --- a/mlir/test/Dialect/Rock/effects.mlir +++ b/mlir/test/Dialect/Rock/effects.mlir @@ -128,7 +128,9 @@ func.func @rock_gridwise_gemm_accel(%A : memref<2x1024x1024xf32>, mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, + wavesPerEU = 0, + gridGroupSize = 0, forceUnroll = true> } : memref<2x1024x1024xf32>, memref<2x1024x2048xf32>, memref<2x1024x2048xf32> return @@ -154,7 +156,9 @@ func.func @rock_gridwise_scaled_gemm_accel(%A : memref<2x1024x1024xf4E2M1FN>, %B mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, + wavesPerEU = 0, + gridGroupSize = 0, forceUnroll = true> } : memref<2x1024x1024xf4E2M1FN>, memref<2x1024x2048xf4E2M1FN>, memref<2x1024x2048xf32>, memref<2x1024x1024xf8E8M0FNU>, memref<2x1024x2048xf8E8M0FNU> return @@ -281,7 +285,9 @@ func.func @rock_blockwise_gemm_accel(%bufferA : memref<4xvector<8xf8E4M3FN>, #gp mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, + wavesPerEU = 0, + gridGroupSize = 0, forceUnroll = true> } : memref<4xvector<16xf32>, #gpu.address_space> += memref<4xvector<8xf8E4M3FN>, #gpu.address_space> * memref<4xvector<8xf8E5M2>, #gpu.address_space> return @@ -315,7 +321,9 @@ func.func @rock_blockwise_gemm_accel_lds(%matrixA : memref<1024xvector<8xf8E4M3F mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, + wavesPerEU = 0, + gridGroupSize = 0, forceUnroll = true> } : memref<4xvector<16xf32>, #gpu.address_space> += memref<4xvector<8xf8E4M3FN>, #gpu.address_space> from memref<1024xvector<8xf8E4M3FN>, #gpu.address_space> * memref<4xvector<8xf8E5M2>, #gpu.address_space> from memref<1024xvector<8xf8E5M2>, #gpu.address_space> return @@ -360,7 +368,9 @@ func.func @rock_blockwise_gemm_accel_two_results(%matrixA : memref<256xvector<2x mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, + wavesPerEU = 0, + gridGroupSize = 0, forceUnroll = true> } : memref<4xvector<16xf32>, #gpu.address_space> += memref<4xf4E2M1FN, #gpu.address_space> from memref<256xvector<2xf4E2M1FN>, #gpu.address_space> scaled by memref<4xf8E8M0FNU, #gpu.address_space> from memref<256xvector<2xf8E8M0FNU>, #gpu.address_space> * memref<4xf4E2M1FN, #gpu.address_space> from memref<256xvector<2xf4E2M1FN>, #gpu.address_space> scaled by memref<4xf8E8M0FNU, #gpu.address_space> from memref<256xvector<2xf8E8M0FNU>, #gpu.address_space> return @@ -398,7 +408,9 @@ func.func @rock_threadwise_gemm_accel(%matrixA : memref<1x16xf32, 5>, kpack = 1, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, + wavesPerEU = 0, + gridGroupSize = 0, forceUnroll = true> } : memref<1x1xvector<32xf32>, 5> += memref<1x16xf32, 5> * memref<1x16xf32, 5> return @@ -427,7 +439,9 @@ func.func @rock_threadwise_gemm_accel_scaled(%matrixA : memref<1x4xvector<4xf4E2 kpack = 1, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, + wavesPerEU = 0, + gridGroupSize = 0, forceUnroll = true> } : memref<1x1xvector<32xf32>, 5> += memref<1x4xvector<4xf4E2M1FN>, 5> scaled by memref<1x4xvector<4xf8E8M0FNU>, 5> * memref<1x4xvector<4xf4E2M1FN>, 5> scaled by memref<1x4xvector<4xf8E8M0FNU>, 5> return @@ -447,8 +461,8 @@ func.func @rock_gridwise_attn(%arg0: memref<1x384x64xf32>, rock.gridwise_attention_accel(%0, %arg1, %arg2, %arg3) preSoftmaxOps = {} { blockSize = 64 : i32, gridSize = 24 : i32, - params0 = #rock.mfma_gemm_params, - params1 = #rock.mfma_gemm_params, + params0 = #rock.mfma_gemm_params, + params1 = #rock.mfma_gemm_params, firstGemmIndices = array, operand_segment_sizes = array, splitKV = 1 : i32, @@ -466,8 +480,8 @@ func.func @rock_reduce(%arg0: memref<2x12x12xf32>, func.return } -#xldops_attn_params_g0 = #rock.mfma_gemm_params -#xldops_attn_params_g1 = #rock.mfma_gemm_params +#xldops_attn_params_g0 = #rock.mfma_gemm_params +#xldops_attn_params_g1 = #rock.mfma_gemm_params func.func @rock_gemmelementwisegemm_simple(%arg0: memref<1x64x1024xf32>, %arg1: memref<1x64x1024xf32>, %arg2: memref<1x1024x64xf32>, @@ -540,7 +554,7 @@ func.func @loadtile_doublebuffer(%arg0: memref<1x384x64xf32>, %lds: memref<4096x // expected-remark @below {{found an instance of 'write' on op operand 1, on resource ''}} // expected-remark @below {{found an instance of 'read' on op operand 1, on resource ''}} // expected-remark @below {{found an instance of 'write' on op operand 2, on resource ''}} - rock.blockwise_load_tile %0[%arg1, %c0, %c0, %c0, %c0] LDS -> %lds -> %reg {elementType = f32, elementLoadType = f32, matrixParamsA = #rock.blockwise_matrix_params, matrixParamsB = #rock.blockwise_matrix_params, blockSize = 64 : i32, loadType = #rock, params = #rock.mfma_gemm_params} : memref<1x64x384xf32> LDS -> memref<4096xi8, #gpu.address_space> -> memref<16xf32, #gpu.address_space> + rock.blockwise_load_tile %0[%arg1, %c0, %c0, %c0, %c0] LDS -> %lds -> %reg {elementType = f32, elementLoadType = f32, matrixParamsA = #rock.blockwise_matrix_params, matrixParamsB = #rock.blockwise_matrix_params, blockSize = 64 : i32, loadType = #rock, params = #rock.mfma_gemm_params} : memref<1x64x384xf32> LDS -> memref<4096xi8, #gpu.address_space> -> memref<16xf32, #gpu.address_space> } return } @@ -554,7 +568,7 @@ func.func @loadtile_default(%arg0: memref<1x384x64xf32>, %lds: memref<4096xi8, # affine.for %arg1 = 0 to 2 { // expected-remark @below {{found an instance of 'read' on op operand 0, on resource ''}} // expected-remark @below {{found an instance of 'write' on op operand 1, on resource ''}} - rock.blockwise_load_tile %0[%arg1, %c0, %c0, %c0, %c0] LDS -> %lds {elementType = f32, elementLoadType = f32, matrixParamsA = #rock.blockwise_matrix_params, matrixParamsB = #rock.blockwise_matrix_params, blockSize = 64 : i32, loadType = #rock, params = #rock.mfma_gemm_params} : memref<1x64x384xf32> LDS -> memref<4096xi8, #gpu.address_space> + rock.blockwise_load_tile %0[%arg1, %c0, %c0, %c0, %c0] LDS -> %lds {elementType = f32, elementLoadType = f32, matrixParamsA = #rock.blockwise_matrix_params, matrixParamsB = #rock.blockwise_matrix_params, blockSize = 64 : i32, loadType = #rock, params = #rock.mfma_gemm_params} : memref<1x64x384xf32> LDS -> memref<4096xi8, #gpu.address_space> } return } @@ -568,7 +582,7 @@ func.func @loadtile_bypasslds(%arg0: memref<1x384x64xf32>, %reg: memref<16xf32, affine.for %arg1 = 0 to 2 { // expected-remark @below {{found an instance of 'read' on op operand 0, on resource ''}} // expected-remark @below {{found an instance of 'write' on op operand 1, on resource ''}} - rock.blockwise_load_tile %0[%arg1, %c0, %c0, %c0, %c0] -> %reg {elementType = f32, elementLoadType = f32, matrixParamsA = #rock.blockwise_matrix_params, matrixParamsB = #rock.blockwise_matrix_params, blockSize = 64 : i32, loadType = #rock, params = #rock.mfma_gemm_params} : memref<1x64x384xf32> -> memref<16xf32, #gpu.address_space> + rock.blockwise_load_tile %0[%arg1, %c0, %c0, %c0, %c0] -> %reg {elementType = f32, elementLoadType = f32, matrixParamsA = #rock.blockwise_matrix_params, matrixParamsB = #rock.blockwise_matrix_params, blockSize = 64 : i32, loadType = #rock, params = #rock.mfma_gemm_params} : memref<1x64x384xf32> -> memref<16xf32, #gpu.address_space> } return } @@ -584,7 +598,7 @@ func.func @loadtile_doublebuffer_directtolds(%arg0: memref<1x384x64xf32>, %lds: // expected-remark @below {{found an instance of 'write' on op operand 1, on resource ''}} // expected-remark @below {{found an instance of 'read' on op operand 1, on resource ''}} // expected-remark @below {{found an instance of 'write' on op operand 2, on resource ''}} - rock.blockwise_load_tile %0[%arg1, %c0, %c0, %c0, %c0] LDS -> %lds -> %reg {elementType = f32, elementLoadType = f32, matrixParamsA = #rock.blockwise_matrix_params, matrixParamsB = #rock.blockwise_matrix_params, blockSize = 64 : i32, loadType = #rock, params = #rock.mfma_gemm_params} : memref<1x64x384xf32> LDS -> memref<4096xi8, #gpu.address_space> -> memref<16xf32, #gpu.address_space> + rock.blockwise_load_tile %0[%arg1, %c0, %c0, %c0, %c0] LDS -> %lds -> %reg {elementType = f32, elementLoadType = f32, matrixParamsA = #rock.blockwise_matrix_params, matrixParamsB = #rock.blockwise_matrix_params, blockSize = 64 : i32, loadType = #rock, params = #rock.mfma_gemm_params} : memref<1x64x384xf32> LDS -> memref<4096xi8, #gpu.address_space> -> memref<16xf32, #gpu.address_space> } return } @@ -598,7 +612,7 @@ func.func @loadtile_default_directtolds(%arg0: memref<1x384x64xf32>, %lds: memre affine.for %arg1 = 0 to 2 { // expected-remark @below {{found an instance of 'read' on op operand 0, on resource ''}} // expected-remark @below {{found an instance of 'write' on op operand 1, on resource ''}} - rock.blockwise_load_tile %0[%arg1, %c0, %c0, %c0, %c0] LDS -> %lds {elementType = f32, elementLoadType = f32, matrixParamsA = #rock.blockwise_matrix_params, matrixParamsB = #rock.blockwise_matrix_params, blockSize = 64 : i32, loadType = #rock, params = #rock.mfma_gemm_params} : memref<1x64x384xf32> LDS -> memref<4096xi8, #gpu.address_space> + rock.blockwise_load_tile %0[%arg1, %c0, %c0, %c0, %c0] LDS -> %lds {elementType = f32, elementLoadType = f32, matrixParamsA = #rock.blockwise_matrix_params, matrixParamsB = #rock.blockwise_matrix_params, blockSize = 64 : i32, loadType = #rock, params = #rock.mfma_gemm_params} : memref<1x64x384xf32> LDS -> memref<4096xi8, #gpu.address_space> } return } diff --git a/mlir/test/Dialect/Rock/gemm_to_gridwise.mlir b/mlir/test/Dialect/Rock/gemm_to_gridwise.mlir index 524669588f7b..1e91b3f727e9 100644 --- a/mlir/test/Dialect/Rock/gemm_to_gridwise.mlir +++ b/mlir/test/Dialect/Rock/gemm_to_gridwise.mlir @@ -6,12 +6,12 @@ #general_gemm_params0 = #rock.general_gemm_params #general_gemm_params_splitk = #rock.general_gemm_params #general_gemm_params1 = #rock.general_gemm_params -#xdlops_gemm_params0 = #rock.mfma_gemm_params -#xdlops_gemm_params1 = #rock.mfma_gemm_params -#xdlops_gemm_params3 = #rock.mfma_gemm_params -#xldops_attn_params_g0 = #rock.mfma_gemm_params -#xldops_attn_params_g1 = #rock.mfma_gemm_params -#xldops_attn_params_g1_splitk = #rock.mfma_gemm_params +#xdlops_gemm_params0 = #rock.mfma_gemm_params +#xdlops_gemm_params1 = #rock.mfma_gemm_params +#xdlops_gemm_params3 = #rock.mfma_gemm_params +#xldops_attn_params_g0 = #rock.mfma_gemm_params +#xldops_attn_params_g1 = #rock.mfma_gemm_params +#xldops_attn_params_g1_splitk = #rock.mfma_gemm_params // CHECK-LABEL: func.func @gemm_easy_case_from_conv // CHECK-SAME: (%[[a:.*]]: memref<1x72x128xf32>, %[[b:.*]]: memref<1x72x512xf32>, %[[c:.*]]: memref<1x128x512xf32>) @@ -505,7 +505,7 @@ func.func @rock_gemmelementwisegemm_splitk_two_outputs(%arg0: memref<4096xf32>, rock.yield } %alloc = ab * %0 : memref<1x64x64xf32> -> memref<1x64x64xf32> - } {firstGemmIndices = array, params0 = #rock.mfma_gemm_params, params1 = #rock.mfma_gemm_params, perf_config = "attn:v2:64,128,32,16,32,16,4,4,1,2,1", storeMethod = #rock} + } {firstGemmIndices = array, params0 = #rock.mfma_gemm_params, params1 = #rock.mfma_gemm_params, perf_config = "attn:v2:64,128,32,16,32,16,4,4,1,2,1", storeMethod = #rock} %3 = rock.transform %alloc by (0, d0 floordiv 64, d0 mod 64)> by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [4096] -> [1, 64, 64]> : memref<1x64x64xf32> to memref<4096xf32> %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x64x1xf32> @@ -552,7 +552,7 @@ func.func @rock_attention_gqa(%arg0: memref<64x1x128xf16>, %arg1: memref<8x128x8 rock.yield } %arg4 = softmax(qk) * %arg2 : memref<8x8192x128xf16> -> memref<256x1x128xf16> - } {features = #rock, firstGemmIndices = array, numHeadsKV = 8 : i32, numHeadsQ = 64 : i32, params0 = #rock.wmma_gemm_params, params1 = #rock.wmma_gemm_params, softmaxType = f32, splitKV = 4 : i32, storeMethod = #rock} + } {features = #rock, firstGemmIndices = array, numHeadsKV = 8 : i32, numHeadsQ = 64 : i32, params0 = #rock.wmma_gemm_params, params1 = #rock.wmma_gemm_params, softmaxType = f32, splitKV = 4 : i32, storeMethod = #rock} return } @@ -659,7 +659,7 @@ func.func @gemm_scaled_fp4_splitk(%a: memref<1x72x128xf4E2M1FN>, %b: memref<1x72 rock.gemm %c = tr %a scaled by %scaleA * %b scaled by %scaleB features = mfma storeMethod = set { derivedBlockSize = 256 : i32, gridSize = 16 : i32, - params = #rock.mfma_gemm_params + params = #rock.mfma_gemm_params } : memref<1x128x512xf32> = memref<1x72x128xf4E2M1FN> scaled by memref<1x128x72xf8E8M0FNU> * memref<1x72x512xf4E2M1FN> scaled by memref<1x72x512xf8E8M0FNU> func.return } @@ -727,7 +727,7 @@ func.func @gemm_scaled_fp4_splitk_odd(%arg0: memref<589824xf4E2M1FN>, %arg1: mem rock.gemm %2 = %0 scaled by %7 * %1 scaled by %10 features = mfma storeMethod = set { derivedBlockSize = 256 : i32, gridSize = 12 : i32, - params = #rock.mfma_gemm_params + params = #rock.mfma_gemm_params } : memref<3x256x256xf32> = memref<3x256x768xf4E2M1FN> scaled by memref<3x256x768xf8E8M0FNU> * memref<3x768x256xf4E2M1FN> scaled by memref<3x768x256xf8E8M0FNU> func.return } diff --git a/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir b/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir index d6f532885751..f06c3c572b0c 100644 --- a/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir +++ b/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir @@ -1,6 +1,5 @@ // RUN: rocmlir-opt -split-input-file -rock-gridwise-gemm-to-blockwise -rock-blockwise-load-tile-to-threadwise -canonicalize -verify-diagnostics %s | FileCheck %s -#xdlops_gemm_params = #rock.mfma_gemm_params // CHECK-LABEL: @gridwise_attn_simple // CHECK-SAME: (%[[Q:.+]]: memref<1x384x64xf32>, %[[K:.+]]: memref<1x64x384xf32>, %[[V:.+]]: memref<1x384x64xf32>, %[[O:.+]]: memref<1x384x64xf32>) // CHECK-DAG: %[[ln2Recip:.+]] = arith.constant 1.44269502 : f32 @@ -230,8 +229,8 @@ func.func @gridwise_attn_simple(%arg0: memref<1x384x64xf32>, %arg1: memref<1x64x rock.gridwise_attention_accel(%0, %arg1, %arg2, %arg3) preSoftmaxOps = {} { blockSize = 64 : i32, gridSize = 24 : i32, - params0 = #rock.mfma_gemm_params, - params1 = #rock.mfma_gemm_params, + params0 = #rock.mfma_gemm_params, + params1 = #rock.mfma_gemm_params, firstGemmIndices = array, splitKV = 1 : i32, storeMethod = #rock, @@ -267,8 +266,8 @@ func.func @gridwise_attn_kvcache(%arg0: memref<1x384x64xf32>, %arg1: memref<1x64 blockSize = 64 : i32, gridSize = 24 : i32, operandSegmentSizes = array, - params0 = #rock.mfma_gemm_params, - params1 = #rock.mfma_gemm_params, + params0 = #rock.mfma_gemm_params, + params1 = #rock.mfma_gemm_params, firstGemmIndices = array, splitKV = 1 : i32, storeMethod = #rock @@ -315,8 +314,8 @@ func.func @gridwise_attn_causal_kvcache(%arg0: memref<1x384x64xf32>, %arg1: memr causal, gridSize = 24 : i32, operandSegmentSizes = array, - params0 = #rock.mfma_gemm_params, - params1 = #rock.mfma_gemm_params, + params0 = #rock.mfma_gemm_params, + params1 = #rock.mfma_gemm_params, firstGemmIndices = array, splitKV = 1 : i32, storeMethod = #rock @@ -362,8 +361,8 @@ func.func @gridwise_attn_lse_kvcache(%arg0: memref<1x384x64xf32>, %arg1: memref< blockSize = 64 : i32, gridSize = 24 : i32, operandSegmentSizes = array, - params0 = #rock.mfma_gemm_params, - params1 = #rock.mfma_gemm_params, + params0 = #rock.mfma_gemm_params, + params1 = #rock.mfma_gemm_params, firstGemmIndices = array, splitKV = 1 : i32, storeMethod = #rock @@ -399,8 +398,8 @@ func.func @gridwise_attn_softmaxtype(%arg0: memref<1x384x64xf16>, %arg1: memref< blockSize = 64 : i32, gridSize = 24 : i32, operandSegmentSizes = array, - params0 = #rock.mfma_gemm_params, - params1 = #rock.mfma_gemm_params, + params0 = #rock.mfma_gemm_params, + params1 = #rock.mfma_gemm_params, firstGemmIndices = array, splitKV = 1 : i32, storeMethod = #rock, @@ -457,8 +456,8 @@ func.func @gridwise_attn_softmaxtype_with_scaling(%arg0: memref<1x384x64xf16>, % blockSize = 64 : i32, gridSize = 24 : i32, operandSegmentSizes = array, - params0 = #rock.mfma_gemm_params, - params1 = #rock.mfma_gemm_params, + params0 = #rock.mfma_gemm_params, + params1 = #rock.mfma_gemm_params, firstGemmIndices = array, splitKV = 1 : i32, storeMethod = #rock, @@ -524,8 +523,8 @@ func.func @gridwise_attn_splitkv_lse_kvcache(%arg0: memref<1x384x64xf32>, %arg1: blockSize = 64 : i32, gridSize = 192 : i32, operandSegmentSizes = array, - params0 = #rock.mfma_gemm_params, - params1 = #rock.mfma_gemm_params, + params0 = #rock.mfma_gemm_params, + params1 = #rock.mfma_gemm_params, firstGemmIndices = array, splitKV = 8 : i32, storeMethod = #rock @@ -601,7 +600,7 @@ func.func @multiple_linalg_generics_in_presoftmax_ops(%arg0: memref<59136xf16>, } memref.copy %alloc_1, %arg7 : memref<12x77x77xf32> to memref<12x77x77xf32> rock.yield - } {blockSize = 64 : i32, firstGemmIndices = array, splitKV = 1 : i32, storeMethod = #rock, gridSize = 36 : i32, operandSegmentSizes = array, params0 = #rock.mfma_gemm_params, params1 = #rock.mfma_gemm_params, prePadG0M = 77 : index, prePadG0N = 77 : index, softmaxType = f32} : memref<12x64x96xf16>, memref<12x64x96xf16>, memref<12x96x64xf16>, memref<5929xf16>, memref<12x96x64xf16> + } {blockSize = 64 : i32, firstGemmIndices = array, splitKV = 1 : i32, storeMethod = #rock, gridSize = 36 : i32, operandSegmentSizes = array, params0 = #rock.mfma_gemm_params, params1 = #rock.mfma_gemm_params, prePadG0M = 77 : index, prePadG0N = 77 : index, softmaxType = f32} : memref<12x64x96xf16>, memref<12x64x96xf16>, memref<12x96x64xf16>, memref<5929xf16>, memref<12x96x64xf16> memref.copy %alloc, %arg4 : memref<59136xf16> to memref<59136xf16> return } @@ -684,7 +683,7 @@ func.func @multiple_linalg_generics_in_presoftmax_ops_with_transforms_inbetween( %35 = rock.transform %alloc_1 by (d0 floordiv 6, d0 mod 6, d1, d2)> by [ ["dim0", "dim1"] at [0, 1]>, ["dim2"] at [2]>, ["dim3"] at [3]>] bounds = [12, 77, 77] -> [2, 6, 77, 77]> : memref<2x6x77x77xf32> to memref<12x77x77xf32> memref.copy %35, %arg9 : memref<12x77x77xf32> to memref<12x77x77xf32> rock.yield - } {blockSize = 64 : i32, firstGemmIndices = array, splitKV = 1 : i32, storeMethod = #rock, gridSize = 36 : i32, operandSegmentSizes = array, params0 = #rock.mfma_gemm_params, params1 = #rock.mfma_gemm_params, prePadG0M = 77 : index, prePadG0N = 77 : index, softmaxType = f32} : memref<12x64x96xf16>, memref<12x64x96xf16>, memref<12x96x64xf16>, memref<5929xf16>, memref<5929xf32>, memref<12x96x64xf16> + } {blockSize = 64 : i32, firstGemmIndices = array, splitKV = 1 : i32, storeMethod = #rock, gridSize = 36 : i32, operandSegmentSizes = array, params0 = #rock.mfma_gemm_params, params1 = #rock.mfma_gemm_params, prePadG0M = 77 : index, prePadG0N = 77 : index, softmaxType = f32} : memref<12x64x96xf16>, memref<12x64x96xf16>, memref<12x96x64xf16>, memref<5929xf16>, memref<5929xf32>, memref<12x96x64xf16> memref.copy %alloc, %arg5 : memref<59136xf16> to memref<59136xf16> return } @@ -745,7 +744,7 @@ func.func @non_invertible_transformations_while_regularizing(%arg0: memref<59136 %32 = rock.transform %31 by (d0 floordiv 6, d0 mod 6, d1, d2)> by [ ["dim0", "dim1"] at [0, 1]>, ["dim2"] at [2]>, ["dim3"] at [3]>] bounds = [12, 77, 77] -> [2, 6, 77, 77]> : memref<2x6x77x77xf32> to memref<12x77x77xf32> memref.copy %32, %arg7 : memref<12x77x77xf32> to memref<12x77x77xf32> rock.yield - } {blockSize = 64 : i32, firstGemmIndices = array, splitKV = 1 : i32, storeMethod = #rock, gridSize = 36 : i32, operandSegmentSizes = array, params0 = #rock.mfma_gemm_params, params1 = #rock.mfma_gemm_params, prePadG0M = 77 : index, prePadG0N = 77 : index, softmaxType = f32} : memref<12x64x96xf16>, memref<12x64x96xf16>, memref<12x96x64xf16>, memref<5929xf16>, memref<12x96x64xf16> + } {blockSize = 64 : i32, firstGemmIndices = array, splitKV = 1 : i32, storeMethod = #rock, gridSize = 36 : i32, operandSegmentSizes = array, params0 = #rock.mfma_gemm_params, params1 = #rock.mfma_gemm_params, prePadG0M = 77 : index, prePadG0N = 77 : index, softmaxType = f32} : memref<12x64x96xf16>, memref<12x64x96xf16>, memref<12x96x64xf16>, memref<5929xf16>, memref<12x96x64xf16> memref.copy %alloc, %arg4 : memref<59136xf16> to memref<59136xf16> return } @@ -808,7 +807,7 @@ func.func @multiple_outputs_linalg_while_regularizing(%arg0: memref<59136xf16>, } memref.copy %alloc_2, %arg7 : memref<12x77x77xf32> to memref<12x77x77xf32> rock.yield - } {blockSize = 64 : i32, firstGemmIndices = array, splitKV = 1 : i32, storeMethod = #rock, gridSize = 36 : i32, operandSegmentSizes = array, params0 = #rock.mfma_gemm_params, params1 = #rock.mfma_gemm_params, prePadG0M = 77 : index, prePadG0N = 77 : index, softmaxType = f32} : memref<12x64x96xf16>, memref<12x64x96xf16>, memref<12x96x64xf16>, memref<5929xf16>, memref<12x96x64xf16> + } {blockSize = 64 : i32, firstGemmIndices = array, splitKV = 1 : i32, storeMethod = #rock, gridSize = 36 : i32, operandSegmentSizes = array, params0 = #rock.mfma_gemm_params, params1 = #rock.mfma_gemm_params, prePadG0M = 77 : index, prePadG0N = 77 : index, softmaxType = f32} : memref<12x64x96xf16>, memref<12x64x96xf16>, memref<12x96x64xf16>, memref<5929xf16>, memref<12x96x64xf16> memref.copy %alloc, %arg4 : memref<59136xf16> to memref<59136xf16> return } @@ -837,6 +836,34 @@ func.func @gridwise_attn_splitk(%arg0: memref<1474560xf16>, %arg1: memref<147456 // CHECK: rock.threadwise_write_all {{.*}} atomic_add : memref<96xf16, #gpu.address_space> -> memref<4x384x4096xf16> rock.gridwise_attention_accel(%13, %14, %15, %16) features = mfma|dot|atomic_add|atomic_add_f16|direct_to_lds_32b preSoftmaxOps = { - } {blockSize = 128 : i32, enableSoftmax = false, firstGemmIndices = array, splitKV = 1 : i32, gridSize = 512 : i32, operandSegmentSizes = array, params0 = #rock.mfma_gemm_params, params1 = #rock.mfma_gemm_params, storeMethod = #rock} : memref<4x512x4096xf16>, memref<4x512x1024xf16>, memref<4x1024x384xf16>, memref<4x4096x384xf16> + } {blockSize = 128 : i32, enableSoftmax = false, firstGemmIndices = array, splitKV = 1 : i32, gridSize = 512 : i32, operandSegmentSizes = array, params0 = #rock.mfma_gemm_params, params1 = #rock.mfma_gemm_params, storeMethod = #rock} : memref<4x512x4096xf16>, memref<4x512x1024xf16>, memref<4x1024x384xf16>, memref<4x4096x384xf16> + return +} + +// ----- + +// CHECK-LABEL: @gridwise_attn_wavespereu_outputswizzle +// CHECK-SAME: output_swizzle = 1 : i64, waves_per_eu = 4 : i64 +func.func @gridwise_attn_wavespereu_outputswizzle(%arg0: memref<1474560xf16>, %arg1: memref<1474560xf16>, %arg2: memref<1474560xf16>, %arg3: memref<1474560xf16> {rock.prefill = 0.000000e+00 : f16}) attributes {block_size = 128 : i32, enable_splitk_for_tuning, features = #rock, grid_size = 512 : i32, kernel, mhal.arch = "amdgcn-amd-amdhsa:gfx942:sramecc+:xnack-", num_cu = 304 : i32} { + %0 = rock.transform %arg0 by (d1 * 360 + d2)> by [ ["raw"] at [0]>, [] at []>] bounds = [1, 4096, 360] -> [1474560]> : memref<1474560xf16> to memref<1x4096x360xf16> + %1 = rock.transform %arg1 by (d1 * 4096 + d2)> by [ ["raw"] at [0]>, [] at []>] bounds = [1, 360, 4096] -> [1474560]> : memref<1474560xf16> to memref<1x360x4096xf16> + %2 = rock.transform %arg2 by (d1 * 360 + d2)> by [ ["raw"] at [0]>, [] at []>] bounds = [1, 4096, 360] -> [1474560]> : memref<1474560xf16> to memref<1x4096x360xf16> + %3 = rock.transform %arg3 by (d1 * 360 + d2)> by [ ["raw"] at [0]>, [] at []>] bounds = [1, 4096, 360] -> [1474560]> : memref<1474560xf16> to memref<1x4096x360xf16> + %4 = rock.transform %0 by (d0, d2, d1)> by [ ["gemmG"] at [0]>, ["gemm0K", "gemm0M"] at [2, 1]>] bounds = [1, 360, 4096] -> [1, 4096, 360]> : memref<1x4096x360xf16> to memref<1x360x4096xf16> + %5 = rock.transform %1 by (d0, d3, d1 * 1024 + d2)> by [ ["gemmG", "gemmK"] at [0, 1]>, ["gemmN"] at [2]>] bounds = [1, 4, 1024, 360] -> [1, 360, 4096]> : memref<1x360x4096xf16> to memref<1x4x1024x360xf16> + %6 = rock.transform %5 by (0, d0, d2, d1)> by [ ["gemmG", "gemmNSplit"] at [0, 1]>, ["gemmN", "gemmK"] at [2, 3]>] bounds = [4, 360, 1024] -> [1, 4, 1024, 360]> : memref<1x4x1024x360xf16> to memref<4x360x1024xf16> + %7 = rock.transform %2 by (d0, d1 * 1024 + d2, d3)> by [ ["gemmG", "gemmO"] at [0, 2]>, ["gemmN"] at [1]>] bounds = [1, 4, 1024, 360] -> [1, 4096, 360]> : memref<1x4096x360xf16> to memref<1x4x1024x360xf16> + %8 = rock.transform %7 by (0, d0, d1, d2)> by [ ["gemmG", "gemmNSplit"] at [0, 1]>, ["gemmN", "gemmO"] at [2, 3]>] bounds = [4, 1024, 360] -> [1, 4, 1024, 360]> : memref<1x4x1024x360xf16> to memref<4x1024x360xf16> + %9 = rock.transform %4 by (d0, d1, d2)> by [ ["gemmG", "gemmK", "gemmM"] at [0, 1, 2]>, [] at []>] bounds = [1, 360, 4096, 4] -> [1, 360, 4096]> : memref<1x360x4096xf16> to memref<1x360x4096x4xf16> + %10 = rock.transform %9 by (0, d1, d2, d0)> by [ ["gemmG", "gemmNSplit"] at [0, 3]>, ["gemmK", "gemmM"] at [1, 2]>] bounds = [4, 360, 4096] -> [1, 360, 4096, 4]> : memref<1x360x4096x4xf16> to memref<4x360x4096xf16> + %11 = rock.transform %3 by (d0, d2, d3)> by [ [] at []>, ["gemmG", "gemmM", "gemmO"] at [0, 1, 2]>] bounds = [1, 4, 4096, 360] -> [1, 4096, 360]> : memref<1x4096x360xf16> to memref<1x4x4096x360xf16> + %12 = rock.transform %11 by (0, d0, d1, d2)> by [ ["gemmG", "gemmNSplit"] at [0, 1]>, ["gemmM", "gemmO"] at [2, 3]>] bounds = [4, 4096, 360] -> [1, 4, 4096, 360]> : memref<1x4x4096x360xf16> to memref<4x4096x360xf16> + %13 = rock.transform %10 by (d0, d1, d2)> by [ ["gemmG"] at [0]>, ["gemm0K"] at [1]>, ["gemm0N"] at [2]>] bounds = [4, 512, 4096] -> [4, 360, 4096]> : memref<4x360x4096xf16> to memref<4x512x4096xf16> + %14 = rock.transform %6 by (d0, d1, d2)> by [ ["gemmG"] at [0]>, ["gemm0K"] at [1]>, ["gemm0M"] at [2]>] bounds = [4, 512, 1024] -> [4, 360, 1024]> : memref<4x360x1024xf16> to memref<4x512x1024xf16> + %15 = rock.transform %8 by (d0, d1, d2)> by [ ["gemmG"] at [0]>, ["gemm1K"] at [1]>, ["gemm1M"] at [2]>] bounds = [4, 1024, 384] -> [4, 1024, 360]> : memref<4x1024x360xf16> to memref<4x1024x384xf16> + %16 = rock.transform %12 by (d0, d1, d2)> by [ ["gemmG"] at [0]>, ["gemm1N"] at [1]>, ["gemm1M"] at [2]>] bounds = [4, 4096, 384] -> [4, 4096, 360]> : memref<4x4096x360xf16> to memref<4x4096x384xf16> + + rock.gridwise_attention_accel(%13, %14, %15, %16) features = mfma|dot|atomic_add|atomic_add_f16|direct_to_lds_32b preSoftmaxOps = { + } {blockSize = 128 : i32, enableSoftmax = false, firstGemmIndices = array, splitKV = 1 : i32, gridSize = 512 : i32, operandSegmentSizes = array, params0 = #rock.mfma_gemm_params, params1 = #rock.mfma_gemm_params, storeMethod = #rock} : memref<4x512x4096xf16>, memref<4x512x1024xf16>, memref<4x1024x384xf16>, memref<4x4096x384xf16> return } diff --git a/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering_barriers.mlir b/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering_barriers.mlir index b869d8565338..538c34f32166 100644 --- a/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering_barriers.mlir +++ b/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering_barriers.mlir @@ -53,7 +53,7 @@ func.func @gridwise_attn_barriers_before_lds_write_issue_1811(%arg0: memref<4096 } memref.copy %alloc, %arg9 : memref<1x64x64xf16> to memref<1x64x64xf16> rock.yield - } {arch = "amdgcn-amd-amdhsa:gfx1100", blockSize = 64 : i32, firstGemmIndices = array, storeMethod = #rock, splitKV = 1 : i32, gridSize = 1 : i32, operandSegmentSizes = array, params0 = #rock.wmma_gemm_params, params1 = #rock.wmma_gemm_params} : memref<1x128x64xi8>, memref<1x128x64xi8>, memref<1x64x64xf16>, memref<1x1x1xi8>, memref<1x1x1xf16>, memref<1x64x64xf16> + } {arch = "amdgcn-amd-amdhsa:gfx1100", blockSize = 64 : i32, firstGemmIndices = array, storeMethod = #rock, splitKV = 1 : i32, gridSize = 1 : i32, operandSegmentSizes = array, params0 = #rock.wmma_gemm_params, params1 = #rock.wmma_gemm_params} : memref<1x128x64xi8>, memref<1x128x64xi8>, memref<1x64x64xf16>, memref<1x1x1xi8>, memref<1x1x1xf16>, memref<1x64x64xf16> return } @@ -81,7 +81,7 @@ func.func @gridwise_attn_barriers_before_lds_write_issue_1844(%arg0: memref<3276 %4 = rock.transform %0 by (d0, d2, d1)> by [ ["gemmG"] at [0]>, ["gemm0K", "gemm0M"] at [2, 1]>] bounds = [1, 128, 256] -> [1, 256, 128]> : memref<1x256x128xf16> to memref<1x128x256xf16> %5 = rock.transform %1 by (d0, d2, d1)> by [ ["gemmG"] at [0]>, ["gemm0K", "gemm0N"] at [2, 1]>] bounds = [1, 128, 256] -> [1, 256, 128]> : memref<1x256x128xf16> to memref<1x128x256xf16> rock.gridwise_attention_accel(%4, %5, %2, %3) preSoftmaxOps = { - } {blockSize = 256 : i32, enableSoftmax = false, firstGemmIndices = array, storeMethod = #rock, splitKV = 1 : i32, gridSize = 2 : i32, operandSegmentSizes = array, params0 = #rock.mfma_gemm_params, params1 = #rock.mfma_gemm_params} : memref<1x128x256xf16>, memref<1x128x256xf16>, memref<1x256x128xf16>, memref<1x256x128xf16> + } {blockSize = 256 : i32, enableSoftmax = false, firstGemmIndices = array, storeMethod = #rock, splitKV = 1 : i32, gridSize = 2 : i32, operandSegmentSizes = array, params0 = #rock.mfma_gemm_params, params1 = #rock.mfma_gemm_params} : memref<1x128x256xf16>, memref<1x128x256xf16>, memref<1x256x128xf16>, memref<1x256x128xf16> return } @@ -106,6 +106,6 @@ func.func @gridwise_attn_barriers_before_lds_write_nobarriers(%arg0: memref<1638 %4 = rock.transform %0 by (d0, d2, d1)> by [ ["gemmG"] at [0]>, ["gemm0K", "gemm0M"] at [2, 1]>] bounds = [1, 128, 128] -> [1, 128, 128]> : memref<1x128x128xf16> to memref<1x128x128xf16> %5 = rock.transform %1 by (d0, d2, d1)> by [ ["gemmG"] at [0]>, ["gemm0K", "gemm0N"] at [2, 1]>] bounds = [1, 128, 128] -> [1, 128, 128]> : memref<1x128x128xf16> to memref<1x128x128xf16> rock.gridwise_attention_accel(%4, %5, %2, %3) preSoftmaxOps = { - } {blockSize = 256 : i32, enableSoftmax = false, firstGemmIndices = array, storeMethod = #rock, splitKV = 1 : i32, gridSize = 1 : i32, operandSegmentSizes = array, params0 = #rock.mfma_gemm_params, params1 = #rock.mfma_gemm_params} : memref<1x128x128xf16>, memref<1x128x128xf16>, memref<1x128x128xf16>, memref<1x128x128xf16> + } {blockSize = 256 : i32, enableSoftmax = false, firstGemmIndices = array, storeMethod = #rock, splitKV = 1 : i32, gridSize = 1 : i32, operandSegmentSizes = array, params0 = #rock.mfma_gemm_params, params1 = #rock.mfma_gemm_params} : memref<1x128x128xf16>, memref<1x128x128xf16>, memref<1x128x128xf16>, memref<1x128x128xf16> return } diff --git a/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering_gqa.mlir b/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering_gqa.mlir index 0fa8f53d5c4f..e3f37175f5e3 100644 --- a/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering_gqa.mlir +++ b/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering_gqa.mlir @@ -63,6 +63,6 @@ func.func @gridwise_attn_causal_scale_gqa(%arg0: memref<8192xf16>, %arg1: memref } memref.copy %alloc, %arg8 : memref<64x1x8192xf16> to memref<64x1x8192xf16> rock.yield - } {blockSize = 32 : i32, causal, firstGemmIndices = array, gridSize = 8 : i32, numRepeatsGQA = 8 : index, operandSegmentSizes = array, params0 = #rock.wmma_gemm_params, params1 = #rock.wmma_gemm_params, prePadG0N = 8 : index, softmaxType = f32, splitKV = 1 : i32, storeMethod = #rock} : memref<8x128x32xf16>, memref<8x128x8192xf16>, memref<8x8192x128xf16>, memref<64x1x8192xf16>, memref<8x32x128xf16>, memref<8x32xf16> + } {blockSize = 32 : i32, causal, firstGemmIndices = array, gridSize = 8 : i32, numRepeatsGQA = 8 : index, operandSegmentSizes = array, params0 = #rock.wmma_gemm_params, params1 = #rock.wmma_gemm_params, prePadG0N = 8 : index, softmaxType = f32, splitKV = 1 : i32, storeMethod = #rock} : memref<8x128x32xf16>, memref<8x128x8192xf16>, memref<8x8192x128xf16>, memref<64x1x8192xf16>, memref<8x32x128xf16>, memref<8x32xf16> return } diff --git a/mlir/test/Dialect/Rock/gridwise_gemm_accel_lowering.mlir b/mlir/test/Dialect/Rock/gridwise_gemm_accel_lowering.mlir index c8406dd52da7..a48a070d35d6 100644 --- a/mlir/test/Dialect/Rock/gridwise_gemm_accel_lowering.mlir +++ b/mlir/test/Dialect/Rock/gridwise_gemm_accel_lowering.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-opt -split-input-file -rock-gridwise-gemm-to-blockwise -rock-blockwise-load-tile-to-threadwise -rock-pipeline %s | FileCheck %s -#xdlops_gemm_params1 = #rock.mfma_gemm_params +#xdlops_gemm_params1 = #rock.mfma_gemm_params // CHECK-LABEL: @fp8_bf8_xdlops func.func @fp8_bf8_xdlops(%arg0: memref<1x128x128xf8E4M3FNUZ>, %arg1: memref<1x128x115200xf8E5M2FNUZ>, %arg2: memref<1x128x115200xf32>) attributes {block_size = 256 : i32, grid_size = 900 : i32, arch = "amdgcn-amd-amdhsa:gfx942", numCU = 228 : i32} { // The tuning testcase leads to padded buffers, we simplify here. @@ -41,7 +41,7 @@ func.func @fp8_bf8_xdlops(%arg0: memref<1x128x128xf8E4M3FNUZ>, %arg1: memref<1x1 // ----- -#xdlops_gemm_params1a = #rock.mfma_gemm_params +#xdlops_gemm_params1a = #rock.mfma_gemm_params // CHECK-LABEL: @fp8_bf8_xdlops_ocp func.func @fp8_bf8_xdlops_ocp(%arg0: memref<1x128x128xf8E4M3FN>, %arg1: memref<1x128x115200xf8E5M2>, %arg2: memref<1x128x115200xf32>) attributes {block_size = 256 : i32, grid_size = 900 : i32, arch = "amdgcn-amd-amdhsa:gfx950", numCU = 256 : i32} { // The tuning testcase leads to padded buffers, we simplify here. @@ -82,7 +82,7 @@ func.func @fp8_bf8_xdlops_ocp(%arg0: memref<1x128x128xf8E4M3FN>, %arg1: memref<1 // ----- -#xdlops_gemm_params2 = #rock.mfma_gemm_params +#xdlops_gemm_params2 = #rock.mfma_gemm_params // CHECK: @chiplet_grid func.func @chiplet_grid(%arg0: memref<1x32x128xf32>, %arg1: memref<1x32x256xf32>, %arg2: memref<1x128x256xf32>) attributes {block_size = 256 : i32, grid_size = 8 : i32, arch = "amdgcn-amd-amdhsa:gfx942", numCU = 228 : i32} { // CHECK: %[[BID:.+]] = rock.workgroup_id @@ -98,7 +98,7 @@ func.func @chiplet_grid(%arg0: memref<1x32x128xf32>, %arg1: memref<1x32x256xf32> // ----- -#xdlops_gemm_params2 = #rock.mfma_gemm_params +#xdlops_gemm_params2 = #rock.mfma_gemm_params // CHECK: @chiplet_grid_mi308 func.func @chiplet_grid_mi308(%arg0: memref<1x32x128xf32>, %arg1: memref<1x32x256xf32>, %arg2: memref<1x128x256xf32>) attributes {block_size = 256 : i32, grid_size = 8 : i32, arch = "amdgcn-amd-amdhsa:gfx942", numCU = 80 : i32} { // CHECK: %[[BID:.+]] = rock.workgroup_id @@ -114,7 +114,7 @@ func.func @chiplet_grid_mi308(%arg0: memref<1x32x128xf32>, %arg1: memref<1x32x25 // ----- -#xdlops_gemm_params_double_buffer = #rock.mfma_gemm_params +#xdlops_gemm_params_double_buffer = #rock.mfma_gemm_params // CHECK-LABEL: @fp8_bf8_xdlops_ocp_double_buffer func.func @fp8_bf8_xdlops_ocp_double_buffer(%arg0: memref<1x128x128xf8E4M3FN>, %arg1: memref<1x128x115200xf8E5M2>, %arg2: memref<1x128x115200xf32>) attributes {block_size = 256 : i32, grid_size = 900 : i32, arch = "amdgcn-amd-amdhsa:gfx950", numCU = 256 : i32} { // The tuning testcase leads to padded buffers, we simplify here. @@ -164,7 +164,7 @@ func.func @fp8_bf8_xdlops_ocp_double_buffer(%arg0: memref<1x128x128xf8E4M3FN>, % // ----- // Tests for scaled GEMM (FP4 with scales) -#xdlops_gemm_params_scaled = #rock.mfma_gemm_params +#xdlops_gemm_params_scaled = #rock.mfma_gemm_params // CHECK-LABEL: @scaled_gemm_fp4_basic func.func @scaled_gemm_fp4_basic(%arg0: memref<1x512x16xf4E2M1FN>, %arg1: memref<1x512x16xf4E2M1FN>, %arg2: memref<1x16x16xf32>, %scaleA: memref<1x512x16xf8E8M0FNU>, %scaleB: memref<1x512x16xf8E8M0FNU>) attributes {block_size = 64 : i32, grid_size = 1 : i32, kernel, arch = "amdgcn-amd-amdhsa:gfx950", num_cu = 256 : i64} { @@ -223,7 +223,7 @@ func.func @scaled_gemm_fp4_basic(%arg0: memref<1x512x16xf4E2M1FN>, %arg1: memref // ----- // Test scaled GEMM with different dimensions -#xdlops_gemm_params_scaled2 = #rock.mfma_gemm_params +#xdlops_gemm_params_scaled2 = #rock.mfma_gemm_params // CHECK-LABEL: @scaled_gemm_fp4_larger func.func @scaled_gemm_fp4_larger(%arg0: memref<1x512x32xf4E2M1FN>, %arg1: memref<1x512x32xf4E2M1FN>, %arg2: memref<1x32x32xf32>, %scaleA: memref<1x512x32xf8E8M0FNU>, %scaleB: memref<1x512x32xf8E8M0FNU>) attributes {block_size = 256 : i32, grid_size = 1 : i32, kernel, arch = "amdgcn-amd-amdhsa:gfx950", num_cu = 256 : i64} { @@ -274,3 +274,52 @@ func.func @scaled_gemm_fp4_larger(%arg0: memref<1x512x32xf4E2M1FN>, %arg1: memre rock.gridwise_gemm_accel(%arg0, %arg1, %arg2, %scaleA, %scaleB) storeMethod( set) features = mfma {blockSize = 256 : i32, gridSize = 1 : i32, params = #xdlops_gemm_params_scaled2} : memref<1x512x32xf4E2M1FN>, memref<1x512x32xf4E2M1FN>, memref<1x32x32xf32>, memref<1x512x32xf8E8M0FNU>, memref<1x512x32xf8E8M0FNU> return } + +// ----- + +#xdlops_gemm_params3 = #rock.mfma_gemm_params + +// CHECK: @gemm_wavespereu_outputswizzle +// CHECK-SAME: output_swizzle = 1 : i64, waves_per_eu = 4 : i64 +func.func @gemm_wavespereu_outputswizzle(%arg0: memref<1x32x128xf32>, %arg1: memref<1x32x256xf32>, %arg2: memref<1x128x256xf32>) attributes {block_size = 256 : i32, grid_size = 8 : i32, arch = "amdgcn-amd-amdhsa:gfx942", numCU = 228 : i32} { + rock.gridwise_gemm_accel(%arg0, %arg1, %arg2) storeMethod( set) {blockSize = 256 : i32, gridSize = 900 : i32, params = #xdlops_gemm_params3} : memref<1x32x128xf32>, memref<1x32x256xf32>, memref<1x128x256xf32> + return +} + +// ----- + +#xdlops_gemm_params_gridgroupsize = #rock.mfma_gemm_params +// CHECK: @grid_group_size +func.func @grid_group_size(%arg0: memref<1x32x128xf32>, %arg1: memref<1x32x256xf32>, %arg2: memref<1x128x256xf32>) attributes {block_size = 256 : i32, grid_size = 8 : i32, arch = "amdgcn-amd-amdhsa:gfx942", numCU = 228 : i32} { + // CHECK: %[[BID:.+]] = rock.workgroup_id + // CHECK-DAG: %[[CHIPLET_GRP_ID:.+]] = arith.remui %[[BID]], %c4 : index + // CHECK-DAG: %[[CHIPLET_BID:.+]] = arith.divui %[[BID]], %c4 : index + // CHECK-DAG: %[[CHIPLET_GRP_ID_LSHIFT:.+]] = arith.muli %[[CHIPLET_GRP_ID]], %c2 : index + // CHECK-DAG: %[[MAYBE_NEW_BID:.+]] = arith.addi %[[CHIPLET_BID]], %[[CHIPLET_GRP_ID_LSHIFT]] : index + // CHECK-DAG: %[[IS_TAIL_BID:.+]] = arith.cmpi sgt, %[[BID]], %c7 : index + // CHECK-DAG: %[[NEW_BID:.+]] = arith.select %[[IS_TAIL_BID]], %[[BID]], %[[MAYBE_NEW_BID]] : index + // CHECK-DAG: %[[NEW_BID2:.+]] = arith.remui %[[NEW_BID]], %c8 : index + // CHECK-DAG: %[[GROUD_ID:.+]] = arith.divui %[[NEW_BID2]], %c256 : index + // CHECK-DAG: %[[FIRST_BID_M:.+]] = arith.muli %[[GROUD_ID]], %c64 : index + rock.gridwise_gemm_accel(%arg0, %arg1, %arg2) storeMethod( set) {blockSize = 256 : i32, gridSize = 900 : i32, params = #xdlops_gemm_params_gridgroupsize} : memref<1x32x128xf32>, memref<1x32x256xf32>, memref<1x128x256xf32> + return +} + +// ----- + +#xdlops_gemm_params_gridgroupsize_default = #rock.mfma_gemm_params +// CHECK: @grid_group_size_default +func.func @grid_group_size_default(%arg0: memref<1x32x128xf32>, %arg1: memref<1x32x256xf32>, %arg2: memref<1x128x256xf32>) attributes {block_size = 256 : i32, grid_size = 8 : i32, arch = "amdgcn-amd-amdhsa:gfx942", numCU = 228 : i32} { + // CHECK: %[[BID:.+]] = rock.workgroup_id + // CHECK-DAG: %[[CHIPLET_GRP_ID:.+]] = arith.remui %[[BID]], %c4 : index + // CHECK-DAG: %[[CHIPLET_BID:.+]] = arith.divui %[[BID]], %c4 : index + // CHECK-DAG: %[[CHIPLET_GRP_ID_LSHIFT:.+]] = arith.muli %[[CHIPLET_GRP_ID]], %c2 : index + // CHECK-DAG: %[[MAYBE_NEW_BID:.+]] = arith.addi %[[CHIPLET_BID]], %[[CHIPLET_GRP_ID_LSHIFT]] : index + // CHECK-DAG: %[[IS_TAIL_BID:.+]] = arith.cmpi sgt, %[[BID]], %c7 : index + // CHECK-DAG: %[[NEW_BID:.+]] = arith.select %[[IS_TAIL_BID]], %[[BID]], %[[MAYBE_NEW_BID]] : index + // CHECK-DAG: %[[NEW_BID2:.+]] = arith.remui %[[NEW_BID]], %c8 : index + // CHECK-DAG: %[[GROUD_ID:.+]] = arith.divui %[[NEW_BID2]], %c64 : index + // CHECK-DAG: %[[FIRST_BID_M:.+]] = arith.muli %[[GROUD_ID]], %c16 : index + rock.gridwise_gemm_accel(%arg0, %arg1, %arg2) storeMethod( set) {blockSize = 256 : i32, gridSize = 900 : i32, params = #xdlops_gemm_params_gridgroupsize_default} : memref<1x32x128xf32>, memref<1x32x256xf32>, memref<1x128x256xf32> + return +} diff --git a/mlir/test/Dialect/Rock/gridwise_gemm_accel_lowering_invalid.mlir b/mlir/test/Dialect/Rock/gridwise_gemm_accel_lowering_invalid.mlir index ea4828dd297a..1234aca68cf0 100644 --- a/mlir/test/Dialect/Rock/gridwise_gemm_accel_lowering_invalid.mlir +++ b/mlir/test/Dialect/Rock/gridwise_gemm_accel_lowering_invalid.mlir @@ -8,7 +8,7 @@ // kpackPerBlock * mPerBlock * kpack * sizeof(f32) + kpackPerBlock * nPerBlock * kpack * sizeof(f32) // = 32 * 256 * 8 * 4 + 32 * 256 * 8 * 4 = 262144 + 262144 = 524288 bytes > 65536 // Format: A (G x K x M), B (G x K x N), C (G x M x N) -#xdlops_gemm_params_too_much_lds = #rock.mfma_gemm_params +#xdlops_gemm_params_too_much_lds = #rock.mfma_gemm_params func.func @excessive_lds_usage(%arg0: memref<1x256x256xf32>, %arg1: memref<1x256x256xf32>, %arg2: memref<1x256x256xf32>) attributes {block_size = 256 : i32, grid_size = 1 : i32, arch = "amdgcn-amd-amdhsa:gfx942", numCU = 304 : i32} { // expected-error @+2 {{requires too much LDS}} // expected-error @+1 {{failed to legalize operation 'rock.gridwise_gemm_accel'}} @@ -22,7 +22,7 @@ func.func @excessive_lds_usage(%arg0: memref<1x256x256xf32>, %arg1: memref<1x256 // kpackPerBlock * mPerBlock * kpack * sizeof(f32) + kpackPerBlock * nPerBlock * kpack * sizeof(f32) > 163840 // 16 * 512 * 8 * 4 + 16 * 512 * 8 * 4 = 262144 + 262144 = 524288 bytes > 163840 // Format: A (G x K x M), B (G x K x N), C (G x M x N) -#xdlops_gemm_params_gfx950_lds = #rock.mfma_gemm_params +#xdlops_gemm_params_gfx950_lds = #rock.mfma_gemm_params func.func @gfx950_lds_exceeded(%arg0: memref<1x128x512xf32>, %arg1: memref<1x128x512xf32>, %arg2: memref<1x512x512xf32>) attributes {block_size = 256 : i32, grid_size = 1 : i32, arch = "amdgcn-amd-amdhsa:gfx950", numCU = 256 : i32} { // expected-error @+2 {{requires too much LDS}} // expected-error @+1 {{failed to legalize operation 'rock.gridwise_gemm_accel'}} @@ -37,7 +37,7 @@ func.func @gfx950_lds_exceeded(%arg0: memref<1x128x512xf32>, %arg1: memref<1x128 // kpackPerBlock * mPerBlock * kpack * sizeof(f4E2M1FN) + kpackPerBlock * nPerBlock * kpack * sizeof(f4E2M1FN) // = 32 * 256 * 32 * 0.5 + 32 * 256 * 32 * 0.5 = 131072 + 131072 = 262144 bytes > 163840 // Format: A (G x K x M), B (G x K x N), C (G x M x N), scaleA (G x K x M), scaleB (G x K x N) -#xdlops_gemm_params_scaled_lds_exceeded = #rock.mfma_gemm_params +#xdlops_gemm_params_scaled_lds_exceeded = #rock.mfma_gemm_params func.func @scaled_gemm_lds_exceeded(%arg0: memref<1x1024x256xf4E2M1FN>, %arg1: memref<1x1024x256xf4E2M1FN>, %arg2: memref<1x256x256xf32>, %scaleA: memref<1x1024x256xf8E8M0FNU>, %scaleB: memref<1x1024x256xf8E8M0FNU>) attributes {block_size = 256 : i32, grid_size = 1 : i32, arch = "amdgcn-amd-amdhsa:gfx950", numCU = 256 : i32} { // expected-error @+2 {{requires too much LDS}} // expected-error @+1 {{failed to legalize operation 'rock.gridwise_gemm_accel'}} @@ -51,7 +51,7 @@ func.func @scaled_gemm_lds_exceeded(%arg0: memref<1x1024x256xf4E2M1FN>, %arg1: m // kpackPerBlock * mPerBlock * kpack * sizeof(f4E2M1FN) + kpackPerBlock * nPerBlock * kpack * sizeof(f4E2M1FN) // = 32 * 512 * 32 * 0.5 + 32 * 512 * 32 * 0.5 = 262144 + 262144 = 524288 bytes > 163840 // Format: A (G x K x M), B (G x K x N), C (G x M x N), scaleA (G x K x M), scaleB (G x K x N) -#xdlops_gemm_params_scaled_lds_exceeded2 = #rock.mfma_gemm_params +#xdlops_gemm_params_scaled_lds_exceeded2 = #rock.mfma_gemm_params func.func @scaled_gemm_lds_exceeded_alt(%arg0: memref<1x1024x512xf4E2M1FN>, %arg1: memref<1x1024x512xf4E2M1FN>, %arg2: memref<1x512x512xf32>, %scaleA: memref<1x1024x512xf8E8M0FNU>, %scaleB: memref<1x1024x512xf8E8M0FNU>) attributes {block_size = 256 : i32, grid_size = 1 : i32, arch = "amdgcn-amd-amdhsa:gfx950", numCU = 256 : i32} { // expected-error @+2 {{requires too much LDS}} // expected-error @+1 {{failed to legalize operation 'rock.gridwise_gemm_accel'}} diff --git a/mlir/test/Dialect/Rock/gridwise_gemm_conservative_lds_barriers.mlir b/mlir/test/Dialect/Rock/gridwise_gemm_conservative_lds_barriers.mlir index ff2c6c066f0f..ae908bbf3c16 100644 --- a/mlir/test/Dialect/Rock/gridwise_gemm_conservative_lds_barriers.mlir +++ b/mlir/test/Dialect/Rock/gridwise_gemm_conservative_lds_barriers.mlir @@ -90,7 +90,7 @@ module attributes {mhal.arch = "amdgcn-amd-amdhsa:gfx1201"} { } memref.copy %alloc, %arg11 : memref<12x1x384xf16> to memref<12x1x384xf16> rock.yield - } {blockSize = 32 : i32, firstGemmIndices = array, gridSize = 24 : i32, numRepeatsGQA = 2 : index, operandSegmentSizes = array, params0 = #rock.wmma_gemm_params, params1 = #rock.wmma_gemm_params, prePadG0N = 2 : index, softmaxType = f32, splitKV = 4 : i32, storeMethod = #rock} : memref<6x64x32xf16>, memref<6x64x384xf16>, memref<6x384x64xf16>, memref<12x1x384xf16>, memref<12x1x384xf16>, memref<6xi32>, memref<24x32x64xf16>, memref<24x32xf16> + } {blockSize = 32 : i32, firstGemmIndices = array, gridSize = 24 : i32, numRepeatsGQA = 2 : index, operandSegmentSizes = array, params0 = #rock.wmma_gemm_params, params1 = #rock.wmma_gemm_params, prePadG0N = 2 : index, softmaxType = f32, splitKV = 4 : i32, storeMethod = #rock} : memref<6x64x32xf16>, memref<6x64x384xf16>, memref<6x384x64xf16>, memref<12x1x384xf16>, memref<12x1x384xf16>, memref<6xi32>, memref<24x32x64xf16>, memref<24x32xf16> // CHECK: scf.for // CHECK: rock.blockwise_load_tile // CHECK: rock.blockwise_load_tile diff --git a/mlir/test/Dialect/Rock/integration/multibuffer/test_multi_buffer_full.mlir b/mlir/test/Dialect/Rock/integration/multibuffer/test_multi_buffer_full.mlir index d984f35e387f..4432a12d429c 100644 --- a/mlir/test/Dialect/Rock/integration/multibuffer/test_multi_buffer_full.mlir +++ b/mlir/test/Dialect/Rock/integration/multibuffer/test_multi_buffer_full.mlir @@ -28,7 +28,7 @@ #map24 = affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 64, (d2 mod 64) floordiv 32, d2 mod 32)> #map25 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, (d2 * 32 + d4) * 2 + d3)> #map26 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -#xldops_gemm_params = #rock.mfma_gemm_params +#xldops_gemm_params = #rock.mfma_gemm_params #transform_map = #rock.transform_map<#map by [ ["gemmG"] at [0]>, ["gemmK", "gemmM"] at [2, 1]>] bounds = [1, 1024, 384] -> [1, 384, 1024]> #transform_map1 = #rock.transform_map<#map1 by [ ["g"] at [0]>, ["k"] at [1]>, ["m"] at [2]>, [] at []>] bounds = [16, 1, 6, 16, 32, 8, 2, 8] -> [1, 1024, 384]> #transform_map2 = #rock.transform_map<#map2 by [ ["k_loop", "g_block", "m_block", "n_block"] at [0, 1, 2, 3]>, ["m_thread", "k_thread"] at [4, 5]>, ["m_iter", "k_iter"] at [6, 7]>] bounds = [16, 1, 6, 16, 256, 16] -> [16, 1, 6, 16, 32, 8, 2, 8]> diff --git a/mlir/test/Dialect/Rock/integration/multibuffer/test_multi_buffer_full_gfx950.mlir b/mlir/test/Dialect/Rock/integration/multibuffer/test_multi_buffer_full_gfx950.mlir index 1a1a52bde58b..91790c6f1ee9 100644 --- a/mlir/test/Dialect/Rock/integration/multibuffer/test_multi_buffer_full_gfx950.mlir +++ b/mlir/test/Dialect/Rock/integration/multibuffer/test_multi_buffer_full_gfx950.mlir @@ -28,7 +28,7 @@ #map24 = affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 64, (d2 mod 64) floordiv 32, d2 mod 32)> #map25 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, (d2 * 32 + d4) * 2 + d3)> #map26 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -#xldops_gemm_params = #rock.mfma_gemm_params +#xldops_gemm_params = #rock.mfma_gemm_params #transform_map = #rock.transform_map<#map by [ ["gemmG"] at [0]>, ["gemmK", "gemmM"] at [2, 1]>] bounds = [1, 1024, 384] -> [1, 384, 1024]> #transform_map1 = #rock.transform_map<#map1 by [ ["g"] at [0]>, ["k"] at [1]>, ["m"] at [2]>, [] at []>] bounds = [16, 1, 6, 16, 32, 8, 2, 8] -> [1, 1024, 384]> #transform_map2 = #rock.transform_map<#map2 by [ ["k_loop", "g_block", "m_block", "n_block"] at [0, 1, 2, 3]>, ["m_thread", "k_thread"] at [4, 5]>, ["m_iter", "k_iter"] at [6, 7]>] bounds = [16, 1, 6, 16, 256, 16] -> [16, 1, 6, 16, 32, 8, 2, 8]> diff --git a/mlir/test/Dialect/Rock/lds_transpose_attributes.mlir b/mlir/test/Dialect/Rock/lds_transpose_attributes.mlir index 2638a095701f..66b488335afb 100644 --- a/mlir/test/Dialect/Rock/lds_transpose_attributes.mlir +++ b/mlir/test/Dialect/Rock/lds_transpose_attributes.mlir @@ -4,7 +4,7 @@ kpackPerBlock = 16, mPerBlock = 64, nPerBlock = 64, kpack = 1, mPerWave = 32, nPerWave = 32, mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 3, - outputSwizzle = 2, forceUnroll = true> + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> module attributes {mhal.arch = "amdgcn-amd-amdhsa:gfx950"} { // CHECK-LABEL: func.func @test_lds_transpose_attributes @@ -37,7 +37,7 @@ module attributes {mhal.arch = "amdgcn-amd-amdhsa:gfx950"} { kpackPerBlock = 32, mPerBlock = 64, nPerBlock = 64, kpack = 1, mPerWave = 16, nPerWave = 64, mnPerXdl = 16, splitKFactor = 1, scheduleVersion = 4, - outputSwizzle = 2, forceUnroll = true> + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> module attributes {mhal.arch = "amdgcn-amd-amdhsa:gfx950"} { // CHECK-LABEL: func.func @test_lds_transpose_attributes_double_buffering diff --git a/mlir/test/Dialect/Rock/loadtile_to_threadwise_lowering.mlir b/mlir/test/Dialect/Rock/loadtile_to_threadwise_lowering.mlir index b645b970d9a7..6a191a9fbb72 100644 --- a/mlir/test/Dialect/Rock/loadtile_to_threadwise_lowering.mlir +++ b/mlir/test/Dialect/Rock/loadtile_to_threadwise_lowering.mlir @@ -26,7 +26,7 @@ func.func @doublebuffer(%arg0: memref<1x384x64xf32>) attributes {block_size = 25 // CHECK-NEXT: rock.yield // CHECK: {name = "LDSRead"} affine.for %arg1 = 0 to 2 { - rock.blockwise_load_tile %0[%arg1, %c0, %c0, %c0, %c0] LDS -> %lds -> %reg {elementType = f32, elementLoadType = f32, matrixParamsA = #rock.blockwise_matrix_params, matrixParamsB = #rock.blockwise_matrix_params, blockSize = 64 : i32, loadType = #rock, params = #rock.mfma_gemm_params} : memref<1x64x384xf32> LDS -> memref<4096xi8, #gpu.address_space> -> memref<16xf32, #gpu.address_space> + rock.blockwise_load_tile %0[%arg1, %c0, %c0, %c0, %c0] LDS -> %lds -> %reg {elementType = f32, elementLoadType = f32, matrixParamsA = #rock.blockwise_matrix_params, matrixParamsB = #rock.blockwise_matrix_params, blockSize = 64 : i32, loadType = #rock, params = #rock.mfma_gemm_params} : memref<1x64x384xf32> LDS -> memref<4096xi8, #gpu.address_space> -> memref<16xf32, #gpu.address_space> } return } @@ -49,7 +49,7 @@ func.func @default(%arg0: memref<1x384x64xf32>) attributes {block_size = 256 : i // CHECK-NEXT: rock.yield // CHECK: {name = "LDSWrite"} affine.for %arg1 = 0 to 2 { - rock.blockwise_load_tile %0[%arg1, %c0, %c0, %c0, %c0] LDS -> %lds {elementType = f32, elementLoadType = f32, matrixParamsA = #rock.blockwise_matrix_params, matrixParamsB = #rock.blockwise_matrix_params, blockSize = 64 : i32, loadType = #rock, params = #rock.mfma_gemm_params} : memref<1x64x384xf32> LDS -> memref<4096xi8, #gpu.address_space> + rock.blockwise_load_tile %0[%arg1, %c0, %c0, %c0, %c0] LDS -> %lds {elementType = f32, elementLoadType = f32, matrixParamsA = #rock.blockwise_matrix_params, matrixParamsB = #rock.blockwise_matrix_params, blockSize = 64 : i32, loadType = #rock, params = #rock.mfma_gemm_params} : memref<1x64x384xf32> LDS -> memref<4096xi8, #gpu.address_space> } return } @@ -72,7 +72,7 @@ func.func @bypasslds(%arg0: memref<1x384x64xf32>) attributes {block_size = 256 : // CHECK: rock.yield // CHECK: {name = "RegTranspose"} affine.for %arg1 = 0 to 2 { - rock.blockwise_load_tile %0[%arg1, %c0, %c0, %c0, %c0] -> %reg {elementType = f32, elementLoadType = f32, matrixParamsA = #rock.blockwise_matrix_params, matrixParamsB = #rock.blockwise_matrix_params, blockSize = 64 : i32, loadType = #rock, params = #rock.mfma_gemm_params} : memref<1x64x384xf32> -> memref<16xf32, #gpu.address_space> + rock.blockwise_load_tile %0[%arg1, %c0, %c0, %c0, %c0] -> %reg {elementType = f32, elementLoadType = f32, matrixParamsA = #rock.blockwise_matrix_params, matrixParamsB = #rock.blockwise_matrix_params, blockSize = 64 : i32, loadType = #rock, params = #rock.mfma_gemm_params} : memref<1x64x384xf32> -> memref<16xf32, #gpu.address_space> } return } diff --git a/mlir/test/Dialect/Rock/lowering_blockwise_gemm_accel.mlir b/mlir/test/Dialect/Rock/lowering_blockwise_gemm_accel.mlir index 5521ea5560e0..fa61917e4336 100644 --- a/mlir/test/Dialect/Rock/lowering_blockwise_gemm_accel.mlir +++ b/mlir/test/Dialect/Rock/lowering_blockwise_gemm_accel.mlir @@ -23,7 +23,7 @@ func.func @rock_blockwise_gemm_accel_two_results(%matrixA : memref<256xvector<2x mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<4xvector<16xf32>, #priv> += memref<4xf32, #priv> from memref<256xvector<2xf32>, #wg> * memref<4xf32, #priv> from memref<256xvector<2xf32>, #wg> return @@ -49,7 +49,7 @@ func.func @rock_blockwise_gemm_accel_one_result(%matrixA : memref<128xvector<8xi mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<1xvector<16xi32>, #priv> += memref<1xvector<4xi8>, #priv> from memref<128xvector<8xi8>, #wg> * memref<1xvector<4xi8>, #priv> from memref<128xvector<8xi8>, #wg> return @@ -77,7 +77,7 @@ func.func @rock_blockwise_gemm_accel_fp8_bf8(%matrixA : memref<1024xvector<8xf8E mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<4xvector<16xf32>, #gpu.address_space> += memref<4xvector<8xf8E4M3FNUZ>, #gpu.address_space> from memref<1024xvector<8xf8E4M3FNUZ>, #gpu.address_space> * memref<4xvector<8xf8E5M2FNUZ>, #gpu.address_space> from memref<1024xvector<8xf8E5M2FNUZ>, #gpu.address_space> return @@ -105,7 +105,7 @@ func.func @rock_blockwise_gemm_accel_fp8_bf8_ocp(%matrixA : memref<1024xvector<8 mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<4xvector<16xf32>, #gpu.address_space> += memref<4xvector<8xf8E4M3FN>, #gpu.address_space> from memref<1024xvector<8xf8E4M3FN>, #gpu.address_space> * memref<4xvector<8xf8E5M2>, #gpu.address_space> from memref<1024xvector<8xf8E5M2>, #gpu.address_space> return @@ -135,7 +135,7 @@ func.func @rock_blockwise_gemm_accel_fp8_bf8_ocp_double_buffer(%bufferA : memref mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<4xvector<16xf32>, #gpu.address_space> += memref<4xvector<8xf8E4M3FN>, #gpu.address_space> * memref<4xvector<8xf8E5M2>, #gpu.address_space> return @@ -175,7 +175,7 @@ func.func @rock_blockwise_gemm_accel_scaled_schedule_v2( mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 2, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<1xvector<16xf32>, #priv> += memref<8xvector<32xf4E2M1FN>, #priv> scaled by memref<8xvector<32xf8E8M0FNU>, #priv> from memref<512xvector<32xf8E8M0FNU>, #wg> * memref<8xvector<32xf4E2M1FN>, #priv> scaled by memref<8xvector<32xf8E8M0FNU>, #priv> from memref<512xvector<32xf8E8M0FNU>, #wg> return @@ -203,7 +203,7 @@ func.func @rock_blockwise_gemm_accel_direct_to_lds(%matrixA : memref<256xvector< mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 4, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<4xvector<16xf32>, #priv> += memref<16xi8, #priv> from memref<256xvector<2xf32>, #wg> * memref<16xi8, #priv> from memref<256xvector<2xf32>, #wg> return diff --git a/mlir/test/Dialect/Rock/lowering_blockwise_gemm_wmma.mlir b/mlir/test/Dialect/Rock/lowering_blockwise_gemm_wmma.mlir index 36e16e736bf0..4eaa067160b7 100644 --- a/mlir/test/Dialect/Rock/lowering_blockwise_gemm_wmma.mlir +++ b/mlir/test/Dialect/Rock/lowering_blockwise_gemm_wmma.mlir @@ -25,7 +25,7 @@ func.func @rock_blockwise_gemm_accel_wmma(%matrixA : memref<16xvector<8xf16>, #w mnPerXdl = 16, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<1xvector<8xf32>, #priv> += memref<1xvector<16xf16>, #priv> from memref<16xvector<8xf16>, #wg> * memref<1xvector<16xf16>, #priv> from memref<16xvector<8xf16>, #wg> return @@ -54,7 +54,7 @@ func.func @rock_blockwise_gemm_accel_wmma_largekpack(%matrixA : memref<32xvector kpack = 8, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<1xvector<8xf32>, #priv> += memref<1xvector<16xf16>, #priv> from memref<32xvector<8xf16>, #wg> * memref<1xvector<16xf16>, #priv> from memref<32xvector<8xf16>, #wg> return @@ -83,7 +83,7 @@ func.func @rock_blockwise_gemm_accel_wmma_int8(%matrixA : memref<32xvector<16xi8 kpack = 16, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<4xvector<8xi32>, #priv> += memref<4xvector<16xi8>, #priv> from memref<32xvector<16xi8>, #wg> * memref<4xvector<16xi8>, #priv> from memref<32xvector<16xi8>, #wg> return @@ -111,7 +111,7 @@ func.func @rock_blockwise_gemm_accel_wmma_double_buffer(%bufferA : memref<1xvect mnPerXdl = 16, splitKFactor = 1, scheduleVersion = 2, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<1xvector<8xf32>, #priv> += memref<1xvector<16xf16>, #priv> * memref<1xvector<16xf16>, #priv> return diff --git a/mlir/test/Dialect/Rock/lowering_gemm_linalg_splitk_normalization.mlir b/mlir/test/Dialect/Rock/lowering_gemm_linalg_splitk_normalization.mlir index 49c50ef8cf04..528bb39a4362 100644 --- a/mlir/test/Dialect/Rock/lowering_gemm_linalg_splitk_normalization.mlir +++ b/mlir/test/Dialect/Rock/lowering_gemm_linalg_splitk_normalization.mlir @@ -13,7 +13,7 @@ func.func @matmul_splitk_addf_constant(%arg0: memref<32768xf32>, %arg1: memref<1 %3 = rock.transform %0 by (d0 floordiv 4, d0 mod 4, d1)> by [ ["g", "d0"] at [0, 1]>, ["d1"] at [2]>] bounds = [8, 4096] -> [2, 4, 4096]> : memref<2x4x4096xf32> to memref<8x4096xf32> %4 = rock.transform %2 by (0, d0, d1)> by [ ["g"] at [0]>, ["d0", "d1"] at [1, 2]>] bounds = [4096, 4] -> [2, 4096, 4]> : memref<2x4096x4xf32> to memref<4096x4xf32> %5 = rock.transform %alloc by (d0 floordiv 4, d0 mod 4, d1)> by [ ["g", "d0"] at [0, 1]>, ["d1"] at [2]>] bounds = [8, 4] -> [2, 4, 4]> : memref<2x4x4xf32> to memref<8x4xf32> - rock.gemm %5 = %3 * %4 storeMethod = set {derivedBlockSize = 128 : i32, params = #rock.mfma_gemm_params, perf_config = "v3:16,32,4,16,16,4,4,1,2,1,1"} : memref<8x4xf32> = memref<8x4096xf32> * memref<4096x4xf32> + rock.gemm %5 = %3 * %4 storeMethod = set {derivedBlockSize = 128 : i32, params = #rock.mfma_gemm_params, perf_config = "v3:16,32,4,16,16,4,4,1,2,1,1"} : memref<8x4xf32> = memref<8x4096xf32> * memref<4096x4xf32> %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<2x4x4xf32> linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%alloc : memref<2x4x4xf32>) outs(%alloc_0 : memref<2x4x4xf32>) { ^bb0(%in: f32, %out: f32): @@ -38,7 +38,7 @@ func.func @matmul_splitk_subf_constant(%arg0: memref<32768xf32>, %arg1: memref<1 %3 = rock.transform %0 by (d0 floordiv 4, d0 mod 4, d1)> by [ ["g", "d0"] at [0, 1]>, ["d1"] at [2]>] bounds = [8, 4096] -> [2, 4, 4096]> : memref<2x4x4096xf32> to memref<8x4096xf32> %4 = rock.transform %2 by (0, d0, d1)> by [ ["g"] at [0]>, ["d0", "d1"] at [1, 2]>] bounds = [4096, 4] -> [2, 4096, 4]> : memref<2x4096x4xf32> to memref<4096x4xf32> %5 = rock.transform %alloc by (d0 floordiv 4, d0 mod 4, d1)> by [ ["g", "d0"] at [0, 1]>, ["d1"] at [2]>] bounds = [8, 4] -> [2, 4, 4]> : memref<2x4x4xf32> to memref<8x4xf32> - rock.gemm %5 = %3 * %4 storeMethod = set {derivedBlockSize = 128 : i32, params = #rock.mfma_gemm_params, perf_config = "v3:16,32,4,16,16,4,4,1,2,1,1"} : memref<8x4xf32> = memref<8x4096xf32> * memref<4096x4xf32> + rock.gemm %5 = %3 * %4 storeMethod = set {derivedBlockSize = 128 : i32, params = #rock.mfma_gemm_params, perf_config = "v3:16,32,4,16,16,4,4,1,2,1,1"} : memref<8x4xf32> = memref<8x4096xf32> * memref<4096x4xf32> %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<2x4x4xf32> linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%alloc : memref<2x4x4xf32>) outs(%alloc_0 : memref<2x4x4xf32>) { ^bb0(%in: f32, %out: f32): @@ -62,7 +62,7 @@ func.func @matmul_splitk_addf_same(%arg0: memref<32768xf32>, %arg1: memref<16384 %3 = rock.transform %0 by (d0 floordiv 4, d0 mod 4, d1)> by [ ["g", "d0"] at [0, 1]>, ["d1"] at [2]>] bounds = [8, 4096] -> [2, 4, 4096]> : memref<2x4x4096xf32> to memref<8x4096xf32> %4 = rock.transform %2 by (0, d0, d1)> by [ ["g"] at [0]>, ["d0", "d1"] at [1, 2]>] bounds = [4096, 4] -> [2, 4096, 4]> : memref<2x4096x4xf32> to memref<4096x4xf32> %5 = rock.transform %alloc by (d0 floordiv 4, d0 mod 4, d1)> by [ ["g", "d0"] at [0, 1]>, ["d1"] at [2]>] bounds = [8, 4] -> [2, 4, 4]> : memref<2x4x4xf32> to memref<8x4xf32> - rock.gemm %5 = %3 * %4 storeMethod = set {derivedBlockSize = 128 : i32, params = #rock.mfma_gemm_params, perf_config = "v3:16,32,4,16,16,4,4,1,2,1,1"} : memref<8x4xf32> = memref<8x4096xf32> * memref<4096x4xf32> + rock.gemm %5 = %3 * %4 storeMethod = set {derivedBlockSize = 128 : i32, params = #rock.mfma_gemm_params, perf_config = "v3:16,32,4,16,16,4,4,1,2,1,1"} : memref<8x4xf32> = memref<8x4096xf32> * memref<4096x4xf32> %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<2x4x4xf32> linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%alloc : memref<2x4x4xf32>) outs(%alloc_0 : memref<2x4x4xf32>) { ^bb0(%in: f32, %out: f32): @@ -85,7 +85,7 @@ func.func @matmul_splitk_subf_same(%arg0: memref<32768xf32>, %arg1: memref<16384 %3 = rock.transform %0 by (d0 floordiv 4, d0 mod 4, d1)> by [ ["g", "d0"] at [0, 1]>, ["d1"] at [2]>] bounds = [8, 4096] -> [2, 4, 4096]> : memref<2x4x4096xf32> to memref<8x4096xf32> %4 = rock.transform %2 by (0, d0, d1)> by [ ["g"] at [0]>, ["d0", "d1"] at [1, 2]>] bounds = [4096, 4] -> [2, 4096, 4]> : memref<2x4096x4xf32> to memref<4096x4xf32> %5 = rock.transform %alloc by (d0 floordiv 4, d0 mod 4, d1)> by [ ["g", "d0"] at [0, 1]>, ["d1"] at [2]>] bounds = [8, 4] -> [2, 4, 4]> : memref<2x4x4xf32> to memref<8x4xf32> - rock.gemm %5 = %3 * %4 storeMethod = set {derivedBlockSize = 128 : i32, params = #rock.mfma_gemm_params, perf_config = "v3:16,32,4,16,16,4,4,1,2,1,1"} : memref<8x4xf32> = memref<8x4096xf32> * memref<4096x4xf32> + rock.gemm %5 = %3 * %4 storeMethod = set {derivedBlockSize = 128 : i32, params = #rock.mfma_gemm_params, perf_config = "v3:16,32,4,16,16,4,4,1,2,1,1"} : memref<8x4xf32> = memref<8x4096xf32> * memref<4096x4xf32> %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<2x4x4xf32> linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%alloc : memref<2x4x4xf32>) outs(%alloc_0 : memref<2x4x4xf32>) { ^bb0(%in: f32, %out: f32): @@ -109,7 +109,7 @@ func.func @matmul_splitk_addf(%arg0: memref<2x4x4xf32>, %arg1: memref<32768xf32> %3 = rock.transform %0 by (d0 floordiv 4, d0 mod 4, d1)> by [ ["g", "d0"] at [0, 1]>, ["d1"] at [2]>] bounds = [8, 4096] -> [2, 4, 4096]> : memref<2x4x4096xf32> to memref<8x4096xf32> %4 = rock.transform %2 by (0, d0, d1)> by [ ["g"] at [0]>, ["d0", "d1"] at [1, 2]>] bounds = [4096, 4] -> [2, 4096, 4]> : memref<2x4096x4xf32> to memref<4096x4xf32> %5 = rock.transform %alloc by (d0 floordiv 4, d0 mod 4, d1)> by [ ["g", "d0"] at [0, 1]>, ["d1"] at [2]>] bounds = [8, 4] -> [2, 4, 4]> : memref<2x4x4xf32> to memref<8x4xf32> - rock.gemm %5 = %3 * %4 storeMethod = set {derivedBlockSize = 128 : i32, params = #rock.mfma_gemm_params, perf_config = "v3:16,32,4,16,16,4,4,1,2,1,1"} : memref<8x4xf32> = memref<8x4096xf32> * memref<4096x4xf32> + rock.gemm %5 = %3 * %4 storeMethod = set {derivedBlockSize = 128 : i32, params = #rock.mfma_gemm_params, perf_config = "v3:16,32,4,16,16,4,4,1,2,1,1"} : memref<8x4xf32> = memref<8x4096xf32> * memref<4096x4xf32> %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<2x4x4xf32> linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%alloc, %arg0 : memref<2x4x4xf32>, memref<2x4x4xf32>) outs(%alloc_0 : memref<2x4x4xf32>) { ^bb0(%in: f32, %in_1: f32, %out: f32): @@ -135,7 +135,7 @@ func.func @matmul_splitk_subf(%arg0: memref<2x4x4xf32>, %arg1: memref<32768xf32> %3 = rock.transform %0 by (d0 floordiv 4, d0 mod 4, d1)> by [ ["g", "d0"] at [0, 1]>, ["d1"] at [2]>] bounds = [8, 4096] -> [2, 4, 4096]> : memref<2x4x4096xf32> to memref<8x4096xf32> %4 = rock.transform %2 by (0, d0, d1)> by [ ["g"] at [0]>, ["d0", "d1"] at [1, 2]>] bounds = [4096, 4] -> [2, 4096, 4]> : memref<2x4096x4xf32> to memref<4096x4xf32> %5 = rock.transform %alloc by (d0 floordiv 4, d0 mod 4, d1)> by [ ["g", "d0"] at [0, 1]>, ["d1"] at [2]>] bounds = [8, 4] -> [2, 4, 4]> : memref<2x4x4xf32> to memref<8x4xf32> - rock.gemm %5 = %3 * %4 storeMethod = set {derivedBlockSize = 128 : i32, params = #rock.mfma_gemm_params, perf_config = "v3:16,32,4,16,16,4,4,1,2,1,1"} : memref<8x4xf32> = memref<8x4096xf32> * memref<4096x4xf32> + rock.gemm %5 = %3 * %4 storeMethod = set {derivedBlockSize = 128 : i32, params = #rock.mfma_gemm_params, perf_config = "v3:16,32,4,16,16,4,4,1,2,1,1"} : memref<8x4xf32> = memref<8x4096xf32> * memref<4096x4xf32> %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<2x4x4xf32> linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%alloc, %arg0 : memref<2x4x4xf32>, memref<2x4x4xf32>) outs(%alloc_0 : memref<2x4x4xf32>) { ^bb0(%in: f32, %in_1: f32, %out: f32): @@ -161,7 +161,7 @@ func.func @matmul_splitk_multiple(%arg0: memref<2x4x4xf32>, %arg1: memref<2x4x4x %3 = rock.transform %0 by (d0 floordiv 4, d0 mod 4, d1)> by [ ["g", "d0"] at [0, 1]>, ["d1"] at [2]>] bounds = [8, 4096] -> [2, 4, 4096]> : memref<2x4x4096xf32> to memref<8x4096xf32> %4 = rock.transform %2 by (0, d0, d1)> by [ ["g"] at [0]>, ["d0", "d1"] at [1, 2]>] bounds = [4096, 4] -> [2, 4096, 4]> : memref<2x4096x4xf32> to memref<4096x4xf32> %5 = rock.transform %alloc by (d0 floordiv 4, d0 mod 4, d1)> by [ ["g", "d0"] at [0, 1]>, ["d1"] at [2]>] bounds = [8, 4] -> [2, 4, 4]> : memref<2x4x4xf32> to memref<8x4xf32> - rock.gemm %5 = %3 * %4 storeMethod = set {derivedBlockSize = 128 : i32, params = #rock.mfma_gemm_params, perf_config = "v3:16,32,4,16,16,4,4,1,2,1,1"} : memref<8x4xf32> = memref<8x4096xf32> * memref<4096x4xf32> + rock.gemm %5 = %3 * %4 storeMethod = set {derivedBlockSize = 128 : i32, params = #rock.mfma_gemm_params, perf_config = "v3:16,32,4,16,16,4,4,1,2,1,1"} : memref<8x4xf32> = memref<8x4096xf32> * memref<4096x4xf32> %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<2x4x4xf16> linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%alloc, %arg0, %arg1 : memref<2x4x4xf32>, memref<2x4x4xf32>, memref<2x4x4xf16>) outs(%alloc_0 : memref<2x4x4xf16>) { ^bb0(%in: f32, %in_1: f32, %in_2: f16, %out: f16): @@ -201,7 +201,7 @@ func.func @convolution_multi_output_add(%arg0: memref<32768xf32>, %arg1: memref< %7 = rock.transform %6 by (d0, d1, d2, d3 + d4, d5 + d6)> by [ ["ni", "gi", "ci"] at [0, 1, 2]>, ["0ipad"] at [3]>, ["1ipad"] at [4]>] bounds = [2, 1, 4, 3, 64, 3, 64] -> [2, 1, 4, 66, 66]> : memref<2x1x4x66x66xf32> to memref<2x1x4x3x64x3x64xf32> %8 = rock.transform %7 by (d2 floordiv 4096, d0, d1 floordiv 9, (d1 mod 9) floordiv 3, (d2 mod 4096) floordiv 64, d1 mod 3, d2 mod 64)> by [ ["gi"] at [1]>, ["ci", "0", "1"] at [2, 3, 5]>, ["ni", "0o", "1o"] at [0, 4, 6]>] bounds = [1, 36, 8192] -> [2, 1, 4, 3, 64, 3, 64]> : memref<2x1x4x3x64x3x64xf32> to memref<1x36x8192xf32> %9 = rock.transform %4 by (d2 floordiv 4096, d0, d1, (d2 mod 4096) floordiv 64, d2 mod 64)> by [ ["go"] at [1]>, ["ko"] at [2]>, ["no", "0o", "1o"] at [0, 3, 4]>] bounds = [1, 320, 8192] -> [2, 1, 320, 64, 64]> : memref<2x1x320x64x64xf32> to memref<1x320x8192xf32> - rock.gemm %9 = tr %5 * %8 storeMethod = set {derivedBlockSize = 128 : i32, params = #rock.mfma_gemm_params} : memref<1x320x8192xf32> = memref<1x36x320xf32> * memref<1x36x8192xf32> + rock.gemm %9 = tr %5 * %8 storeMethod = set {derivedBlockSize = 128 : i32, params = #rock.mfma_gemm_params} : memref<1x320x8192xf32> = memref<1x36x320xf32> * memref<1x36x8192xf32> %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<2x320x64x64xf32> linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%alloc : memref<2x320x64x64xf32>) outs(%alloc_2 : memref<2x320x64x64xf32>) { ^bb0(%in: f32, %out: f32): @@ -250,7 +250,7 @@ func.func @convolution_multi_output_add_twice(%arg0: memref<32768xf32> {mhal.rea %7 = rock.transform %6 by (d0, d1, d2, d3 + d4, d5 + d6)> by [ ["ni", "gi", "ci"] at [0, 1, 2]>, ["0ipad"] at [3]>, ["1ipad"] at [4]>] bounds = [2, 1, 4, 3, 64, 3, 64] -> [2, 1, 4, 66, 66]> : memref<2x1x4x66x66xf32> to memref<2x1x4x3x64x3x64xf32> %8 = rock.transform %7 by (d2 floordiv 4096, d0, d1 floordiv 9, (d1 mod 9) floordiv 3, (d2 mod 4096) floordiv 64, d1 mod 3, d2 mod 64)> by [ ["gi"] at [1]>, ["ci", "0", "1"] at [2, 3, 5]>, ["ni", "0o", "1o"] at [0, 4, 6]>] bounds = [1, 36, 8192] -> [2, 1, 4, 3, 64, 3, 64]> : memref<2x1x4x3x64x3x64xf32> to memref<1x36x8192xf32> %9 = rock.transform %4 by (d2 floordiv 4096, d0, d1, (d2 mod 4096) floordiv 64, d2 mod 64)> by [ ["go"] at [1]>, ["ko"] at [2]>, ["no", "0o", "1o"] at [0, 3, 4]>] bounds = [1, 320, 8192] -> [2, 1, 320, 64, 64]> : memref<2x1x320x64x64xf32> to memref<1x320x8192xf32> - rock.gemm %9 = tr %5 * %8 storeMethod = set {derivedBlockSize = 128 : i32, params = #rock.mfma_gemm_params} : memref<1x320x8192xf32> = memref<1x36x320xf32> * memref<1x36x8192xf32> + rock.gemm %9 = tr %5 * %8 storeMethod = set {derivedBlockSize = 128 : i32, params = #rock.mfma_gemm_params} : memref<1x320x8192xf32> = memref<1x36x320xf32> * memref<1x36x8192xf32> %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<2x320x64x64xf32> linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%alloc : memref<2x320x64x64xf32>) outs(%alloc_2 : memref<2x320x64x64xf32>) { ^bb0(%in: f32, %out: f32): diff --git a/mlir/test/Dialect/Rock/lowering_output_swizzle.mlir b/mlir/test/Dialect/Rock/lowering_output_swizzle.mlir index a6d665ab23d7..45d5eaea03f0 100644 --- a/mlir/test/Dialect/Rock/lowering_output_swizzle.mlir +++ b/mlir/test/Dialect/Rock/lowering_output_swizzle.mlir @@ -168,7 +168,7 @@ func.func @rock_output_swizzle_multiple_outputs(%arg0: memref<1280xf16> {mhal.re %110 = rock.transform %41 by (d1)> by [ [] at []>, ["k"] at [0]>] bounds = [1, 8] -> [8]> : memref<8xvector<4xf16>, #gpu.address_space> to memref<1x8xvector<4xf16>, #gpu.address_space> %111 = rock.transform %42 by (d1)> by [ [] at []>, ["k"] at [0]>] bounds = [1, 8] -> [8]> : memref<8xvector<4xf16>, #gpu.address_space> to memref<1x8xvector<4xf16>, #gpu.address_space> %112 = rock.transform %43 by (d0 + d1)> by [ ["offset"] at [0]>] bounds = [1, 1] -> [1]> : memref<1xvector<16xf32>, #gpu.address_space> to memref<1x1xvector<16xf32>, #gpu.address_space> - rock.threadwise_gemm_accel %112 += %110 * %111 at[%arg6, %arg7, %arg8] {params = #rock.mfma_gemm_params} : memref<1x1xvector<16xf32>, #gpu.address_space> += memref<1x8xvector<4xf16>, #gpu.address_space> * memref<1x8xvector<4xf16>, #gpu.address_space> + rock.threadwise_gemm_accel %112 += %110 * %111 at[%arg6, %arg7, %arg8] {params = #rock.mfma_gemm_params} : memref<1x1xvector<16xf32>, #gpu.address_space> += memref<1x8xvector<4xf16>, #gpu.address_space> * memref<1x8xvector<4xf16>, #gpu.address_space> } } } @@ -218,7 +218,7 @@ func.func @rock_output_swizzle_multiple_outputs(%arg0: memref<1280xf16> {mhal.re %89 = rock.transform %41 by (d1)> by [ [] at []>, ["k"] at [0]>] bounds = [1, 8] -> [8]> : memref<8xvector<4xf16>, #gpu.address_space> to memref<1x8xvector<4xf16>, #gpu.address_space> %90 = rock.transform %42 by (d1)> by [ [] at []>, ["k"] at [0]>] bounds = [1, 8] -> [8]> : memref<8xvector<4xf16>, #gpu.address_space> to memref<1x8xvector<4xf16>, #gpu.address_space> %91 = rock.transform %43 by (d0 + d1)> by [ ["offset"] at [0]>] bounds = [1, 1] -> [1]> : memref<1xvector<16xf32>, #gpu.address_space> to memref<1x1xvector<16xf32>, #gpu.address_space> - rock.threadwise_gemm_accel %91 += %89 * %90 at[%arg5, %arg6, %arg7] {params = #rock.mfma_gemm_params} : memref<1x1xvector<16xf32>, #gpu.address_space> += memref<1x8xvector<4xf16>, #gpu.address_space> * memref<1x8xvector<4xf16>, #gpu.address_space> + rock.threadwise_gemm_accel %91 += %89 * %90 at[%arg5, %arg6, %arg7] {params = #rock.mfma_gemm_params} : memref<1x1xvector<16xf32>, #gpu.address_space> += memref<1x8xvector<4xf16>, #gpu.address_space> * memref<1x8xvector<4xf16>, #gpu.address_space> } } } diff --git a/mlir/test/Dialect/Rock/lowering_output_swizzle_tuning.mlir b/mlir/test/Dialect/Rock/lowering_output_swizzle_tuning.mlir new file mode 100644 index 000000000000..dc97f8ffca0d --- /dev/null +++ b/mlir/test/Dialect/Rock/lowering_output_swizzle_tuning.mlir @@ -0,0 +1,155 @@ +// RUN: rocmlir-opt -rock-output-swizzle %s | FileCheck %s + +#wg = #gpu.address_space +#priv = #gpu.address_space + +// CHECK-LABEL: func.func @rock_output_swizzle_heuristic +func.func @rock_output_swizzle_heuristic(%matrix_c: memref<1x1280x2048xf16>) attributes{arch = "gfx90a:sramecc+:xnack-", block_size = 256 : i32, grid_size = 320 : i32, output_swizzle = 2 : i64, kernel} { + %registers = rock.alloc() : memref<32xf16, #priv> + %registers2 = rock.alloc() : memref<32xf16, #priv> + %blockid = rock.workgroup_id : index + %threadid = rock.workitem_id : index + + %c22 = arith.constant 22 : index + %c320 = arith.constant 320 : index + %c20 = arith.constant 20 : index + %c352 = arith.constant 352 : index + %16 = arith.divui %blockid, %c320 : index + %17 = arith.remui %blockid, %c320 : index + %18 = arith.divui %17, %c352 : index + %19 = arith.muli %18, %c22 : index + %20 = arith.subi %c20, %19 : index + %21 = arith.minui %20, %c22 : index + %22 = arith.remui %17, %21 : index + %23 = arith.addi %19, %22 : index + %24 = arith.remui %17, %c352 : index + %25 = arith.divui %24, %21 : index + + %28 = rock.alloc() : memref<16384xi8, #wg> + %29 = rock.alloc() : memref<16384xi8, #wg> + + %c0 = arith.constant 0 : index + %view_29 = memref.view %29[%c0][] : memref<16384xi8, #wg> to memref<8192xf16, #wg> + %view_29_2 = rock.transform %view_29 by ((d1 * 256 + d0) * 8 + d2)> by [ ["flattenBlock"] at [0]>] bounds = [256, 4, 8] -> [8192]> : memref<8192xf16, #gpu.address_space> to memref<256x4x8xf16, #gpu.address_space> + %view_29_3 = rock.transform %view_29_2 by (d0, d1 floordiv 8, d1 mod 8)> by [ ["tid"] at [0]>, ["iter", "numElements"] at [1, 2]>] bounds = [256, 32] -> [256, 4, 8]> : memref<256x4x8xf16, #gpu.address_space> to memref<256x32xf16, #gpu.address_space> + rock.threadwise_read_into {forceUnroll, useIndexDiffs} [](%view_29_3) [%threadid] -> %registers : memref<256x32xf16, #wg> -> memref<32xf16, #priv> + + %view_28 = memref.view %28[%c0][] : memref<16384xi8, #wg> to memref<8192xf16, #wg> + %view_28_2 = rock.transform %view_28 by ((d1 * 256 + d0) * 8 + d2)> by [ ["flattenBlock"] at [0]>] bounds = [256, 4, 8] -> [8192]> : memref<8192xf16, #gpu.address_space> to memref<256x4x8xf16, #gpu.address_space> + %view_28_3 = rock.transform %view_28_2 by (d0, d1 floordiv 8, d1 mod 8)> by [ ["tid"] at [0]>, ["iter", "numElements"] at [1, 2]>] bounds = [256, 32] -> [256, 4, 8]> : memref<256x4x8xf16, #gpu.address_space> to memref<256x32xf16, #gpu.address_space> + rock.threadwise_read_into {forceUnroll, useIndexDiffs} [](%view_28_3) [%threadid] -> %registers2 : memref<256x32xf16, #wg> -> memref<32xf16, #priv> + + // add registers + %load = rock.in_bounds_load %registers[%c0] : memref<32xf16, #priv>, index -> vector<32xf16> + %load2 = rock.in_bounds_load %registers2[%c0] : memref<32xf16, #priv>, index -> vector<32xf16> + %add = arith.addf %load, %load2 : vector<32xf16> + rock.in_bounds_store %add -> %registers[%c0] : vector<32xf16> -> memref<32xf16, #priv>, index + + // CHECK: rock.alloc() : memref<16384xi8, #gpu.address_space> + // CHECK: rock.threadwise_write_all + // CHECK: rock.lds_barrier + // CHECK: rock.threadwise_read_into + // CHECK: rock.threadwise_write_all + rock.threadwise_write_all {forceUnroll, useIndexDiffs} %registers -> [#rock.transform_map (d0, d1, d2, d3 floordiv 64, (d3 mod 64) floordiv 32, d3 mod 32, d4 floordiv 16, 0, (d4 mod 16) floordiv 4, d4 mod 4)> by [ ["g_block", "m_block", "n_block"] at [0, 1, 2]>, ["wave", "m_tid", "n_tid"] at [3, 4, 5]>, ["i", "j", "vec_group", "vec_item"] at [6, 7, 8, 9]>] bounds = [1, 20, 16, 256, 32] -> [1, 20, 16, 4, 2, 32, 2, 1, 4, 4]>, #rock.transform_map (d0, d1, d2, d3 floordiv 2, d3 mod 2, d4, d5, 0, d6, 0, 0, d8, d9)> by [ ["g_block", "m_block", "n_block"] at [0, 1, 2]>, ["wave_m", "wave_n"] at [3, 4]>, ["m_tid", "n_tid"] at [5, 6]>, ["m_i", "n_i"] at [7, 8]>, ["blk_row", "blk_col"] at [9, 10]>, ["vec_group", "vec_item"] at [11, 12]>] bounds = [1, 20, 16, 4, 2, 32, 2, 1, 4, 4] -> [1, 20, 16, 2, 2, 2, 32, 1, 2, 1, 1, 4, 4]>, #rock.transform_map (d0, ((((d1 + d7) * 2 + d3 + d9) * 4 + d11) * 2 + d5) * 4 + d12, ((d2 * 2 + d8) * 2 + d4 + d10) * 32 + d6)> by [ ["gemmG"] at [0]>, ["gemmM"] at [1]>, ["gemmN"] at [2]>] bounds = [1, 20, 16, 2, 2, 2, 32, 1, 2, 1, 1, 4, 4] -> [1, 1280, 2048]>, #rock.transform_map (d0, d1, d2)> by [ ["gemmG"] at [0]>, ["gemmM"] at [1]>, ["gemmN"] at [2]>] bounds = [1, 1280, 2048] -> [1, 1280, 2048]>, #rock.transform_map (d0, d1, d2)> by [ ["gemmG"] at [0]>, ["gemmM"] at [1]>, ["gemmN"] at [2]>] bounds = [1, 1280, 2048] -> [1, 1280, 2048]>](%matrix_c) [%16, %23, %25, %threadid] by set : memref<32xf16, #gpu.address_space> -> memref<1x1280x2048xf16> + return +} + +// CHECK-LABEL: func.func @rock_output_swizzle_enabled +func.func @rock_output_swizzle_enabled(%matrix_c: memref<1x1280x2048xf16>) attributes{arch = "gfx90a:sramecc+:xnack-", block_size = 256 : i32, grid_size = 320 : i32, output_swizzle = 1 : i64, kernel} { + %registers = rock.alloc() : memref<32xf16, #priv> + %registers2 = rock.alloc() : memref<32xf16, #priv> + %blockid = rock.workgroup_id : index + %threadid = rock.workitem_id : index + + %c22 = arith.constant 22 : index + %c320 = arith.constant 320 : index + %c20 = arith.constant 20 : index + %c352 = arith.constant 352 : index + %16 = arith.divui %blockid, %c320 : index + %17 = arith.remui %blockid, %c320 : index + %18 = arith.divui %17, %c352 : index + %19 = arith.muli %18, %c22 : index + %20 = arith.subi %c20, %19 : index + %21 = arith.minui %20, %c22 : index + %22 = arith.remui %17, %21 : index + %23 = arith.addi %19, %22 : index + %24 = arith.remui %17, %c352 : index + %25 = arith.divui %24, %21 : index + + %28 = rock.alloc() : memref<16384xi8, #wg> + %29 = rock.alloc() : memref<16384xi8, #wg> + + %c0 = arith.constant 0 : index + %view_29 = memref.view %29[%c0][] : memref<16384xi8, #wg> to memref<8192xf16, #wg> + %view_29_2 = rock.transform %view_29 by ((d1 * 256 + d0) * 8 + d2)> by [ ["flattenBlock"] at [0]>] bounds = [256, 4, 8] -> [8192]> : memref<8192xf16, #gpu.address_space> to memref<256x4x8xf16, #gpu.address_space> + %view_29_3 = rock.transform %view_29_2 by (d0, d1 floordiv 8, d1 mod 8)> by [ ["tid"] at [0]>, ["iter", "numElements"] at [1, 2]>] bounds = [256, 32] -> [256, 4, 8]> : memref<256x4x8xf16, #gpu.address_space> to memref<256x32xf16, #gpu.address_space> + rock.threadwise_read_into {forceUnroll, useIndexDiffs} [](%view_29_3) [%threadid] -> %registers : memref<256x32xf16, #wg> -> memref<32xf16, #priv> + + %view_28 = memref.view %28[%c0][] : memref<16384xi8, #wg> to memref<8192xf16, #wg> + %view_28_2 = rock.transform %view_28 by ((d1 * 256 + d0) * 8 + d2)> by [ ["flattenBlock"] at [0]>] bounds = [256, 4, 8] -> [8192]> : memref<8192xf16, #gpu.address_space> to memref<256x4x8xf16, #gpu.address_space> + %view_28_3 = rock.transform %view_28_2 by (d0, d1 floordiv 8, d1 mod 8)> by [ ["tid"] at [0]>, ["iter", "numElements"] at [1, 2]>] bounds = [256, 32] -> [256, 4, 8]> : memref<256x4x8xf16, #gpu.address_space> to memref<256x32xf16, #gpu.address_space> + rock.threadwise_read_into {forceUnroll, useIndexDiffs} [](%view_28_3) [%threadid] -> %registers2 : memref<256x32xf16, #wg> -> memref<32xf16, #priv> + + // add registers + %load = rock.in_bounds_load %registers[%c0] : memref<32xf16, #priv>, index -> vector<32xf16> + %load2 = rock.in_bounds_load %registers2[%c0] : memref<32xf16, #priv>, index -> vector<32xf16> + %add = arith.addf %load, %load2 : vector<32xf16> + rock.in_bounds_store %add -> %registers[%c0] : vector<32xf16> -> memref<32xf16, #priv>, index + + // CHECK: rock.alloc() : memref<16384xi8, #gpu.address_space> + // CHECK: rock.threadwise_write_all + // CHECK: rock.lds_barrier + // CHECK: rock.threadwise_read_into + // CHECK: rock.threadwise_write_all + rock.threadwise_write_all {forceUnroll, useIndexDiffs} %registers -> [#rock.transform_map (d0, d1, d2, d3 floordiv 64, (d3 mod 64) floordiv 32, d3 mod 32, d4 floordiv 16, 0, (d4 mod 16) floordiv 4, d4 mod 4)> by [ ["g_block", "m_block", "n_block"] at [0, 1, 2]>, ["wave", "m_tid", "n_tid"] at [3, 4, 5]>, ["i", "j", "vec_group", "vec_item"] at [6, 7, 8, 9]>] bounds = [1, 20, 16, 256, 32] -> [1, 20, 16, 4, 2, 32, 2, 1, 4, 4]>, #rock.transform_map (d0, d1, d2, d3 floordiv 2, d3 mod 2, d4, d5, 0, d6, 0, 0, d8, d9)> by [ ["g_block", "m_block", "n_block"] at [0, 1, 2]>, ["wave_m", "wave_n"] at [3, 4]>, ["m_tid", "n_tid"] at [5, 6]>, ["m_i", "n_i"] at [7, 8]>, ["blk_row", "blk_col"] at [9, 10]>, ["vec_group", "vec_item"] at [11, 12]>] bounds = [1, 20, 16, 4, 2, 32, 2, 1, 4, 4] -> [1, 20, 16, 2, 2, 2, 32, 1, 2, 1, 1, 4, 4]>, #rock.transform_map (d0, ((((d1 + d7) * 2 + d3 + d9) * 4 + d11) * 2 + d5) * 4 + d12, ((d2 * 2 + d8) * 2 + d4 + d10) * 32 + d6)> by [ ["gemmG"] at [0]>, ["gemmM"] at [1]>, ["gemmN"] at [2]>] bounds = [1, 20, 16, 2, 2, 2, 32, 1, 2, 1, 1, 4, 4] -> [1, 1280, 2048]>, #rock.transform_map (d0, d1, d2)> by [ ["gemmG"] at [0]>, ["gemmM"] at [1]>, ["gemmN"] at [2]>] bounds = [1, 1280, 2048] -> [1, 1280, 2048]>, #rock.transform_map (d0, d1, d2)> by [ ["gemmG"] at [0]>, ["gemmM"] at [1]>, ["gemmN"] at [2]>] bounds = [1, 1280, 2048] -> [1, 1280, 2048]>](%matrix_c) [%16, %23, %25, %threadid] by set : memref<32xf16, #gpu.address_space> -> memref<1x1280x2048xf16> + return +} + +// CHECK-LABEL: func.func @rock_output_swizzle_disabled +func.func @rock_output_swizzle_disabled(%matrix_c: memref<1x1280x2048xf16>) attributes{arch = "gfx90a:sramecc+:xnack-", block_size = 256 : i32, grid_size = 320 : i32, output_swizzle = 0 : i64, kernel} { + %registers = rock.alloc() : memref<32xf16, #priv> + %registers2 = rock.alloc() : memref<32xf16, #priv> + %blockid = rock.workgroup_id : index + %threadid = rock.workitem_id : index + + %c22 = arith.constant 22 : index + %c320 = arith.constant 320 : index + %c20 = arith.constant 20 : index + %c352 = arith.constant 352 : index + %16 = arith.divui %blockid, %c320 : index + %17 = arith.remui %blockid, %c320 : index + %18 = arith.divui %17, %c352 : index + %19 = arith.muli %18, %c22 : index + %20 = arith.subi %c20, %19 : index + %21 = arith.minui %20, %c22 : index + %22 = arith.remui %17, %21 : index + %23 = arith.addi %19, %22 : index + %24 = arith.remui %17, %c352 : index + %25 = arith.divui %24, %21 : index + + %28 = rock.alloc() : memref<16384xi8, #wg> + %29 = rock.alloc() : memref<16384xi8, #wg> + // CHECK: rock.alloc() : memref<16384xi8, #gpu.address_space> + // CHECK: rock.alloc() : memref<16384xi8, #gpu.address_space> + + %c0 = arith.constant 0 : index + %view_29 = memref.view %29[%c0][] : memref<16384xi8, #wg> to memref<8192xf16, #wg> + %view_29_2 = rock.transform %view_29 by ((d1 * 256 + d0) * 8 + d2)> by [ ["flattenBlock"] at [0]>] bounds = [256, 4, 8] -> [8192]> : memref<8192xf16, #gpu.address_space> to memref<256x4x8xf16, #gpu.address_space> + %view_29_3 = rock.transform %view_29_2 by (d0, d1 floordiv 8, d1 mod 8)> by [ ["tid"] at [0]>, ["iter", "numElements"] at [1, 2]>] bounds = [256, 32] -> [256, 4, 8]> : memref<256x4x8xf16, #gpu.address_space> to memref<256x32xf16, #gpu.address_space> + rock.threadwise_read_into {forceUnroll, useIndexDiffs} [](%view_29_3) [%threadid] -> %registers : memref<256x32xf16, #wg> -> memref<32xf16, #priv> + + %view_28 = memref.view %28[%c0][] : memref<16384xi8, #wg> to memref<8192xf16, #wg> + %view_28_2 = rock.transform %view_28 by ((d1 * 256 + d0) * 8 + d2)> by [ ["flattenBlock"] at [0]>] bounds = [256, 4, 8] -> [8192]> : memref<8192xf16, #gpu.address_space> to memref<256x4x8xf16, #gpu.address_space> + %view_28_3 = rock.transform %view_28_2 by (d0, d1 floordiv 8, d1 mod 8)> by [ ["tid"] at [0]>, ["iter", "numElements"] at [1, 2]>] bounds = [256, 32] -> [256, 4, 8]> : memref<256x4x8xf16, #gpu.address_space> to memref<256x32xf16, #gpu.address_space> + rock.threadwise_read_into {forceUnroll, useIndexDiffs} [](%view_28_3) [%threadid] -> %registers2 : memref<256x32xf16, #wg> -> memref<32xf16, #priv> + + // add registers + %load = rock.in_bounds_load %registers[%c0] : memref<32xf16, #priv>, index -> vector<32xf16> + %load2 = rock.in_bounds_load %registers2[%c0] : memref<32xf16, #priv>, index -> vector<32xf16> + %add = arith.addf %load, %load2 : vector<32xf16> + rock.in_bounds_store %add -> %registers[%c0] : vector<32xf16> -> memref<32xf16, #priv>, index + + // CHECK-NOT: rock.alloc() : memref<16384xi8, #gpu.address_space> + rock.threadwise_write_all {forceUnroll, useIndexDiffs} %registers -> [#rock.transform_map (d0, d1, d2, d3 floordiv 64, (d3 mod 64) floordiv 32, d3 mod 32, d4 floordiv 16, 0, (d4 mod 16) floordiv 4, d4 mod 4)> by [ ["g_block", "m_block", "n_block"] at [0, 1, 2]>, ["wave", "m_tid", "n_tid"] at [3, 4, 5]>, ["i", "j", "vec_group", "vec_item"] at [6, 7, 8, 9]>] bounds = [1, 20, 16, 256, 32] -> [1, 20, 16, 4, 2, 32, 2, 1, 4, 4]>, #rock.transform_map (d0, d1, d2, d3 floordiv 2, d3 mod 2, d4, d5, 0, d6, 0, 0, d8, d9)> by [ ["g_block", "m_block", "n_block"] at [0, 1, 2]>, ["wave_m", "wave_n"] at [3, 4]>, ["m_tid", "n_tid"] at [5, 6]>, ["m_i", "n_i"] at [7, 8]>, ["blk_row", "blk_col"] at [9, 10]>, ["vec_group", "vec_item"] at [11, 12]>] bounds = [1, 20, 16, 4, 2, 32, 2, 1, 4, 4] -> [1, 20, 16, 2, 2, 2, 32, 1, 2, 1, 1, 4, 4]>, #rock.transform_map (d0, ((((d1 + d7) * 2 + d3 + d9) * 4 + d11) * 2 + d5) * 4 + d12, ((d2 * 2 + d8) * 2 + d4 + d10) * 32 + d6)> by [ ["gemmG"] at [0]>, ["gemmM"] at [1]>, ["gemmN"] at [2]>] bounds = [1, 20, 16, 2, 2, 2, 32, 1, 2, 1, 1, 4, 4] -> [1, 1280, 2048]>, #rock.transform_map (d0, d1, d2)> by [ ["gemmG"] at [0]>, ["gemmM"] at [1]>, ["gemmN"] at [2]>] bounds = [1, 1280, 2048] -> [1, 1280, 2048]>, #rock.transform_map (d0, d1, d2)> by [ ["gemmG"] at [0]>, ["gemmM"] at [1]>, ["gemmN"] at [2]>] bounds = [1, 1280, 2048] -> [1, 1280, 2048]>](%matrix_c) [%16, %23, %25, %threadid] by set : memref<32xf16, #gpu.address_space> -> memref<1x1280x2048xf16> + return +} diff --git a/mlir/test/Dialect/Rock/lowering_to_threadwise_accel.mlir b/mlir/test/Dialect/Rock/lowering_to_threadwise_accel.mlir index 7faa2da4250a..aea453692d32 100644 --- a/mlir/test/Dialect/Rock/lowering_to_threadwise_accel.mlir +++ b/mlir/test/Dialect/Rock/lowering_to_threadwise_accel.mlir @@ -48,7 +48,7 @@ func.func @rock_gemm_schedulev2(%arg0: memref<1x128x128xf16>, %arg1: memref<1x12 // CHECK-SAME: scheduleVersion = 2 // CHECK: name = "MMA" // CHECK: pipeline = #rock.pipeline<1> - rock.gridwise_gemm_accel(%arg0, %arg1, %arg2) storeMethod( set) {blockSize = 256 : i32, gridSize = 3600 : i32, params = #rock.mfma_gemm_params} : memref<1x128x128xf16>, memref<1x128x115200xf16>, memref<1x128x115200xf32> + rock.gridwise_gemm_accel(%arg0, %arg1, %arg2) storeMethod( set) {blockSize = 256 : i32, gridSize = 3600 : i32, params = #rock.mfma_gemm_params} : memref<1x128x128xf16>, memref<1x128x115200xf16>, memref<1x128x115200xf32> return } @@ -94,7 +94,7 @@ func.func @rock_gemm_schedulev1(%arg0: memref<1x128x128xf16>, %arg1: memref<1x12 // CHECK-SAME: scheduleVersion = 1 // CHECK: name = "MMA" // CHECK: pipeline = #rock.pipeline<2> - rock.gridwise_gemm_accel(%arg0, %arg1, %arg2) storeMethod( set) {blockSize = 256 : i32, gridSize = 3600 : i32, params = #rock.mfma_gemm_params} : memref<1x128x128xf16>, memref<1x128x115200xf16>, memref<1x128x115200xf32> + rock.gridwise_gemm_accel(%arg0, %arg1, %arg2) storeMethod( set) {blockSize = 256 : i32, gridSize = 3600 : i32, params = #rock.mfma_gemm_params} : memref<1x128x128xf16>, memref<1x128x115200xf16>, memref<1x128x115200xf32> return } @@ -140,7 +140,7 @@ func.func @rock_conv_gkc01_n01gc_ngk01_0_schedulev2(%arg0: memref<1x32x32xf16>, // CHECK: rock.threadwise_gemm_accel %[[outReg]] += %[[AReg]] * %[[BReg]] // CHECK-SAME: scheduleVersion = 2 // CHECK: name = "MMA" - rock.gridwise_gemm_accel(%arg0, %arg1, %arg2) storeMethod( set) {blockSize = 256 : i32, gridSize = 400 : i32, params = #rock.mfma_gemm_params} : memref<1x32x32xf16>, memref<1x32x25600xf16>, memref<1x32x25600xf32> + rock.gridwise_gemm_accel(%arg0, %arg1, %arg2) storeMethod( set) {blockSize = 256 : i32, gridSize = 400 : i32, params = #rock.mfma_gemm_params} : memref<1x32x32xf16>, memref<1x32x25600xf16>, memref<1x32x25600xf32> return } @@ -181,7 +181,7 @@ func.func @rock_conv_gkc01_n01gc_ngk01_0_schedulev1(%arg0: memref<1x32x32xf16>, // CHECK: rock.threadwise_gemm_accel %[[outReg]] += %[[AReg]] * %[[BReg]] // CHECK-SAME: scheduleVersion = 1 // CHECK: name = "MMA" - rock.gridwise_gemm_accel(%arg0, %arg1, %arg2) storeMethod( set) {blockSize = 256 : i32, gridSize = 400 : i32, params = #rock.mfma_gemm_params} : memref<1x32x32xf16>, memref<1x32x25600xf16>, memref<1x32x25600xf32> + rock.gridwise_gemm_accel(%arg0, %arg1, %arg2) storeMethod( set) {blockSize = 256 : i32, gridSize = 400 : i32, params = #rock.mfma_gemm_params} : memref<1x32x32xf16>, memref<1x32x25600xf16>, memref<1x32x25600xf32> return } @@ -230,7 +230,7 @@ func.func @rock_scaled_gemm_transA(%arg0: memref<1x128x64xf4E2M1FN>, %arg1: memr // CHECK: affine.for // CHECK: rock.threadwise_gemm_accel {{.*}} scaled by {{.*}} * {{.*}} scaled by {{.*}} : memref<1x1xvector<16xf32>, #gpu.address_space> += memref<1x2xvector<32xf4E2M1FN>, #gpu.address_space> scaled by memref<1x2xvector<32xf8E8M0FNU>, #gpu.address_space> * memref<1x2xvector<32xf4E2M1FN>, #gpu.address_space> scaled by memref<1x2xvector<32xf8E8M0FNU>, #gpu.address_space> // CHECK: name = "MMA" - rock.gridwise_gemm_accel(%arg0, %arg1, %arg2, %arg3, %arg4) storeMethod( set) {blockSize = 256 : i32, gridSize = 1 : i32, params = #rock.mfma_gemm_params} : memref<1x128x64xf4E2M1FN>, memref<1x128x64xf4E2M1FN>, memref<1x64x64xf32>, memref<1x128x64xf8E8M0FNU>, memref<1x128x64xf8E8M0FNU> + rock.gridwise_gemm_accel(%arg0, %arg1, %arg2, %arg3, %arg4) storeMethod( set) {blockSize = 256 : i32, gridSize = 1 : i32, params = #rock.mfma_gemm_params} : memref<1x128x64xf4E2M1FN>, memref<1x128x64xf4E2M1FN>, memref<1x64x64xf32>, memref<1x128x64xf8E8M0FNU>, memref<1x128x64xf8E8M0FNU> return } @@ -266,7 +266,7 @@ func.func @rock_scaled_gemm_transB(%arg0: memref<1x128x64xf4E2M1FN>, %arg1: memr // CHECK: affine.for // CHECK: rock.threadwise_gemm_accel {{.*}} scaled by {{.*}} * {{.*}} scaled by {{.*}} : memref<1x1xvector<16xf32>, #gpu.address_space> += memref<1x2xvector<32xf4E2M1FN>, #gpu.address_space> scaled by memref<1x2xvector<32xf8E8M0FNU>, #gpu.address_space> * memref<1x2xvector<32xf4E2M1FN>, #gpu.address_space> scaled by memref<1x2xvector<32xf8E8M0FNU>, #gpu.address_space> // CHECK: name = "MMA" - rock.gridwise_gemm_accel(%arg0, %arg1, %arg2, %arg3, %arg4) storeMethod( set) {blockSize = 256 : i32, gridSize = 1 : i32, params = #rock.mfma_gemm_params} : memref<1x128x64xf4E2M1FN>, memref<1x128x64xf4E2M1FN>, memref<1x64x64xf32>, memref<1x128x64xf8E8M0FNU>, memref<1x128x64xf8E8M0FNU> + rock.gridwise_gemm_accel(%arg0, %arg1, %arg2, %arg3, %arg4) storeMethod( set) {blockSize = 256 : i32, gridSize = 1 : i32, params = #rock.mfma_gemm_params} : memref<1x128x64xf4E2M1FN>, memref<1x128x64xf4E2M1FN>, memref<1x64x64xf32>, memref<1x128x64xf8E8M0FNU>, memref<1x128x64xf8E8M0FNU> return } @@ -302,7 +302,7 @@ func.func @rock_scaled_gemm_transAB(%arg0: memref<1x128x64xf4E2M1FN>, %arg1: mem // CHECK: affine.for // CHECK: rock.threadwise_gemm_accel {{.*}} scaled by {{.*}} * {{.*}} scaled by {{.*}} : memref<1x1xvector<16xf32>, #gpu.address_space> += memref<1x2xvector<32xf4E2M1FN>, #gpu.address_space> scaled by memref<1x2xvector<32xf8E8M0FNU>, #gpu.address_space> * memref<1x2xvector<32xf4E2M1FN>, #gpu.address_space> scaled by memref<1x2xvector<32xf8E8M0FNU>, #gpu.address_space> // CHECK: name = "MMA" - rock.gridwise_gemm_accel(%arg0, %arg1, %arg2, %arg3, %arg4) storeMethod( set) {blockSize = 256 : i32, gridSize = 1 : i32, params = #rock.mfma_gemm_params} : memref<1x128x64xf4E2M1FN>, memref<1x128x64xf4E2M1FN>, memref<1x64x64xf32>, memref<1x128x64xf8E8M0FNU>, memref<1x128x64xf8E8M0FNU> + rock.gridwise_gemm_accel(%arg0, %arg1, %arg2, %arg3, %arg4) storeMethod( set) {blockSize = 256 : i32, gridSize = 1 : i32, params = #rock.mfma_gemm_params} : memref<1x128x64xf4E2M1FN>, memref<1x128x64xf4E2M1FN>, memref<1x64x64xf32>, memref<1x128x64xf8E8M0FNU>, memref<1x128x64xf8E8M0FNU> return } @@ -338,7 +338,7 @@ func.func @rock_scaled_gemm_no_transpose(%arg0: memref<1x128x64xf4E2M1FN>, %arg1 // CHECK: affine.for // CHECK: rock.threadwise_gemm_accel {{.*}} scaled by {{.*}} * {{.*}} scaled by {{.*}} : memref<1x1xvector<16xf32>, #gpu.address_space> += memref<1x2xvector<32xf4E2M1FN>, #gpu.address_space> scaled by memref<1x2xvector<32xf8E8M0FNU>, #gpu.address_space> * memref<1x2xvector<32xf4E2M1FN>, #gpu.address_space> scaled by memref<1x2xvector<32xf8E8M0FNU>, #gpu.address_space> // CHECK: name = "MMA" - rock.gridwise_gemm_accel(%arg0, %arg1, %arg2, %arg3, %arg4) storeMethod( set) {blockSize = 256 : i32, gridSize = 1 : i32, params = #rock.mfma_gemm_params} : memref<1x128x64xf4E2M1FN>, memref<1x128x64xf4E2M1FN>, memref<1x64x64xf32>, memref<1x128x64xf8E8M0FNU>, memref<1x128x64xf8E8M0FNU> + rock.gridwise_gemm_accel(%arg0, %arg1, %arg2, %arg3, %arg4) storeMethod( set) {blockSize = 256 : i32, gridSize = 1 : i32, params = #rock.mfma_gemm_params} : memref<1x128x64xf4E2M1FN>, memref<1x128x64xf4E2M1FN>, memref<1x64x64xf32>, memref<1x128x64xf8E8M0FNU>, memref<1x128x64xf8E8M0FNU> return } @@ -354,8 +354,8 @@ func.func @gridwise_attn_schedulev2(%arg0: memref<1x384x64xf32>, %arg1: memref<1 rock.gridwise_attention_accel(%0, %arg1, %arg2, %arg3) preSoftmaxOps = {} { blockSize = 64 : i32, gridSize = 24 : i32, - params0 = #rock.mfma_gemm_params, - params1 = #rock.mfma_gemm_params, + params0 = #rock.mfma_gemm_params, + params1 = #rock.mfma_gemm_params, firstGemmIndices = array, splitKV = 1 : i32, storeMethod = #rock, diff --git a/mlir/test/Dialect/Rock/lowering_top_level.mlir b/mlir/test/Dialect/Rock/lowering_top_level.mlir index 6462f16cb6c7..34216fb426ab 100644 --- a/mlir/test/Dialect/Rock/lowering_top_level.mlir +++ b/mlir/test/Dialect/Rock/lowering_top_level.mlir @@ -31,8 +31,8 @@ #general_gemm_params0 = #rock.general_gemm_params #general_gemm_params1 = #rock.general_gemm_params -#xdlops_gemm_params0 = #rock.mfma_gemm_params -#xdlops_gemm_params1 = #rock.mfma_gemm_params +#xdlops_gemm_params0 = #rock.mfma_gemm_params +#xdlops_gemm_params1 = #rock.mfma_gemm_params func.func @rock_conv(%filter : memref<1x128x8x3x3xf32>, %input : memref<128x1x8x32x32xf32>, %output : memref<128x1x128x30x30xf32>) attributes {arch = "amdgcn-amd-amdhsa:gfx906"} { rock.conv(%filter, %input, %output) features = none { diff --git a/mlir/test/Dialect/Rock/lowering_wmma_gemm.mlir b/mlir/test/Dialect/Rock/lowering_wmma_gemm.mlir index 248a57587ece..84b447cb68ee 100644 --- a/mlir/test/Dialect/Rock/lowering_wmma_gemm.mlir +++ b/mlir/test/Dialect/Rock/lowering_wmma_gemm.mlir @@ -25,7 +25,7 @@ func.func @rock_accel_gemm_wmma(%matrixA : memref<1x4xvector<16xf16>, 5>, kpack = 16, splitKFactor = 3, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<1x1xvector<8xf32>, 5> += memref<1x4xvector<16xf16>, 5> * memref<1x4xvector<16xf16>, 5> return @@ -55,7 +55,7 @@ func.func @rock_accel_gemm_wmma_gfx12(%matrixA : memref<1x4xvector<8xf16>, 5>, kpack = 8, splitKFactor = 3, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<1x1xvector<8xf32>, 5> += memref<1x4xvector<8xf16>, 5> * memref<1x4xvector<8xf16>, 5> return @@ -87,7 +87,7 @@ func.func @rock_accel_gemm_wmma_repeats(%matrixA : memref<1x4xvector<16xf16>, 5> kpack = 16, splitKFactor = 3, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<2x2xvector<8xf32>, 5> += memref<1x4xvector<16xf16>, 5> * memref<1x4xvector<16xf16>, 5> return @@ -119,7 +119,7 @@ func.func @rock_accel_gemm_wmma_repeats_int8(%matrixA : memref<1x4xvector<16xi8> kpack = 16, splitKFactor = 3, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<2x2xvector<8xi32>, 5> += memref<1x4xvector<16xi8>, 5> * memref<1x4xvector<16xi8>, 5> return @@ -151,7 +151,7 @@ func.func @rock_accel_gemm_wmma_partial_repeats_int8(%matrixA : memref<1x2xvecto kpack = 16, splitKFactor = 3, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<2x2xvector<8xi32>, 5> += memref<1x2xvector<16xi8>, 5> * memref<1x2xvector<16xi8>, 5> return diff --git a/mlir/test/Dialect/Rock/lowering_xdlops_gemm.mlir b/mlir/test/Dialect/Rock/lowering_xdlops_gemm.mlir index e0953335a0d4..02903feb7977 100644 --- a/mlir/test/Dialect/Rock/lowering_xdlops_gemm.mlir +++ b/mlir/test/Dialect/Rock/lowering_xdlops_gemm.mlir @@ -28,7 +28,7 @@ func.func @rock_accel_gemm_reduction_nokpack(%matrixA : memref<1x2xf32, 5>, mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<1x2xvector<16xf32>, 5> += memref<1x2xf32, 5> * memref<1x2xf32, 5> return @@ -59,7 +59,7 @@ func.func @rock_accel_gemm_reduction_kpack_f32(%matrixA : memref<1x2xf32, 5>, mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<2x2xvector<16xf32>, 5> += memref<1x2xf32, 5> * memref<1x2xf32, 5> return @@ -90,7 +90,7 @@ func.func @rock_accel_gemm_reduction_kpack_i8(%matrixA : memref<1x4xvector<4xi8> mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<1x1xvector<16xi32>, 5> += memref<1x4xvector<4xi8>, 5> * memref<1x4xvector<4xi8>, 5> return @@ -120,7 +120,7 @@ func.func @accel_gemm_gfx90a_i8(%matrixA : memref<1x4xvector<4xi8>, 5>, mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<1x1xvector<16xi32>, 5> += memref<1x4xvector<4xi8>, 5> * memref<1x4xvector<4xi8>, 5> return @@ -147,7 +147,7 @@ func.func @accel_gemm_gfx942_i8(%matrixA : memref<1x4xvector<8xi8>, 5>, mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<1x1xvector<16xi32>, 5> += memref<1x4xvector<8xi8>, 5> * memref<1x4xvector<8xi8>, 5> return @@ -174,7 +174,7 @@ func.func @accel_gemm_gfx908_bf16(%matrixA : memref<1x4xvector<2xbf16>, 5>, mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<1x1xvector<16xf32>, 5> += memref<1x4xvector<2xbf16>, 5> * memref<1x4xvector<2xbf16>, 5> return @@ -201,7 +201,7 @@ func.func @accel_gemm_gfx90a_bf16(%matrixA : memref<1x4xvector<4xbf16>, 5>, mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<1x1xvector<16xf32>, 5> += memref<1x4xvector<4xbf16>, 5> * memref<1x4xvector<4xbf16>, 5> return @@ -230,7 +230,7 @@ func.func @accel_gemm_fp8_bf8(%matrixA : memref<1x4xvector<8xf8E4M3FNUZ>, #gpu.a mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<2x2xvector<16xf32>, #gpu.address_space> += memref<1x4xvector<8xf8E4M3FNUZ>, #gpu.address_space> * memref<1x4xvector<8xf8E5M2FNUZ>, #gpu.address_space> return @@ -259,7 +259,7 @@ func.func @accel_gemm_fp8_bf8_ocp(%matrixA : memref<1x4xvector<8xf8E4M3FN>, #gpu mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<2x2xvector<16xf32>, #gpu.address_space> += memref<1x4xvector<8xf8E4M3FN>, #gpu.address_space> * memref<1x4xvector<8xf8E5M2>, #gpu.address_space> return @@ -286,7 +286,7 @@ func.func @accel_gemm_gfx950_f16_16x16x32(%matrixA : memref<1x2xvector<8xf16>, 5 mnPerXdl = 16, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<1x1xvector<4xf32>, 5> += memref<1x2xvector<8xf16>, 5> * memref<1x2xvector<8xf16>, 5> return @@ -313,7 +313,7 @@ func.func @accel_gemm_gfx950_bf16_16x16x32(%matrixA : memref<1x2xvector<8xbf16>, mnPerXdl = 16, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<1x1xvector<4xf32>, 5> += memref<1x2xvector<8xbf16>, 5> * memref<1x2xvector<8xbf16>, 5> return @@ -340,7 +340,7 @@ func.func @accel_gemm_gfx950_f16_32x32x16(%matrixA : memref<1x2xvector<8xf16>, 5 mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<1x1xvector<16xf32>, 5> += memref<1x2xvector<8xf16>, 5> * memref<1x2xvector<8xf16>, 5> return @@ -367,7 +367,7 @@ func.func @accel_gemm_gfx950_bf16_32x32x16(%matrixA : memref<1x2xvector<8xbf16>, mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<1x1xvector<16xf32>, 5> += memref<1x2xvector<8xbf16>, 5> * memref<1x2xvector<8xbf16>, 5> return @@ -394,7 +394,7 @@ func.func @accel_gemm_gfx950_i8_32x32x32(%matrixA : memref<1x4xvector<16xi8>, 5> mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<1x1xvector<16xi32>, 5> += memref<1x4xvector<16xi8>, 5> * memref<1x4xvector<16xi8>, 5> return @@ -421,7 +421,7 @@ func.func @accel_gemm_gfx950_i8_16x16x64(%matrixA : memref<1x2xvector<16xi8>, 5> mnPerXdl = 16, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<1x1xvector<4xi32>, 5> += memref<1x2xvector<16xi8>, 5> * memref<1x2xvector<16xi8>, 5> return @@ -449,7 +449,7 @@ func.func @accel_gemm_gfx950_f32_16x16x128_fp4(%matrixA : memref<1x1xvector<32xf mnPerXdl = 16, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<1x1xvector<4xf32>, 5> += memref<1x1xvector<32xf4E2M1FN>, 5> * memref<1x1xvector<32xf4E2M1FN>, 5> return @@ -477,7 +477,7 @@ func.func @accel_gemm_gfx950_f32_32x32x64_fp4(%matrixA : memref<1x1xvector<32xf4 mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<1x1xvector<16xf32>, 5> += memref<1x1xvector<32xf4E2M1FN>, 5> * memref<1x1xvector<32xf4E2M1FN>, 5> return @@ -505,7 +505,7 @@ func.func @accel_gemm_gfx950_f32_64x64x512_fp4_1(%matrixA : memref<1x16xvector<3 mnPerXdl = 16, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<1x16xvector<4xf32>, 5> += memref<1x16xvector<32xf4E2M1FN>, 5> * memref<1x16xvector<32xf4E2M1FN>, 5> return @@ -535,7 +535,7 @@ func.func @accel_gemm_gfx950_f32_64x64x512_fp4_2(%matrixA : memref<1x8xvector<32 mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<1x1xvector<16xf32>, 5> += memref<1x8xvector<32xf4E2M1FN>, 5> * memref<1x8xvector<32xf4E2M1FN>, 5> return @@ -572,7 +572,7 @@ func.func @accel_gemm_gfx950_f32_16x16x128_fp4_scaled(%matrixA : memref<1x1xvect mnPerXdl = 16, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<1x1xvector<4xf32>, 5> += memref<1x1xvector<32xf4E2M1FN>, 5> scaled by memref<1x1xvector<32xf8E8M0FNU>, 5> * memref<1x1xvector<32xf4E2M1FN>, 5> scaled by memref<1x1xvector<32xf8E8M0FNU>, 5> return @@ -608,7 +608,7 @@ func.func @accel_gemm_gfx950_f32_32x32x64_fp4_scaled(%matrixA : memref<1x1xvecto mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<1x1xvector<16xf32>, 5> += memref<1x1xvector<32xf4E2M1FN>, 5> scaled by memref<1x1xvector<32xf8E8M0FNU>, 5> * memref<1x1xvector<32xf4E2M1FN>, 5> scaled by memref<1x1xvector<32xf8E8M0FNU>, 5> return @@ -644,7 +644,7 @@ func.func @accel_gemm_gfx950_f32_16x16x512_fp4_scaled_multi(%matrixA : memref<1x mnPerXdl = 16, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<1x1xvector<4xf32>, 5> += memref<1x4xvector<32xf4E2M1FN>, 5> scaled by memref<1x4xvector<32xf8E8M0FNU>, 5> * memref<1x4xvector<32xf4E2M1FN>, 5> scaled by memref<1x4xvector<32xf8E8M0FNU>, 5> return diff --git a/mlir/test/Dialect/Rock/ops.mlir b/mlir/test/Dialect/Rock/ops.mlir index 753910659116..ae96d2b79b11 100644 --- a/mlir/test/Dialect/Rock/ops.mlir +++ b/mlir/test/Dialect/Rock/ops.mlir @@ -230,7 +230,7 @@ func.func @rock_gridwise_gemm_accel(%A : memref<2x1024x1024xf32>, %B : memref<2x mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<2x1024x1024xf32>, memref<2x1024x2048xf32>, memref<2x1024x2048xf32> return @@ -253,7 +253,7 @@ func.func @rock_gridwise_scaled_gemm_accel(%A : memref<2x1024x1024xf4E2M1FN>, %B mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<2x1024x1024xf4E2M1FN>, memref<2x1024x2048xf4E2M1FN>, memref<2x1024x2048xf32>, memref<2x1024x1024xf8E8M0FNU>, memref<2x1024x2048xf8E8M0FNU> return @@ -286,7 +286,7 @@ func.func @rock_blockwise_gemm_accel_scaled(%matrixA : memref<256xvector<2xf4E2M mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<4xvector<16xf32>, #gpu.address_space> += memref<4xf4E2M1FN, #gpu.address_space> from memref<256xvector<2xf4E2M1FN>, #gpu.address_space> scaled by memref<4xf8E8M0FNU, #gpu.address_space> from memref<256xvector<2xf8E8M0FNU>, #gpu.address_space> * memref<4xf4E2M1FN, #gpu.address_space> from memref<256xvector<2xf4E2M1FN>, #gpu.address_space> scaled by memref<4xf8E8M0FNU, #gpu.address_space> from memref<256xvector<2xf8E8M0FNU>, #gpu.address_space> return @@ -313,7 +313,7 @@ func.func @rock_threadwise_gemm_accel_scaled(%matrixA : memref<1x4xvector<4xf4E2 kpack = 1, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<1x1xvector<32xf32>, 5> += memref<1x4xvector<4xf4E2M1FN>, 5> scaled by memref<1x4xvector<4xf8E8M0FNU>, 5> * memref<1x4xvector<4xf4E2M1FN>, 5> scaled by memref<1x4xvector<4xf8E8M0FNU>, 5> return @@ -368,8 +368,8 @@ func.func @gridwise_attn_atomic_add(%arg0: memref<1x384x64xf32>, %arg1: memref<1 rock.gridwise_attention_accel(%0, %arg1, %arg2, %arg3) preSoftmaxOps = {} { blockSize = 64 : i32, gridSize = 24 : i32, - params0 = #rock.mfma_gemm_params, - params1 = #rock.mfma_gemm_params, + params0 = #rock.mfma_gemm_params, + params1 = #rock.mfma_gemm_params, firstGemmIndices = array, storeMethod = #rock, splitKV = 1 : i32, diff --git a/mlir/test/Dialect/Rock/ops_2.mlir b/mlir/test/Dialect/Rock/ops_2.mlir index 833eacf66372..e97e23392781 100644 --- a/mlir/test/Dialect/Rock/ops_2.mlir +++ b/mlir/test/Dialect/Rock/ops_2.mlir @@ -149,7 +149,7 @@ func.func @rock_accel_gemm_one_result(%matrixA : memref<1x16xf32, 5>, kpack = 1, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<1x1xvector<32xf32>, 5> += memref<1x16xf32, 5> * memref<1x16xf32, 5> return @@ -176,7 +176,7 @@ func.func @rock_accel_gemm_one_result_f4(%matrixA : memref<1x1xvector<32xf4E2M1F kpack = 1, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<1x1xvector<4xf32>, 5> += memref<1x1xvector<32xf4E2M1FN>, 5> * memref<1x1xvector<32xf4E2M1FN>, 5> return @@ -208,7 +208,7 @@ func.func @rock_accel_gemm_two_results(%matrixA : memref<1x16xf32, 5>, kpack = 1, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<2x2xvector<32xf32>, 5> += memref<1x16xf32, 5> * memref<1x16xf32, 5> return @@ -237,7 +237,7 @@ func.func @rock_blockwise_gemm_accel_one_result(%matrixA : memref<12288xf32, 3>, kpack = 1, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<1xvector<32xf32>, 5> += memref<32xf32, 5> from memref<12288xf32, 3> * memref<16xf32, 5> from memref<12288xf32, 3> return @@ -266,7 +266,7 @@ func.func @rock_blockwise_gemm_accel_two_results(%matrixA : memref<12288xf32, 3> kpack = 1, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> } : memref<2xvector<32xf32>, 5> += memref<32xf32, 5> from memref<12288xf32, 3> * memref<16xf32, 5> from memref<12288xf32, 3> return diff --git a/mlir/test/Dialect/Rock/ops_blockwise_f16.mlir b/mlir/test/Dialect/Rock/ops_blockwise_f16.mlir index 47e990defeda..bec275799947 100644 --- a/mlir/test/Dialect/Rock/ops_blockwise_f16.mlir +++ b/mlir/test/Dialect/Rock/ops_blockwise_f16.mlir @@ -43,7 +43,9 @@ func.func @rock_xdlops_gemm_accel_one_result_f16(%matrixA : memref<1x4xvector<4x kpack = 1, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, + wavesPerEU = 0, + gridGroupSize = 0, forceUnroll = true> } : memref<1x1xvector<32xf32>, 5> += memref<1x4xvector<4xf16>, 5> * memref<1x4xvector<4xf16>, 5> return @@ -70,7 +72,9 @@ func.func @rock_xdlops_gemm_accel_two_results_f16(%matrixA : memref<1x4xvector<4 kpack = 1, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, + wavesPerEU = 0, + gridGroupSize = 0, forceUnroll = true> } : memref<1x1xvector<32xf32>, 5> += memref<1x4xvector<4xf16>, 5> * memref<1x4xvector<4xf16>, 5> return @@ -100,7 +104,9 @@ func.func @rock_blockwise_gemm_accel_one_result_f16(%matrixA : memref<8192xf16, kpack = 1, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, + wavesPerEU = 0, + gridGroupSize = 0, forceUnroll = true> } : memref<1xvector<32xf32>, 5> += memref<4xvector<4xf16>, 5> from memref<8192xf16, 3> * memref<4xvector<4xf16>, 5> from memref<4096xf16, 3> return @@ -129,7 +135,9 @@ func.func @rock_blockwise_gemm_accel_two_results_f16(%matrixA : memref<8192xf16, kpack = 1, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, + wavesPerEU = 0, + gridGroupSize = 0, forceUnroll = true> } : memref<2xvector<32xf32>, 5> += memref<4xvector<4xf16>, 5> from memref<8192xf16, 3> * memref<4xvector<4xf16>, 5> from memref<4096xf16, 3> return diff --git a/mlir/test/Dialect/Rock/ops_error.mlir b/mlir/test/Dialect/Rock/ops_error.mlir index 15a86fee4590..d19bb75f0dc2 100644 --- a/mlir/test/Dialect/Rock/ops_error.mlir +++ b/mlir/test/Dialect/Rock/ops_error.mlir @@ -7,8 +7,8 @@ func.func @gridwise_attn_atomic_add_fail(%arg0: memref<1x384x64xf32>, %arg1: mem rock.gridwise_attention_accel(%0, %arg1, %arg2, %arg3) preSoftmaxOps = {} { blockSize = 64 : i32, gridSize = 24 : i32, - params0 = #rock.mfma_gemm_params, - params1 = #rock.mfma_gemm_params, + params0 = #rock.mfma_gemm_params, + params1 = #rock.mfma_gemm_params, firstGemmIndices = array, storeMethod = #rock, splitKV = 1 : i32, @@ -304,7 +304,7 @@ func.func @gemm_scaled_inputs_not_float4e2m1(%a: memref<2x64x128xf16>, mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> func.func @gridwise_gemm_accel_scale_presence_a_only(%A: memref<1x4x8xf4E2M1FN>, %B: memref<1x4x16xf4E2M1FN>, %C: memref<1x8x16xf32>, %scaleA: memref<1x4x8xf8E8M0FNU>) attributes {arch = "amdgcn-amd-amdhsa:gfx950"} { @@ -429,7 +429,7 @@ func.func @rock_gridwise_gemm_accel_invalid_out_dtype(%A: memref<2x1024x1024xf4E mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> func.func @blockwise_gemm_accel_scale_buffer_presence_a_only( @@ -949,7 +949,7 @@ func.func @blockwise_gemm_accel_invalid_arch( kpack = 1, splitKFactor = 1, scheduleVersion = 1, - outputSwizzle = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> // Error case: Only scaleA provided diff --git a/mlir/test/Dialect/Rock/regularize.mlir b/mlir/test/Dialect/Rock/regularize.mlir index 07533b8ac50d..ce1019b23173 100644 --- a/mlir/test/Dialect/Rock/regularize.mlir +++ b/mlir/test/Dialect/Rock/regularize.mlir @@ -9,7 +9,7 @@ func.func private @bert_part_11__part_0(%arg0: memref<1x12x12x32xf32> {mhal.read %4 = rock.transform %3 by (d0, d1, d2)> by [ ["gemmG"] at [0]>, ["gemmK"] at [1]>, ["gemmM"] at [2]>] bounds = [12, 32, 16] -> [12, 32, 12]> : memref<12x32x12xf32> to memref<12x32x16xf32> %5 = rock.transform %0 by (d0, d1, d2)> by [ ["gemmG"] at [0]>, ["gemmK"] at [1]>, ["gemmN"] at [2]>] bounds = [12, 32, 16] -> [12, 32, 12]> : memref<12x32x12xf32> to memref<12x32x16xf32> %6 = rock.transform %2 by (d0, d1, d2)> by [ ["gemmG"] at [0]>, ["gemmM"] at [1]>, ["gemmN"] at [2]>] bounds = [12, 16, 16] -> [12, 12, 12]> : memref<12x12x12xf32> to memref<12x16x16xf32> - rock.gridwise_gemm_accel(%4, %5, %6) storeMethod( set) {blockSize = 64 : i32, gridSize = 12 : i32, params = #rock.mfma_gemm_params} : memref<12x32x16xf32>, memref<12x32x16xf32>, memref<12x16x16xf32> + rock.gridwise_gemm_accel(%4, %5, %6) storeMethod( set) {blockSize = 64 : i32, gridSize = 12 : i32, params = #rock.mfma_gemm_params} : memref<12x32x16xf32>, memref<12x32x16xf32>, memref<12x16x16xf32> %7 = rock.transform %2 by (d0 * 12 + d1, d2, d3)> by [ ["dim0"] at [0]>, ["dim1"] at [1]>, ["dim2"] at [2]>] bounds = [1, 12, 12, 12] -> [12, 12, 12]> : memref<12x12x12xf32> to memref<1x12x12x12xf32> %8 = memref.collapse_shape %7 [[0, 1], [2], [3]] : memref<1x12x12x12xf32> into memref<12x12x12xf32> %9 = memref.collapse_shape %arg2 [] : memref<1x1x1x1xf32> into memref diff --git a/mlir/test/Dialect/Rock/rock-shuffle-gemm-for-reductions.mlir b/mlir/test/Dialect/Rock/rock-shuffle-gemm-for-reductions.mlir index 6f139693a76c..8506119fab41 100644 --- a/mlir/test/Dialect/Rock/rock-shuffle-gemm-for-reductions.mlir +++ b/mlir/test/Dialect/Rock/rock-shuffle-gemm-for-reductions.mlir @@ -34,7 +34,7 @@ func.func @mlir_convolution_multi_reduce(%arg0: memref<320xf32>, %arg1: memref<3 // CHECK: %[[GEMM_OUT_C_TR3:.+]] = rock.transform %[[GEMM_OUT_C_TR2]] by (d0, d1 floordiv 64, (d1 mod 64) floordiv 2, d1 mod 2, d2 floordiv 256, (d2 mod 256) floordiv 128, d2 mod 128)> by [ ["G"] at [0]>, ["m_rh", "m_nr", "m_rl"] at [1, 2, 3]>, ["n_rh", "n_nr", "n_rl"] at [4, 5, 6]>] bounds = [1, 320, 8192] -> [1, 5, 32, 2, 32, 2, 128]> : memref<1x5x32x2x32x2x128xf32> to memref<1x320x8192xf32> %13 = rock.transform %8 by (d2 floordiv 4096, d0, d1, (d2 mod 4096) floordiv 64, d2 mod 64)> by [ ["go"] at [1]>, ["ko"] at [2]>, ["no", "0o", "1o"] at [0, 3, 4]>] bounds = [1, 320, 8192] -> [2, 1, 320, 64, 64]> : memref<2x1x320x64x64xf32> to memref<1x320x8192xf32> // CHECK: rock.gridwise_gemm_accel(%[[GEMM_IN_A_TR3]], %[[GEMM_IN_B_TR3]], %[[GEMM_OUT_C_TR3]]) - rock.gridwise_gemm_accel(%9, %12, %13) storeMethod( set) {blockSize = 256 : i32, gridSize = 320 : i32, params = #rock.mfma_gemm_params} : memref<1x36x320xf32>, memref<1x36x8192xf32>, memref<1x320x8192xf32> + rock.gridwise_gemm_accel(%9, %12, %13) storeMethod( set) {blockSize = 256 : i32, gridSize = 320 : i32, params = #rock.mfma_gemm_params} : memref<1x36x320xf32>, memref<1x36x8192xf32>, memref<1x320x8192xf32> %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<2x32x10x64x64xf32> %alloc_1 = memref.alloc() : memref<2621440xf32> %14 = rock.transform %alloc_1 by ((((d0 * 32 + d1) * 10 + d2) * 64 + d3) * 64 + d4)> by [ ["dim0"] at [0]>] bounds = [2, 32, 10, 64, 64] -> [2621440]> : memref<2621440xf32> to memref<2x32x10x64x64xf32> diff --git a/mlir/test/Dialect/Rock/test-fusion-and-pipeline.mlir b/mlir/test/Dialect/Rock/test-fusion-and-pipeline.mlir index 7918905dd241..2645ce889d28 100644 --- a/mlir/test/Dialect/Rock/test-fusion-and-pipeline.mlir +++ b/mlir/test/Dialect/Rock/test-fusion-and-pipeline.mlir @@ -85,7 +85,7 @@ module { } {name = "LDSWrite"} rock.stage { %21 = rock.workitem_id : index - rock.threadwise_gemm_accel %20 += %18 * %19 at[%21, %21, %21] {params = #rock.mfma_gemm_params} : memref<1x1xvector<4xf32>, #gpu.address_space> += memref<1x2xvector<4xf16>, #gpu.address_space> * memref<1x2xvector<4xf16>, #gpu.address_space> + rock.threadwise_gemm_accel %20 += %18 * %19 at[%21, %21, %21] {params = #rock.mfma_gemm_params} : memref<1x1xvector<4xf32>, #gpu.address_space> += memref<1x2xvector<4xf16>, #gpu.address_space> * memref<1x2xvector<4xf16>, #gpu.address_space> rock.yield } {name = "MMA"} } {pipeline = #rock.pipeline<2>} diff --git a/mlir/test/Dialect/Rock/test_multi_buffer.mlir b/mlir/test/Dialect/Rock/test_multi_buffer.mlir index 2ac656043576..c1d749c2819f 100644 --- a/mlir/test/Dialect/Rock/test_multi_buffer.mlir +++ b/mlir/test/Dialect/Rock/test_multi_buffer.mlir @@ -22,7 +22,7 @@ #map17 = affine_map<(d0, d1) -> (d0 floordiv 8, d0 mod 8, 0, d1 floordiv 8, d1 mod 8)> #map18 = affine_map<(d0, d1, d2, d3, d4) -> ((d0 + d2) * 8 + d4, d3 * 32 + d1)> #map19 = affine_map<(d0, d1) -> (d0 floordiv 32, d0 mod 32, 0, d1 floordiv 8, d1 mod 8)> -#xldops_gemm_params = #rock.mfma_gemm_params +#xldops_gemm_params = #rock.mfma_gemm_params #transform_map = #rock.transform_map<#map by [ ["gemmG"] at [0]>, ["gemmK", "gemmM"] at [2, 1]>] bounds = [1, 1024, 384] -> [1, 384, 1024]> #transform_map1 = #rock.transform_map<#map1 by [ ["g"] at [0]>, ["k"] at [1]>, ["m"] at [2]>, [] at []>] bounds = [16, 1, 6, 16, 32, 8, 2, 8] -> [1, 1024, 384]> #transform_map2 = #rock.transform_map<#map2 by [ ["k_loop", "g_block", "m_block", "n_block"] at [0, 1, 2, 3]>, ["m_thread", "k_thread"] at [4, 5]>, ["m_iter", "k_iter"] at [6, 7]>] bounds = [16, 1, 6, 16, 256, 16] -> [16, 1, 6, 16, 32, 8, 2, 8]> diff --git a/mlir/test/Dialect/Rock/toblockwise_attention_accel_lowering.mlir b/mlir/test/Dialect/Rock/toblockwise_attention_accel_lowering.mlir index d0b91192f5d9..510f671ed3cb 100644 --- a/mlir/test/Dialect/Rock/toblockwise_attention_accel_lowering.mlir +++ b/mlir/test/Dialect/Rock/toblockwise_attention_accel_lowering.mlir @@ -158,8 +158,8 @@ func.func @gridwise_attn_simple(%arg0: memref<1x384x64xf32>, %arg1: memref<1x64x rock.gridwise_attention_accel(%0, %arg1, %arg2, %arg3) preSoftmaxOps = {} { blockSize = 64 : i32, gridSize = 24 : i32, - params0 = #rock.mfma_gemm_params, - params1 = #rock.mfma_gemm_params, + params0 = #rock.mfma_gemm_params, + params1 = #rock.mfma_gemm_params, firstGemmIndices = array, splitKV = 1 : i32, storeMethod = #rock, @@ -199,8 +199,8 @@ func.func @gridwise_attn_schedulev2(%arg0: memref<1x384x64xf32>, %arg1: memref<1 rock.gridwise_attention_accel(%0, %arg1, %arg2, %arg3) preSoftmaxOps = {} { blockSize = 64 : i32, gridSize = 24 : i32, - params0 = #rock.mfma_gemm_params, - params1 = #rock.mfma_gemm_params, + params0 = #rock.mfma_gemm_params, + params1 = #rock.mfma_gemm_params, firstGemmIndices = array, splitKV = 1 : i32, storeMethod = #rock, diff --git a/mlir/test/Dialect/Rock/toblockwise_gemm_accel_lowering.mlir b/mlir/test/Dialect/Rock/toblockwise_gemm_accel_lowering.mlir index 9846c7d4bf0f..0fe170601667 100644 --- a/mlir/test/Dialect/Rock/toblockwise_gemm_accel_lowering.mlir +++ b/mlir/test/Dialect/Rock/toblockwise_gemm_accel_lowering.mlir @@ -1,7 +1,7 @@ // RUN: rocmlir-opt -split-input-file -rock-gridwise-gemm-to-blockwise -canonicalize -verify-diagnostics %s | FileCheck %s -#xdlops_gemm_params1 = #rock.mfma_gemm_params -#xdlops_gemm_params2 = #rock.mfma_gemm_params +#xdlops_gemm_params1 = #rock.mfma_gemm_params +#xdlops_gemm_params2 = #rock.mfma_gemm_params // CHECK-LABEL: @fp8_bf8_xdlops func.func @fp8_bf8_xdlops(%arg0: memref<1x128x128xf8E4M3FNUZ>, %arg1: memref<1x128x115200xf8E5M2FNUZ>, %arg2: memref<1x128x115200xf32>) attributes {block_size = 256 : i32, grid_size = 900 : i32, arch = "amdgcn-amd-amdhsa:gfx942", numCU = 228 : i32} { diff --git a/mlir/test/e2e/AttentionNonPowerOfTwoTileSize.toml b/mlir/test/e2e/AttentionNonPowerOfTwoTileSize.toml index 420c728b8f6c..6a480db242cf 100644 --- a/mlir/test/e2e/AttentionNonPowerOfTwoTileSize.toml +++ b/mlir/test/e2e/AttentionNonPowerOfTwoTileSize.toml @@ -29,7 +29,7 @@ prefix = "--transO=" [[axis]] name = "perf_config" -values = ["attn:v3:64,64,192,32,32,96,16,8,1,1,2,1", "attn:v3:64,64,96,8,32,48,16,8,1,1,2,1"] +values = ["attn:v3:64,64,192,32,32,96,16,8,1,1,2,0,1", "attn:v3:64,64,96,8,32,48,16,8,1,1,2,0,1"] prefix = "--perf_config " ## attention variant diff --git a/mlir/test/e2e/GemmVariantsNonPowerOfTwoTileSize.toml b/mlir/test/e2e/GemmVariantsNonPowerOfTwoTileSize.toml index 38a78a474e3e..002b73ec1865 100644 --- a/mlir/test/e2e/GemmVariantsNonPowerOfTwoTileSize.toml +++ b/mlir/test/e2e/GemmVariantsNonPowerOfTwoTileSize.toml @@ -24,7 +24,7 @@ prefix = "-t " [[axis]] name = "perf_config" -values = ["v4:96,96,8,48,48,16,8,1,1,2,1,1", "v4:192,96,8,96,48,16,8,1,1,2,1,1"] +values = ["v4:96,96,8,48,48,16,8,1,1,2,0,0,1,1", "v4:192,96,8,96,48,16,8,1,1,2,0,0,1,1"] prefix = "--perf_config " ## Gemm variants diff --git a/mlir/test/fusion/bug-1546-compile-failure.mlir b/mlir/test/fusion/bug-1546-compile-failure.mlir index 0dd83cfe90a6..4ce7dd7c45e4 100644 --- a/mlir/test/fusion/bug-1546-compile-failure.mlir +++ b/mlir/test/fusion/bug-1546-compile-failure.mlir @@ -15,7 +15,7 @@ func.func @mlir_dot_mul(%arg0: memref<6xf32>, %arg1: memref<12xf32>, %arg2: memr %3 = rock.transform %2 by (d0, d1, d2)> by [ ["gemmG"] at [0]>, ["gemmK"] at [1]>, ["gemmM"] at [2]>] bounds = [1, 16, 16] -> [1, 3, 2]> : memref<1x3x2xf32> to memref<1x16x16xf32> %4 = rock.transform %1 by (d0, d1, d2)> by [ ["gemmG"] at [0]>, ["gemmK"] at [1]>, ["gemmN"] at [2]>] bounds = [1, 16, 16] -> [1, 3, 4]> : memref<1x3x4xf32> to memref<1x16x16xf32> %5 = rock.transform %alloc by (d0, d1, d2)> by [ ["gemmG"] at [0]>, ["gemmM"] at [1]>, ["gemmN"] at [2]>] bounds = [1, 16, 16] -> [1, 2, 4]> : memref<1x2x4xf32> to memref<1x16x16xf32> - rock.gridwise_gemm_accel(%3, %4, %5) storeMethod( set) {blockSize = 64 : i32, gridSize = 1 : i32, params = #rock.mfma_gemm_params} : memref<1x16x16xf32>, memref<1x16x16xf32>, memref<1x16x16xf32> + rock.gridwise_gemm_accel(%3, %4, %5) storeMethod( set) {blockSize = 64 : i32, gridSize = 1 : i32, params = #rock.mfma_gemm_params} : memref<1x16x16xf32>, memref<1x16x16xf32>, memref<1x16x16xf32> %6 = rock.transform %alloc by (0, d0 floordiv 4, d0 mod 4)> by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [8] -> [1, 2, 4]> : memref<1x2x4xf32> to memref<8xf32> %7 = rock.transform %alloc by (0, d0, d1)> by [ ["col0", "col1"] at [0, 1]>, ["dim1"] at [2]>] bounds = [2, 4] -> [1, 2, 4]> : memref<1x2x4xf32> to memref<2x4xf32> %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<2x4xf32> diff --git a/mlir/test/fusion/bug-1550-reduction-fusion-compile-failure.mlir b/mlir/test/fusion/bug-1550-reduction-fusion-compile-failure.mlir index c0c7cbed0df4..e2a35ce4a689 100644 --- a/mlir/test/fusion/bug-1550-reduction-fusion-compile-failure.mlir +++ b/mlir/test/fusion/bug-1550-reduction-fusion-compile-failure.mlir @@ -19,7 +19,7 @@ func.func @mlir_convolution_reshape_mul_reshape_reduce_sum_reshape_mul_mul_resha %10 = rock.transform %5 by (d0, d1, d2)> by [ ["gemmG"] at [0]>, ["gemmK"] at [1]>, ["gemmM"] at [2]>] bounds = [1, 64, 384] -> [1, 36, 320]> : memref<1x36x320xf32> to memref<1x64x384xf32> %11 = rock.transform %8 by (d0, d1, d2)> by [ ["gemmG"] at [0]>, ["gemmK"] at [1]>, ["gemmN"] at [2]>] bounds = [1, 64, 8192] -> [1, 36, 8192]> : memref<1x36x8192xf32> to memref<1x64x8192xf32> %12 = rock.transform %9 by (d0, d1, d2)> by [ ["gemmG"] at [0]>, ["gemmM"] at [1]>, ["gemmN"] at [2]>] bounds = [1, 384, 8192] -> [1, 320, 8192]> : memref<1x320x8192xf32> to memref<1x384x8192xf32> - rock.gridwise_gemm_accel(%10, %11, %12) storeMethod( set) {blockSize = 256 : i32, gridSize = 1536 : i32, params = #rock.mfma_gemm_params} : memref<1x64x384xf32>, memref<1x64x8192xf32>, memref<1x384x8192xf32> + rock.gridwise_gemm_accel(%10, %11, %12) storeMethod( set) {blockSize = 256 : i32, gridSize = 1536 : i32, params = #rock.mfma_gemm_params} : memref<1x64x384xf32>, memref<1x64x8192xf32>, memref<1x384x8192xf32> %13 = rock.transform %alloc by (d0, d1 * 10 + d2, d3, d4)> by [ ["dim0"] at [0]>, ["dim1"] at [1]>, ["dim2"] at [2]>, ["dim3"] at [3]>] bounds = [2, 32, 10, 64, 64] -> [2, 320, 64, 64]> : memref<2x320x64x64xf32> to memref<2x32x10x64x64xf32> %14 = rock.transform %alloc by (d0 floordiv 1310720, (d0 mod 1310720) floordiv 4096, (d0 mod 4096) floordiv 64, d0 mod 64)> by [ ["col0", "col1", "col2", "col3"] at [0, 1, 2, 3]>] bounds = [2621440] -> [2, 320, 64, 64]> : memref<2x320x64x64xf32> to memref<2621440xf32> %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<2x32x10x64x64xf32> diff --git a/mlir/test/fusion/linalg-generic-with-atomic-store.mlir b/mlir/test/fusion/linalg-generic-with-atomic-store.mlir index d3e20b62c33a..c8a54cc45439 100644 --- a/mlir/test/fusion/linalg-generic-with-atomic-store.mlir +++ b/mlir/test/fusion/linalg-generic-with-atomic-store.mlir @@ -6,7 +6,7 @@ module attributes {mhal.arch = "amdgcn-amd-amdhsa:gfx90a"} { func.func @rock_gemm(%arg0: memref<1x1024x1024xf32>, %arg1: memref<1x1024x512xf32>, %arg2: memref<1x1024x512xi16> {rock.prefill = 0.000000e+00 : f32}) attributes {block_size = 256 : i32, kernel, mhal.arch = "amdgcn-amd-amdhsa:gfx90a"} { %alloc = memref.alloc() : memref<1x1024x512xf32> - rock.gemm %alloc = %arg0 * %arg1 features = mfma|dot|atomic_add|atomic_add_f16 storeMethod = atomic_add {arch = "amdgcn-amd-amdhsa:gfx90a", derivedBlockSize = 256 : i32, params = #rock.mfma_gemm_params} : memref<1x1024x512xf32> = memref<1x1024x1024xf32> * memref<1x1024x512xf32> + rock.gemm %alloc = %arg0 * %arg1 features = mfma|dot|atomic_add|atomic_add_f16 storeMethod = atomic_add {arch = "amdgcn-amd-amdhsa:gfx90a", derivedBlockSize = 256 : i32, params = #rock.mfma_gemm_params} : memref<1x1024x512xf32> = memref<1x1024x1024xf32> * memref<1x1024x512xf32> // expected-error @+1 {{'linalg.generic' op is infusible with non-`Set` store method}} linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%alloc : memref<1x1024x512xf32>) outs(%arg2 : memref<1x1024x512xi16>) { diff --git a/mlir/test/fusion/pr-e2e/attention/mixr-attention-padded-scale-cross.mlir b/mlir/test/fusion/pr-e2e/attention/mixr-attention-padded-scale-cross.mlir index 820eb00d2fc9..e54cb357bb5e 100644 --- a/mlir/test/fusion/pr-e2e/attention/mixr-attention-padded-scale-cross.mlir +++ b/mlir/test/fusion/pr-e2e/attention/mixr-attention-padded-scale-cross.mlir @@ -9,7 +9,7 @@ module { %0 = migraphx.dot %arg0, %arg1: <1x128x64xf32, 8192x64x1>, <1x64x27xf32, 1728x27x1> -> <1x128x27xf32, 3456x27x1> %biased = migraphx.mul %0, %arg3 : <1x128x27xf32, 3456x27x1>, <1x128x27xf32, 3456x27x1> -> <1x128x27xf32, 3456x27x1> %1 = migraphx.softmax %biased{axis = 2 : i64} : <1x128x27xf32, 3456x27x1> -> <1x128x27xf32, 3456x27x1> - %2 = migraphx.dot %1, %arg2 {perf_config = "attn:v3:32,32,32,32,32,32,32,1,1,1,2,1"} : <1x128x27xf32, 3456x27x1>, <1x27x64xf32, 1728x64x1> -> <1x128x64xf32, 8192x64x1> + %2 = migraphx.dot %1, %arg2 {perf_config = "attn:v3:32,32,32,32,32,32,32,1,1,1,2,0,1"} : <1x128x27xf32, 3456x27x1>, <1x27x64xf32, 1728x64x1> -> <1x128x64xf32, 8192x64x1> return %2 : !migraphx.shaped<1x128x64xf32, 8192x64x1> } } diff --git a/mlir/test/rocmlir-driver/populate_perf_config_gemm.mlir b/mlir/test/rocmlir-driver/populate_perf_config_gemm.mlir index e593e60dec62..6a529f78bbdd 100644 --- a/mlir/test/rocmlir-driver/populate_perf_config_gemm.mlir +++ b/mlir/test/rocmlir-driver/populate_perf_config_gemm.mlir @@ -9,7 +9,7 @@ // CHECK-SAME: features = mfma|dot // CHECK-SAME: arch = "amdgcn-amd-amdhsa:gfx1030" // CHECK-SAME: perf_config = "v3:128,64,4,64,64,1,1,1,2,1,1" -// AFFIX: #rock.mfma_gemm_params +// AFFIX: #rock.mfma_gemm_params // GRIDWISE: rock.gridwise_gemm_accel // RUN: rocmlir-gen --operation gemm -t f32 --arch gfx1030 --mfma on -n 128 -k 8 -m 256 --perf_config "v3:128,64,4,64,32,1,1,2,2,1,1" | FileCheck %s --check-prefix=GEN_V2 @@ -20,5 +20,5 @@ // CHECK-SAME: features = mfma|dot // CHECK-SAME: arch = "amdgcn-amd-amdhsa:gfx1030" // CHECK-SAME: perf_config = "v3:128,64,4,64,64,1,1,2,2,1,1" -// AFFIX_V2: #rock.mfma_gemm_params +// AFFIX_V2: #rock.mfma_gemm_params // GRIDWISE_V2: rock.gridwise_gemm_accel diff --git a/mlir/test/rocmlir-gen/emit-tuning-space.mlir b/mlir/test/rocmlir-gen/emit-tuning-space.mlir index 4c572be5a432..a0a1d5ba53fd 100644 --- a/mlir/test/rocmlir-gen/emit-tuning-space.mlir +++ b/mlir/test/rocmlir-gen/emit-tuning-space.mlir @@ -2,25 +2,25 @@ // CHECK-NAVI: v3:64,32,32,4,2,4,1,1,2 // RUN: rocmlir-gen --arch gfx90a --operation=gemm -t f32 -g 1 -m 64 -k 128 -n 64 --num_cu=104 --emit-tuning-space=full | FileCheck %s --check-prefixes=CHECK-MI -// CHECK-MI: v4:64,64,8,32,32,16,4,4,1,2,1,1 +// CHECK-MI: v4:64,64,8,32,32,16,4,4,1,2,0,0,1,1 // RUN: rocmlir-gen --arch gfx950 --operation=gemm -t f32 -g 1 -m 64 -k 128 -n 64 --num_cu=256 --emit-tuning-space=exhaustive | FileCheck %s --check-prefixes=CHECK-EXHAUSTIVE-DOUBLEBUFFER -// CHECK-EXHAUSTIVE-DOUBLEBUFFER: v4:64,64,8,16,16,16,4,4,2,2,1,1 +// CHECK-EXHAUSTIVE-DOUBLEBUFFER: v4:64,64,8,16,16,16,4,4,2,2,0,0,1,1 // RUN: rocmlir-gen --arch gfx950 --operation=gemm -t f32 -g 1 -m 64 -k 128 -n 64 --num_cu=256 --emit-tuning-space=exhaustive | FileCheck %s --check-prefixes=CHECK-EXHAUSTIVE-DIRECTTOLDS-SINGLE -// CHECK-EXHAUSTIVE-DIRECTTOLDS-SINGLE: v4:64,64,8,16,16,16,4,4,3,2,1,1 +// CHECK-EXHAUSTIVE-DIRECTTOLDS-SINGLE: v4:64,64,8,16,16,16,4,4,3,2,0,0,1,1 // RUN: rocmlir-gen --arch gfx950 --operation=gemm -t f32 -g 1 -m 64 -k 128 -n 64 --num_cu=256 --emit-tuning-space=exhaustive | FileCheck %s --check-prefixes=CHECK-EXHAUSTIVE-DIRECTTOLDS-DOUBLE -// CHECK-EXHAUSTIVE-DIRECTTOLDS-DOUBLE: v4:64,64,8,16,16,16,4,4,4,2,1,1 +// CHECK-EXHAUSTIVE-DIRECTTOLDS-DOUBLE: v4:64,64,8,16,16,16,4,4,4,2,0,0,1,1 // RUN: rocmlir-gen --arch gfx950 --operation=attention -t f32 -g 1 -head_dim_qk 32 -head_dim_v 32 -num_heads_q 128 -num_heads_kv 128 -seq_len_q 1024 -seq_len_k 1024 --num_cu=256 --emit-tuning-space=exhaustive | FileCheck %s --check-prefixes=CHECK-EXHAUSTIVE-ATTN-DIRECTTOLDS-SINGLE -// CHECK-EXHAUSTIVE-ATTN-DIRECTTOLDS-SINGLE: attn:v3:32,64,32,16,32,32,16,8,1,3,2,1 +// CHECK-EXHAUSTIVE-ATTN-DIRECTTOLDS-SINGLE: attn:v3:32,64,32,16,32,32,16,8,1,3,2,0,1 // RUN: rocmlir-gen --arch gfx950 --operation=attention -t f32 -g 1 -head_dim_qk 32 -head_dim_v 32 -num_heads_q 128 -num_heads_kv 128 -seq_len_q 1024 -seq_len_k 1024 --num_cu=256 --emit-tuning-space=exhaustive | FileCheck %s --check-prefixes=CHECK-EXHAUSTIVE-ATTN-DIRECTTOLDS-DOUBLE -// CHECK-EXHAUSTIVE-ATTN-DIRECTTOLDS-DOUBLE: attn:v3:32,64,32,16,32,32,16,8,1,4,2,1 +// CHECK-EXHAUSTIVE-ATTN-DIRECTTOLDS-DOUBLE: attn:v3:32,64,32,16,32,32,16,8,1,4,2,0,1 // RUN: rocmlir-gen -p --arch gfx1100 --operation=attention -t f16 --emit-tuning-space=exhaustive | FileCheck %s --check-prefixes=CHECK-SCHEDULING-ATTENTION -// CHECK-SCHEDULING-ATTENTION: attn:v3:128,128,256,64,32,32,16,16,1,2,2,1 +// CHECK-SCHEDULING-ATTENTION: attn:v3:128,128,256,64,32,32,16,16,1,2,2,0,1 // RUN: rocmlir-gen -p --arch gfx1100 --operation=gemm -t f16 --emit-tuning-space=exhaustive | FileCheck %s --check-prefixes=CHECK-SCHEDULING-GEMM -// CHECK-SCHEDULING-GEMM: v4:256,256,8,64,128,16,16,1,2,2,1,1 +// CHECK-SCHEDULING-GEMM: v4:256,256,8,64,128,16,16,1,2,2,0,0,1,1 diff --git a/mlir/tools/rocmlir-gen/rocmlir-gen.cpp b/mlir/tools/rocmlir-gen/rocmlir-gen.cpp index 80d8dce09f5b..4be5c5bf2338 100644 --- a/mlir/tools/rocmlir-gen/rocmlir-gen.cpp +++ b/mlir/tools/rocmlir-gen/rocmlir-gen.cpp @@ -1397,7 +1397,7 @@ static Value makeNDMemRef(OpBuilder &b, Value var, uint32_t ndim) { static std::pair getMandNPerBlock(OpBuilder builder, const GenParams ¶ms) { - // default perfConfig is attn:v3:32,32,32,32,32,32,16,1,1,1,2,1 + // default perfConfig is attn:v3:32,32,32,32,32,32,16,1,1,1,2,0,1 // keep in sync with AffixTuningParameters.cpp if (params.perfConfig.empty()) return {32, 32}; diff --git a/mlir/unittests/Dialect/Rock/InitParamsAccelTests.cpp b/mlir/unittests/Dialect/Rock/InitParamsAccelTests.cpp index 51f06596128b..7ebade4eb3a2 100644 --- a/mlir/unittests/Dialect/Rock/InitParamsAccelTests.cpp +++ b/mlir/unittests/Dialect/Rock/InitParamsAccelTests.cpp @@ -21,7 +21,7 @@ namespace { TEST(V4Config, First) { InitParamsAccel validParams; bool isValidPerfConfig = - validParams.deserialize("v4:64,64,8,32,32,16,4,2,3,2,1,1"); + validParams.deserialize("v4:64,64,8,32,32,16,4,2,3,2,0,0,1,1"); EXPECT_EQ(isValidPerfConfig, true); EXPECT_EQ(validParams.gemmBThreadCopyMoreGemmKPack, true); @@ -39,12 +39,14 @@ TEST(V4Config, First) { EXPECT_EQ(validParams.gemmMPerBlock, 64); EXPECT_EQ(validParams.gemmNPerBlock, 64); EXPECT_EQ(validParams.gemmKPerBlock, 8); + EXPECT_EQ(validParams.wavesPerEU, 0); + EXPECT_EQ(validParams.gridGroupSize, 0); } TEST(V4Config, Second) { InitParamsAccel validParams; bool isValidPerfConfig = - validParams.deserialize("v4:128,64,8,64,32,32,4,9,2,2,0,1"); + validParams.deserialize("v4:128,64,8,64,32,32,4,9,2,0,8,64,0,1"); EXPECT_EQ(isValidPerfConfig, true); EXPECT_EQ(validParams.gemmBThreadCopyMoreGemmKPack, true); @@ -57,11 +59,13 @@ TEST(V4Config, Second) { EXPECT_EQ(validParams.gemmNPerWaveOrMnPerXdl, 0); EXPECT_EQ(validParams.gemmScheduleVersion, 2); EXPECT_EQ(validParams.gemmMnPerXdl, 32); - EXPECT_EQ(validParams.outputSwizzle, 2); + EXPECT_EQ(validParams.outputSwizzle, 0); EXPECT_EQ(validParams.getVersion(), InitParamsAccel::Version::V4); EXPECT_EQ(validParams.gemmMPerBlock, 128); EXPECT_EQ(validParams.gemmNPerBlock, 64); EXPECT_EQ(validParams.gemmKPerBlock, 8); + EXPECT_EQ(validParams.wavesPerEU, 8); + EXPECT_EQ(validParams.gridGroupSize, 64); } //===----------------------------------------------------------------------===// diff --git a/mlir/utils/performance/attentionSweeps.py b/mlir/utils/performance/attentionSweeps.py index 6280657c9e30..4c792bafd0e1 100755 --- a/mlir/utils/performance/attentionSweeps.py +++ b/mlir/utils/performance/attentionSweeps.py @@ -152,7 +152,8 @@ def sample_attn_shape(): [4, 8, 16], # kPack [1], # splitKFactor [1, 2, 3, 4], # scheduleVersion - [2], # outputSwizzle + [0, 1, 2], # outputSwizzle + [0, 1, 2, 4, 8], # wavesPerEU [0, 1] # forceUnroll )) @@ -168,7 +169,8 @@ def sample_attn_shape(): [4, 8, 16], # kPack [1], # splitKFactor [1, 2, 3, 4], # scheduleVersion - [2], # outputSwizzle + [0, 1, 2], # outputSwizzle + [0, 1, 2, 4, 8, 16], # wavesPerEU [0, 1] # forceUnroll )) diff --git a/mlir/utils/performance/parameterSweeps.py b/mlir/utils/performance/parameterSweeps.py index efccf7b97e89..bb34bfa01d8c 100755 --- a/mlir/utils/performance/parameterSweeps.py +++ b/mlir/utils/performance/parameterSweeps.py @@ -47,7 +47,11 @@ class Version(enum.Enum): def __init__(self, config: Sequence[int], version: Version = Version.V4): self._config = config self._version = version - self._version_map = {PerfConfig.Version.V2: "v2", PerfConfig.Version.V3: "v3", PerfConfig.Version.V4: "v4"} + self._version_map = { + PerfConfig.Version.V2: "v2", + PerfConfig.Version.V3: "v3", + PerfConfig.Version.V4: "v4" + } def __str__(self): suffix = ','.join(str(v) for v in self._config) @@ -428,7 +432,7 @@ def to_conv_structure_type_test(params, options: Options) -> MLIROnlyConfig: # NPerWave (exponent) range(2, 8), # MNPerXdl (exponent) - range(4, 5), # 16 only + range(4, 5), # 16 only # KPack (exponent) range(2, 5), # splitKFactor (exponent) @@ -462,16 +466,23 @@ def to_conv_structure_type_test(params, options: Options) -> MLIROnlyConfig: # GEMM Schedule Version range(1, 3)) + def to_wmma_perf_config_test(params, options: Options) -> MLIROnlyConfig: n, g, c, hi, wi, k, y, x, sw, sh, phl, phr, pwl, pwr, dh, dw = \ 512, 1, 512, 1, 1, 512, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1 op, layout, dtype, m_per_block, n_per_block, k_per_block, m_per_wave, \ n_per_wave, mn_per_xdl, kpack, split_k, gemm_schedule = params - # no mn_per_xdl for wmma + # set heuristic settings + # TODO: randomly perf_configs sample instead of brute force + output_swizzle = 2 + waves_per_eu = 0 + grid_group_size = 0 + perf_config_tuple = (1 << m_per_block, 1 << n_per_block, 1 << k_per_block, 1 << m_per_wave, - 1 << n_per_wave, 1 << kpack, 1 << split_k, gemm_schedule, 2, 1, 1) + 1 << n_per_wave, 1 << mn_per_xdl, 1 << kpack, 1 << split_k, gemm_schedule, + output_swizzle, waves_per_eu, grid_group_size, 1, 1) return MLIROnlyConfig(dtype, op, layout, n, c, hi, wi, k, y, x, sh, sw, phl, phr, pwl, pwr, dh, - dw, g, options.arch, PerfConfig(perf_config_tuple, PerfConfig.Version.V3)) + dw, g, options.arch, PerfConfig(perf_config_tuple, PerfConfig.Version.V4)) def to_mfma_perf_config_test(params, options: Options) -> MLIROnlyConfig: @@ -479,8 +490,15 @@ def to_mfma_perf_config_test(params, options: Options) -> MLIROnlyConfig: 512, 1, 512, 1, 1, 512, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1 op, layout, dtype, m_per_block, n_per_block, k_per_block, m_per_wave, \ n_per_wave, mn_per_xdl, kpack, split_k, gemm_schedule = params + # set heuristic settings + # TODO: randomly perf_configs sample instead of brute force + output_swizzle = 2 + waves_per_eu = 0 + grid_group_size = 0 + perf_config_tuple = (1 << m_per_block, 1 << n_per_block, 1 << k_per_block, 1 << m_per_wave, - 1 << n_per_wave, 1 << mn_per_xdl, 1 << kpack, 1 << split_k, gemm_schedule, 2, 1, 1) + 1 << n_per_wave, 1 << mn_per_xdl, 1 << kpack, 1 << split_k, gemm_schedule, + output_swizzle, waves_per_eu, grid_group_size, 1, 1) return MLIROnlyConfig(dtype, op, layout, n, c, hi, wi, k, y, x, sh, sw, phl, phr, pwl, pwr, dh, dw, g, options.arch, PerfConfig(perf_config_tuple, PerfConfig.Version.V4))