55#include " mlir/Dialect/Rock/IR/Rock.h"
66#include " mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.h"
77#include " mlir/Dialect/Rock/IR/RockGemmWrapperInterface.h"
8+ #include " mlir/Dialect/Rock/IR/RockTypes.h"
89#include " mlir/Dialect/Rock/Passes.h"
910#include " mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h"
1011#include " mlir/Dialect/Rock/Tuning/UtilityParams.h"
1920#include " llvm/Support/Debug.h"
2021#include " llvm/Support/LogicalResult.h"
2122#include " llvm/Support/raw_ostream.h"
23+ #include < optional>
2224
2325namespace mlir {
2426namespace rock {
@@ -51,6 +53,41 @@ struct AffixTuningParameters
5153};
5254} // anonymous namespace
5355
56+ static FailureOr<std::optional<int64_t >> getScheduleVersion (func::FuncOp funcOp,
57+ Operation *op) {
58+ auto scheduleVersionAttrName = rock::ScheduleVersionAttr::getMnemonic ();
59+
60+ std::optional<int64_t > scheduleVersion = std::nullopt ;
61+ bool hasPerfConfig = op->hasAttrOfType <StringAttr>(" perf_config" );
62+ if (funcOp->hasAttrOfType <rock::ScheduleVersionAttr>(
63+ scheduleVersionAttrName) &&
64+ hasPerfConfig) {
65+ return op->emitError (
66+ " kernel has both perf_config and schedule_version attribute "
67+ " set. Please modify schedule version directly inside "
68+ " perf_config and remove schedule_version\n " );
69+ }
70+ if (funcOp->hasAttrOfType <rock::ScheduleVersionAttr>(
71+ scheduleVersionAttrName)) {
72+ scheduleVersion = dyn_cast<rock::ScheduleVersionAttr>(
73+ funcOp->removeAttr (scheduleVersionAttrName))
74+ .getScheduleVersion ();
75+ } else if (!hasPerfConfig) {
76+ // set default schedule
77+ scheduleVersion = static_cast <int64_t >(GemmLoadTileType::Default);
78+ }
79+
80+ // check scheduleVersion is valid
81+ if (scheduleVersion.has_value ()) {
82+ std::optional<GemmLoadTileType> maybeLoadType =
83+ rock::symbolizeGemmLoadTileType (scheduleVersion.value ());
84+ if (!maybeLoadType.has_value ())
85+ return op->emitOpError (" schedule version value is incorrect" );
86+ }
87+
88+ return scheduleVersion;
89+ }
90+
5491void AffixTuningParameters::runOnOperation () {
5592 func::FuncOp func = getOperation ();
5693 // currently, in rocMLIR we only support one Fusion Root per function.
@@ -131,30 +168,18 @@ void AffixTuningParameters::setUtilityKernelSizes(Value arg, T utilityOp) {
131168void AffixTuningParameters::affixTuningParametersImpl (
132169 RockGemmWrapperInterface op) {
133170 OpBuilder b (op.getContext ());
134- auto scheduleVersionAttrName = rock::ScheduleVersionAttr::getMnemonic ();
135171 auto funcParent = op->getParentOfType <func::FuncOp>();
136172 std::string perfConfig;
137- if (funcParent->hasAttrOfType <rock::ScheduleVersionAttr>(
138- scheduleVersionAttrName) &&
139- op->hasAttrOfType <StringAttr>(" perf_config" )) {
140- op->emitError (" kernel has both perf_config and schedule_version attribute "
141- " set. Please modify schedule version directly inside "
142- " perf_config and remove schedule_version\n " );
143- signalPassFailure ();
144- return ;
145- }
146173 if (auto perfConfigAttr =
147174 op->template getAttrOfType <StringAttr>(" perf_config" )) {
148175 perfConfig = perfConfigAttr.getValue ().str ();
149176 }
150- // by default rocMLIR selects GEMM Schedule V1
151- auto scheduleVersion = 1 ;
152- if (funcParent->hasAttrOfType <rock::ScheduleVersionAttr>(
153- scheduleVersionAttrName)) {
154- scheduleVersion = dyn_cast<rock::ScheduleVersionAttr>(
155- funcParent->removeAttr (scheduleVersionAttrName))
156- .getScheduleVersion ();
157- }
177+ FailureOr<std::optional<int64_t >> maybeScheduleVersion =
178+ getScheduleVersion (funcParent, op);
179+ if (failed (maybeScheduleVersion))
180+ return signalPassFailure ();
181+
182+ std::optional<int64_t > scheduleVersion = maybeScheduleVersion.value ();
158183
159184 GemmFeatures features = rock::getFeatures (op);
160185 if (isAccel (features)) {
@@ -165,9 +190,9 @@ void AffixTuningParameters::affixTuningParametersImpl(
165190 // update schedule version to what is provided by the user if and only if
166191 // user hasn't provided perfConfig, otherwise just keep whatever is inside
167192 // perfConfig
168- if (!op-> hasAttrOfType <StringAttr>( " perf_config " )) {
169- validParams.gemmScheduleVersion = scheduleVersion;
170- }
193+ if (scheduleVersion. has_value ())
194+ validParams.gemmScheduleVersion = scheduleVersion. value () ;
195+
171196 if (failed (status)) {
172197 // Try again if allowed.
173198 if (fallBackNoConfig) {
@@ -233,9 +258,8 @@ void AffixTuningParameters::affixTuningParametersImpl(
233258 // update schedule version to what is provided by the user if and only if
234259 // user hasn't provided perfConfig, otherwise just keep whatever was
235260 // obtained from perfConfig
236- if (!op->hasAttrOfType <StringAttr>(" perf_config" )) {
237- validParams.gemmScheduleVersion = scheduleVersion;
238- }
261+ if (scheduleVersion.has_value ())
262+ validParams.gemmScheduleVersion = scheduleVersion.value ();
239263
240264 Attribute gemmParams = populateParams.getGemmParamsAttr (b, validParams);
241265 op.setGemmParamsAttr (gemmParams);
@@ -289,6 +313,13 @@ void AffixTuningParameters::affixTuningParametersImpl(
289313 " with matrix accelerator extentions" );
290314 return signalPassFailure ();
291315 }
316+ auto funcParent = op->getParentOfType <func::FuncOp>();
317+ FailureOr<std::optional<int64_t >> maybeScheduleVersion =
318+ getScheduleVersion (funcParent, op);
319+ if (failed (maybeScheduleVersion))
320+ return signalPassFailure ();
321+
322+ std::optional<int64_t > scheduleVersion = maybeScheduleVersion.value ();
292323
293324 Attribute params0 = op.getGemm0Params ().value_or (nullptr );
294325 // set a default one if params is not provided
@@ -305,6 +336,11 @@ void AffixTuningParameters::affixTuningParametersImpl(
305336 op.emitError (" perf config string has an incorrect format." );
306337 return signalPassFailure ();
307338 }
339+
340+ if (scheduleVersion.has_value ())
341+ attnPerfConfig =
342+ attnPerfConfig.withScheduleVersion (scheduleVersion.value ());
343+
308344 GemmFeatures features = rock::getFeatures (op);
309345 RockAccelTuningParamAttrInterface accelParams0;
310346 if (bitEnumContainsAny (features, GemmFeatures::mfma)) {
0 commit comments