Skip to content

Commit 4fbd50b

Browse files
committed
Addressing PR comments
1 parent 56d7a92 commit 4fbd50b

24 files changed

+25
-45
lines changed

mlir/include/mlir/Dialect/Rock/utility/loweringUtils.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,6 @@ FailureOr<VectorDimInfo> getVectorDim(Location loc, Value matrix, Type elemType,
262262
// Get the LDS size of the memref
263263
std::optional<int64_t> getWorkgroupMemorySize(MemRefType type);
264264

265-
// Return trip count for scf::ForOp
266-
std::optional<int64_t> getConstantLoopTripCount(scf::ForOp loopOp);
267-
268265
} // end namespace rock
269266
} // end namespace mlir
270267
#endif

mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2262,14 +2262,15 @@ struct GridwiseAttentionAccelRewritePattern
22622262
// LDS Barrier (issue 1811): some threads might be loading from LDS
22632263
// while others are in the next iteration (here), writing to LDS. This
22642264
// barrier prevents that.
2265-
std::optional<uint64_t> mLoopIters = std::nullopt;
2265+
std::optional<APInt> mLoopIters = std::nullopt;
22662266
// mLoopOp can be a dynamic loop if we are using KV Cache or Causal
22672267
// masking. If that's the case, we can't know the number of iterations
22682268
// at compile time.
22692269
if (!dynamicMLoop)
2270-
mLoopIters = getConstantLoopTripCount(mLoopOp);
2270+
mLoopIters = mLoopOp.getStaticTripCount();
22712271

2272-
bool mIterOneIter = mLoopIters.has_value() && mLoopIters.value() == 1;
2272+
bool mIterOneIter =
2273+
mLoopIters.has_value() && mLoopIters.value().getSExtValue() == 1;
22732274
if (!mIterOneIter) {
22742275
LLVM_DEBUG(llvm::dbgs() << "adding a barrier in the first gemm loop\n");
22752276
LDSBarrierOp::create(rewriter, loc);

mlir/lib/Dialect/Rock/utility/loweringUtils.cpp

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,21 +1184,3 @@ std::optional<int64_t> mlir::rock::getWorkgroupMemorySize(MemRefType type) {
11841184
}
11851185
return std::nullopt;
11861186
}
1187-
1188-
std::optional<int64_t> mlir::rock::getConstantLoopTripCount(scf::ForOp loopOp) {
1189-
Value lowerBound = loopOp.getLowerBound();
1190-
Value upperBound = loopOp.getUpperBound();
1191-
Value step = loopOp.getStep();
1192-
auto lbConst = getConstantIntValue(lowerBound);
1193-
auto ubConst = getConstantIntValue(upperBound);
1194-
auto stepConst = getConstantIntValue(step);
1195-
if (lbConst && ubConst && stepConst) {
1196-
int64_t tripCount =
1197-
(ubConst.value() - lbConst.value() + stepConst.value() - 1) /
1198-
stepConst.value();
1199-
// Handle the case where tripCount might be negative
1200-
tripCount = std::max(tripCount, 0L);
1201-
return tripCount;
1202-
}
1203-
return std::nullopt;
1204-
}

mlir/test/e2e/AttentionDirectToLDS.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
if (not config.arch_support_mfma) and (not config.arch_support_wmma):
1+
if not (config.arch_support_mfma or config.arch_support_wmma):
22
config.unsupported = True
33

44
if not 'direct_to_lds_32b' in config.features and not 'direct_to_lds_128b' in config.features:
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
if (not config.arch_support_mfma) and (not config.arch_support_wmma):
1+
if not (config.arch_support_mfma or config.arch_support_wmma):
22
config.unsupported = True

mlir/test/e2e/ConvElementwiseGemmDirectToLDS.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
if (not config.arch_support_mfma) and (not config.arch_support_wmma):
1+
if not (config.arch_support_mfma or config.arch_support_wmma):
22
config.unsupported = True
33

44
if not 'direct_to_lds_32b' in config.features and not 'direct_to_lds_128b' in config.features:
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
if (not config.arch_support_mfma) and (not config.arch_support_wmma):
1+
if not (config.arch_support_mfma or config.arch_support_wmma):
22
config.unsupported = True

mlir/test/e2e/GemmElementwiseGemmDirectToLDS.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
if (not config.arch_support_mfma) and (not config.arch_support_wmma):
1+
if not (config.arch_support_mfma or config.arch_support_wmma):
22
config.unsupported = True
33

44
if not 'direct_to_lds_32b' in config.features and not 'direct_to_lds_128b' in config.features:
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
if (not config.arch_support_mfma) and (not config.arch_support_wmma):
1+
if not (config.arch_support_mfma or config.arch_support_wmma):
22
config.unsupported = True

mlir/test/e2e/PrAttentionBF16.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
if (not config.arch_support_mfma) and (not config.arch_support_wmma):
1+
if not (config.arch_support_mfma or config.arch_support_wmma):
22
config.unsupported = True

0 commit comments

Comments
 (0)