Skip to content

Commit 38a8f9d

Browse files
committed
Use pipelining for outer loop of attention
1 parent 12e26fd commit 38a8f9d

File tree

11 files changed

+672
-266
lines changed

11 files changed

+672
-266
lines changed

mlir/include/mlir/Dialect/Rock/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ namespace rock {
5050
#define GEN_PASS_DECL_ROCKFINDFIRSTGEMMINDEXPASS
5151
#define GEN_PASS_DECL_ROCKREMOVEOUTPUTALLOCPASS
5252
#define GEN_PASS_DECL_ROCKBLOCKWISELOADTILETOTHREADWISEPASS
53+
#define GEN_PASS_DECL_ROCKPREPAREPIPELINEPASS
5354
#define GEN_PASS_DECL_ROCKANNOTATELIVENESSPASS
5455

5556
#define GEN_PASS_REGISTRATION

mlir/include/mlir/Dialect/Rock/Passes.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,4 +210,10 @@ def RockAnnotateLivenessPass
210210
let dependentDialects = ["rock::RockDialect", "memref::MemRefDialect"];
211211
}
212212

213+
def RockPreparePipelinePass
214+
: Pass<"rock-prepare-pipeline", "::mlir::func::FuncOp"> {
215+
let summary = "This pass prepares ops for pipelining";
216+
let dependentDialects = ["rock::RockDialect", "func::FuncDialect"];
217+
}
218+
213219
#endif // MLIR_DIALECT_ROCK_PASSES

mlir/lib/Dialect/Rock/IR/RockDialect.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,16 @@ template <>
9292
struct rank<0> {};
9393

9494
template <typename OpType>
95+
<<<<<<< HEAD
9596
static auto
9697
getGemmEffects(rank<1>, OpType &op,
9798
SmallVectorImpl<MemoryEffects::EffectInstance> &effects)
9899
-> decltype(void(op.getScaleA()), void(op.getScaleB())) {
100+
=======
101+
static void
102+
getGemmEffects(OpType &op,
103+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
104+
>>>>>>> 64ae93840ed4 (fixes)
99105
auto *read = MemoryEffects::Read::get();
100106
auto *write = MemoryEffects::Write::get();
101107

mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ void rock::buildKernelPipeline(OpPassManager &pm,
164164
funcPm.addPass(rock::createRockShuffleGemmForReductions());
165165
funcPm.addPass(rock::createRockGridwiseGemmToBlockwisePass());
166166
funcPm.addPass(rock::createRockBlockwiseLoadTileToThreadwisePass());
167+
funcPm.addPass(rock::createRockPreparePipelinePass());
167168

168169
// We want to delay blockwise lowering in the fusion cases
169170
// until after linalg align pass because with reduction fusion

mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -206,11 +206,13 @@ class LoweringBlockwiseLoadTileOp final
206206
// we want to insert all allocs and transforms before the loop
207207
Operation *parentOp = op->getParentOp();
208208
assert(parentOp && "BlockwiseLoadTileOp must have a parent op");
209-
if (isa<LoopLikeOpInterface>(parentOp))
210-
b.setInsertionPoint(parentOp);
211-
else
212-
b.setInsertionPoint(op);
213209

210+
// let's add the allocs to the begginging of the loop
211+
if (auto parentLoop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
212+
// workaround for parentLoop.getBody()
213+
assert(parentLoop->getRegions().size() == 1);
214+
b.setInsertionPointToStart(&parentLoop->getRegion(0).front());
215+
}
214216
Value loadBuffer, storeBuffer;
215217
if (loadType == GemmLoadTileType::BypassLDS) {
216218
auto privateMemoryAddressSpace = b.getAttr<gpu::AddressSpaceAttr>(
@@ -233,10 +235,6 @@ class LoweringBlockwiseLoadTileOp final
233235
SmallVector<int64_t, 3> bidGridLengths = {G, mBlocks, nBlocks};
234236
SmallVector<StringRef, 3> bidGridOrder = {"g_block", "m_block", "n_block"};
235237

236-
// Create the stages for the blockwise load tile op
237-
if (isa<LoopLikeOpInterface>(parentOp))
238-
b.setInsertionPoint(op);
239-
240238
auto [stageGlobalRead, stageGlobalReadNew] =
241239
createOrGetStage(b, loc, "GlobalRead", parentOp);
242240
{

mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ add_rocmlir_dialect_library(MLIRRockTransforms
3232
FindFirstGemmIndex.cpp
3333
RemoveOutputAlloc.cpp
3434
BlockwiseLoadTileToThreadwise.cpp
35+
PreparePipeline.cpp
3536
AnnotateLiveness.cpp
3637

3738
ADDITIONAL_HEADER_DIRS

0 commit comments

Comments
 (0)