Skip to content

Commit f8992ac

Browse files
committed
Use pipelining for attention
1 parent e4ab0c1 commit f8992ac

File tree

52 files changed

+1520
-803
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+1520
-803
lines changed

mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,12 @@ def Rock_AttnPerfConfig : Rock_Attr<"AttnPerfConfig", [RockTuningParamAttrInterf
297297
+ Twine(getOutputSwizzle()) + ","
298298
+ Twine(getForceUnroll())).toVector(perfStr);
299299
}
300+
AttnPerfConfigAttr withScheduleVersion(int64_t newScheduleVersion) const {
301+
return AttnPerfConfigAttr::get(
302+
getContext(), getMPerBlockG0(), getMPerBlockG1(), getNPerBlockG0(),
303+
getKpackPerBlock(), getMPerWave(), getMnPerXdl(), getKpack(),
304+
getSplitKFactor(), newScheduleVersion, getOutputSwizzle(), getForceUnroll());
305+
}
300306
}];
301307

302308
let builders = [

mlir/include/mlir/Dialect/Rock/IR/RockOps.td

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1372,10 +1372,10 @@ def Rock_BlockwiseGemmAccelOp
13721372
I32Attr:$inNPerThread, UnitAttr:$rotateMWithK, UnitAttr:$rotateNWithK,
13731373
UnitAttr:$loadAfromLDS, UnitAttr:$loadBfromLDS,
13741374
UnitAttr:$splitKAcrossThreadsFirstA,
1375-
UnitAttr:$splitKAcrossThreadsFirstB, UnitAttr:$directToLDS,
1376-
UnitAttr:$ldsLayoutMxK, UnitAttr:$ldsLayoutNxK,
1377-
MemRefOf<LdsBufferTypes>:$bufferA, MemRefOf<LdsBufferTypes>:$bufferB,
1378-
MemRefOf<AccelResTypes>:$matrixC,
1375+
UnitAttr:$splitKAcrossThreadsFirstB, UnitAttr:$directToLDSA,
1376+
UnitAttr:$directToLDSB, UnitAttr:$ldsLayoutMxK,
1377+
UnitAttr:$ldsLayoutNxK, MemRefOf<LdsBufferTypes>:$bufferA,
1378+
MemRefOf<LdsBufferTypes>:$bufferB, MemRefOf<AccelResTypes>:$matrixC,
13791379
Optional<MemRefOf<LdsBufferTypes>>:$bufferScaleA,
13801380
Optional<MemRefOf<LdsBufferTypes>>:$bufferScaleB,
13811381
TypeAttr:$elementTypeA, TypeAttr:$elementTypeB,
@@ -1414,9 +1414,10 @@ def Rock_BlockwiseLoadTileOp
14141414
TypeAttr:$elementTypeA, TypeAttr:$elementTypeB,
14151415
TypeAttr:$elementTypeALoad, TypeAttr:$elementTypeBLoad,
14161416
UnitAttr:$rotateWithK, UnitAttr:$swapThreadIterSubDims,
1417-
UnitAttr:$LDSLayoutDxK, Variadic<Index>:$sourceIndices, I64Attr:$G,
1418-
I64Attr:$M, I64Attr:$N, OptionalAttr<Rock_GemmFeaturesAttr>:$features,
1419-
I32Attr:$blockSize, RockAccelTuningParamAttrInterface:$params)> {
1417+
UnitAttr:$LDSLayoutDxK, UnitAttr:$splitKAcrossThreadsFirst,
1418+
Variadic<Index>:$sourceIndices, I64Attr:$G, I64Attr:$M, I64Attr:$N,
1419+
OptionalAttr<Rock_GemmFeaturesAttr>:$features, I32Attr:$blockSize,
1420+
RockAccelTuningParamAttrInterface:$params)> {
14201421
let summary =
14211422
"Blockwise load tile from global memory to LDS and/or registers";
14221423
let description = [{

mlir/lib/Dialect/Rock/Transforms/AffixTuningParameters.cpp

Lines changed: 60 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "mlir/Dialect/Rock/IR/Rock.h"
66
#include "mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.h"
77
#include "mlir/Dialect/Rock/IR/RockGemmWrapperInterface.h"
8+
#include "mlir/Dialect/Rock/IR/RockTypes.h"
89
#include "mlir/Dialect/Rock/Passes.h"
910
#include "mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h"
1011
#include "mlir/Dialect/Rock/Tuning/UtilityParams.h"
@@ -19,6 +20,7 @@
1920
#include "llvm/Support/Debug.h"
2021
#include "llvm/Support/LogicalResult.h"
2122
#include "llvm/Support/raw_ostream.h"
23+
#include <optional>
2224

2325
namespace mlir {
2426
namespace rock {
@@ -51,6 +53,41 @@ struct AffixTuningParameters
5153
};
5254
} // anonymous namespace
5355

56+
static FailureOr<std::optional<int64_t>> getScheduleVersion(func::FuncOp funcOp,
57+
Operation *op) {
58+
auto scheduleVersionAttrName = rock::ScheduleVersionAttr::getMnemonic();
59+
60+
std::optional<int64_t> scheduleVersion = std::nullopt;
61+
bool hasPerfConfig = op->hasAttrOfType<StringAttr>("perf_config");
62+
if (funcOp->hasAttrOfType<rock::ScheduleVersionAttr>(
63+
scheduleVersionAttrName) &&
64+
hasPerfConfig) {
65+
return op->emitError(
66+
"kernel has both perf_config and schedule_version attribute "
67+
"set. Please modify schedule version directly inside "
68+
"perf_config and remove schedule_version\n");
69+
}
70+
if (funcOp->hasAttrOfType<rock::ScheduleVersionAttr>(
71+
scheduleVersionAttrName)) {
72+
scheduleVersion = dyn_cast<rock::ScheduleVersionAttr>(
73+
funcOp->removeAttr(scheduleVersionAttrName))
74+
.getScheduleVersion();
75+
} else if (!hasPerfConfig) {
76+
// set default schedule
77+
scheduleVersion = static_cast<int64_t>(GemmLoadTileType::Default);
78+
}
79+
80+
// check scheduleVersion is valid
81+
if (scheduleVersion.has_value()) {
82+
std::optional<GemmLoadTileType> maybeLoadType =
83+
rock::symbolizeGemmLoadTileType(scheduleVersion.value());
84+
if (!maybeLoadType.has_value())
85+
return op->emitOpError("schedule version value is incorrect");
86+
}
87+
88+
return scheduleVersion;
89+
}
90+
5491
void AffixTuningParameters::runOnOperation() {
5592
func::FuncOp func = getOperation();
5693
// currently, in rocMLIR we only support one Fusion Root per function.
@@ -131,30 +168,18 @@ void AffixTuningParameters::setUtilityKernelSizes(Value arg, T utilityOp) {
131168
void AffixTuningParameters::affixTuningParametersImpl(
132169
RockGemmWrapperInterface op) {
133170
OpBuilder b(op.getContext());
134-
auto scheduleVersionAttrName = rock::ScheduleVersionAttr::getMnemonic();
135171
auto funcParent = op->getParentOfType<func::FuncOp>();
136172
std::string perfConfig;
137-
if (funcParent->hasAttrOfType<rock::ScheduleVersionAttr>(
138-
scheduleVersionAttrName) &&
139-
op->hasAttrOfType<StringAttr>("perf_config")) {
140-
op->emitError("kernel has both perf_config and schedule_version attribute "
141-
"set. Please modify schedule version directly inside "
142-
"perf_config and remove schedule_version\n");
143-
signalPassFailure();
144-
return;
145-
}
146173
if (auto perfConfigAttr =
147174
op->template getAttrOfType<StringAttr>("perf_config")) {
148175
perfConfig = perfConfigAttr.getValue().str();
149176
}
150-
// by default rocMLIR selects GEMM Schedule V1
151-
auto scheduleVersion = 1;
152-
if (funcParent->hasAttrOfType<rock::ScheduleVersionAttr>(
153-
scheduleVersionAttrName)) {
154-
scheduleVersion = dyn_cast<rock::ScheduleVersionAttr>(
155-
funcParent->removeAttr(scheduleVersionAttrName))
156-
.getScheduleVersion();
157-
}
177+
FailureOr<std::optional<int64_t>> maybeScheduleVersion =
178+
getScheduleVersion(funcParent, op);
179+
if (failed(maybeScheduleVersion))
180+
return signalPassFailure();
181+
182+
std::optional<int64_t> scheduleVersion = maybeScheduleVersion.value();
158183

159184
GemmFeatures features = rock::getFeatures(op);
160185
if (isAccel(features)) {
@@ -165,9 +190,9 @@ void AffixTuningParameters::affixTuningParametersImpl(
165190
// update schedule version to what is provided by the user if and only if
166191
// user hasn't provided perfConfig, otherwise just keep whatever is inside
167192
// perfConfig
168-
if (!op->hasAttrOfType<StringAttr>("perf_config")) {
169-
validParams.gemmScheduleVersion = scheduleVersion;
170-
}
193+
if (scheduleVersion.has_value())
194+
validParams.gemmScheduleVersion = scheduleVersion.value();
195+
171196
if (failed(status)) {
172197
// Try again if allowed.
173198
if (fallBackNoConfig) {
@@ -233,9 +258,8 @@ void AffixTuningParameters::affixTuningParametersImpl(
233258
// update schedule version to what is provided by the user if and only if
234259
// user hasn't provided perfConfig, otherwise just keep whatever was
235260
// obtained from perfConfig
236-
if (!op->hasAttrOfType<StringAttr>("perf_config")) {
237-
validParams.gemmScheduleVersion = scheduleVersion;
238-
}
261+
if (scheduleVersion.has_value())
262+
validParams.gemmScheduleVersion = scheduleVersion.value();
239263

240264
Attribute gemmParams = populateParams.getGemmParamsAttr(b, validParams);
241265
op.setGemmParamsAttr(gemmParams);
@@ -289,6 +313,13 @@ void AffixTuningParameters::affixTuningParametersImpl(
289313
"with matrix accelerator extentions");
290314
return signalPassFailure();
291315
}
316+
auto funcParent = op->getParentOfType<func::FuncOp>();
317+
FailureOr<std::optional<int64_t>> maybeScheduleVersion =
318+
getScheduleVersion(funcParent, op);
319+
if (failed(maybeScheduleVersion))
320+
return signalPassFailure();
321+
322+
std::optional<int64_t> scheduleVersion = maybeScheduleVersion.value();
292323

293324
Attribute params0 = op.getGemm0Params().value_or(nullptr);
294325
// set a default one if params is not provided
@@ -305,6 +336,11 @@ void AffixTuningParameters::affixTuningParametersImpl(
305336
op.emitError("perf config string has an incorrect format.");
306337
return signalPassFailure();
307338
}
339+
340+
if (scheduleVersion.has_value())
341+
attnPerfConfig =
342+
attnPerfConfig.withScheduleVersion(scheduleVersion.value());
343+
308344
GemmFeatures features = rock::getFeatures(op);
309345
RockAccelTuningParamAttrInterface accelParams0;
310346
if (bitEnumContainsAny(features, GemmFeatures::mfma)) {

mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -466,24 +466,26 @@ struct BlockwiseGemmAccelRewritePattern
466466
// considered a temporary hack until we have a proper way of "searching"
467467
// through different schedules (either heuristically or automatically)
468468

469-
bool directToLDS = op.getDirectToLDS();
469+
bool directToLDSA = op.getDirectToLDSA();
470+
bool directToLDSB = op.getDirectToLDSB();
470471
Value wrappedLDSBufferForLoadA, wrappedLDSBufferForLoadB;
471472
if (loadAFromLDS) {
472473
wrappedLDSBufferForLoadA = accelEmitterPtr->wrapLDSBufferForLoad(
473474
b, loc, op.getMatrixA(), op.getBlockSize(), op.getInMPerThread(), "m",
474-
op.getRotateMWithK(), directToLDS, op.getLdsLayoutMxK(),
475+
op.getRotateMWithK(), directToLDSA, op.getLdsLayoutMxK(),
475476
op.getSplitKAcrossThreadsFirstA());
476477
}
477478
if (loadBFromLDS) {
478479
wrappedLDSBufferForLoadB = accelEmitterPtr->wrapLDSBufferForLoad(
479480
b, loc, op.getMatrixB(), op.getBlockSize(), op.getInNPerThread(), "n",
480-
op.getRotateNWithK(), directToLDS, op.getLdsLayoutNxK(),
481+
op.getRotateNWithK(), directToLDSB, op.getLdsLayoutNxK(),
481482
op.getSplitKAcrossThreadsFirstB());
482483
}
483484

484485
auto loadBuffer = [&](Value buffer, Value wrappedLDSBufferForLoad,
485486
Value loopVar, Type argType, int64_t repeats,
486-
bool loadFromLDS, bool isA) -> Value {
487+
bool loadFromLDS, bool directToLDS,
488+
bool isA) -> Value {
487489
Value inputBuffer = buffer;
488490
SmallVector<int64_t> shape;
489491
if (directToLDS) {
@@ -545,7 +547,7 @@ struct BlockwiseGemmAccelRewritePattern
545547

546548
Value bufferA = adaptor.getBufferA();
547549
bufferA = loadBuffer(bufferA, wrappedLDSBufferForLoadA, i, argTypeA,
548-
mRepeats, loadAFromLDS, true);
550+
mRepeats, loadAFromLDS, directToLDSA, true);
549551
Value viewA =
550552
accelEmitterPtr->generateThreadwiseViewBufferA(b, loc, bufferA);
551553

@@ -557,7 +559,7 @@ struct BlockwiseGemmAccelRewritePattern
557559

558560
Value bufferB = adaptor.getBufferB();
559561
bufferB = loadBuffer(bufferB, wrappedLDSBufferForLoadB, j, argTypeB,
560-
nRepeats, loadBFromLDS, false);
562+
nRepeats, loadBFromLDS, directToLDSB, false);
561563
Value viewB =
562564
accelEmitterPtr->generateThreadwiseViewBufferB(b, loc, bufferB);
563565

mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,14 @@ class LoweringBlockwiseLoadTileOp final
7676
const std::unique_ptr<rock::accel::AccelEmitter> &accelEmitterPtr,
7777
Value tid, StringRef dName, Value ldsView, Value regs, int64_t blockSize,
7878
int64_t inDPerThread, bool rotateDWithK, bool forceUnroll,
79-
bool directToLDS, bool ldsLayoutDxK) const {
79+
bool directToLDS, bool ldsLayoutDxK,
80+
bool doSplitKAcrossThreadsFirst) const {
8081

8182
// wrapLDSBufferForLoad is reading a single set of Ks into private memory
8283
// A/B[m/n, 0:kBasePerThread]
8384
Value ldsViewForLoad = accelEmitterPtr->wrapLDSBufferForLoad(
8485
b, loc, ldsView, blockSize, inDPerThread, dName, rotateDWithK,
85-
directToLDS, ldsLayoutDxK);
86+
directToLDS, ldsLayoutDxK, doSplitKAcrossThreadsFirst);
8687

8788
// We enhance the transformation from wrapLDSBufferForLoad using a builder
8889
// that, given a single index, splits it into "m"("n") and "k" and lets
@@ -158,6 +159,7 @@ class LoweringBlockwiseLoadTileOp final
158159
bool doRotateWithK = op.getRotateWithK();
159160
bool doSwapThreadIterSubDims = op.getSwapThreadIterSubDims();
160161
bool ldsLayoutDxK = op.getLDSLayoutDxK();
162+
bool doSplitKAcrossThreadsFirst = op.getSplitKAcrossThreadsFirst();
161163
LDSLayoutConfigDim ldsLayoutConfig{doRotateWithK, doSwapThreadIterSubDims,
162164
ldsLayoutDxK};
163165

@@ -406,7 +408,8 @@ class LoweringBlockwiseLoadTileOp final
406408
generateReadLoop(loc, b, accelEmitterPtr, tid, dName, ldsViewForGemm,
407409
destRegisters, blockSize, copyDPerThread,
408410
ldsLayoutConfig.doRotateWithK, forceUnroll,
409-
directToLDS, ldsLayoutConfig.ldsLayoutDxK);
411+
directToLDS, ldsLayoutConfig.ldsLayoutDxK,
412+
doSplitKAcrossThreadsFirst);
410413
if (stageLDSReadNew)
411414
rock::YieldOp::create(b, loc);
412415
}

0 commit comments

Comments
 (0)