-
Notifications
You must be signed in to change notification settings - Fork 50
Extend problem key for reduction fusions #2133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR extends the problem key generation for the tuning database to include information about reduction fusion operations. Previously, only the primary operation (GEMM, convolution, attention) was described in the problem key, leading to cache collisions when fused and standalone modules shared the same key. By adding reduction fusion details, each unique fusion pattern now gets its own cache entry with appropriate performance configurations.
Key Changes:
- Extended the C++ problem key generation to detect and encode reduction operations (sum/max) with their axes and whether intermediate pointwise operations exist
- Modified Python configuration classes to preserve the full original command line (including fusion info) for accurate tuning DB lookups
- Added comprehensive test coverage for various fusion scenarios across different operation types
Reviewed changes
Copilot reviewed 19 out of 19 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp |
Implements fusion analysis logic to detect reduction operations and append fusion information to problem keys |
mlir/include/mlir/Dialect/Rock/utility/fusionUtils.h |
Exposes validOperationGemmOut() function for use in fusion detection |
mlir/lib/Dialect/Rock/utility/fusionUtils.cpp |
Changes visibility of validOperationGemmOut() from static to public |
mlir/utils/performance/perfRunner.py |
Adds to_tuning_key() method to configuration classes and updates DB lookup logic to use full problem keys with fusion info |
mlir/test/fusion/problem-key-tests/*.mlir |
Comprehensive test suite covering GEMM, GEMM+GEMM, and convolution operations with various reduction fusion patterns |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-max-axis2.mlir
Outdated
Show resolved
Hide resolved
axis doesn't tell what is the rank of underlying tensor. if let's says x is 5D and axis = 2; y = 4D and axis = 2, they both will have differnet amount of strides and can lead to different number of DPP instructions being used. |
I updated to add the the rank of the input to the reduction to avoid these collisions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 16 out of 16 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| FailureOr<Value> baseValue = getBaseValue(start); | ||
| if (failed(baseValue)) | ||
| return {false, false}; // Could not find base value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
which cases could lead to this? I wonder if we should make the function FailureOr<..> and fail here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It fails like this when the start value (i.e., the output of the reduction) cannot be traced back to a memref::AllocOp or a BlockArgument. I do like using FailureOr, so I'm updating to that.
| if (writers) { | ||
| for (OpOperand *writerOperand : *writers) { | ||
| auto genericOp = dyn_cast<linalg::GenericOp>(writerOperand->getOwner()); | ||
| if (!genericOp) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we "continue" it if it's not genericOp?
We might be also to find another genericOp in the next one even if this one is not a genericOp, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean, we could have:
out = gemm(...)
b = linalg(out)
a = non_linalg(b)
return a
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My understanding of .getWriters() from BufferDependencyAnalysis is that it will do just that. It already has logic that traces through so that any rock.transforms will be skipped. Are there other non-linalg ops that we care about? Wouldn't that mean that it's a non-fusible operation and should be skipped?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it does that, yes, but then if the op is not a linalg.generic you just skip it here, why do we want to do that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 16 out of 16 changed files in this pull request and generated no new comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
dea97c1 to
d3a6a7a
Compare
@umangyadav Also just updated to get the stride information from the memref type. |
|
Reductions can have multiple axes although we reshape it to one axis in migaphx due to limitations with TOSA, but there is an important difference between {N,C,H,W} -> {N,C,1,1} and {N,C,H,W} -> {N,C/8,1,1,1} which this problem key doesnt capture. Perhaps the reduction should be the space of the reduction as a tensor in the order it is from the output of the gemm or conv(ie a transpose is unlikely but if it does happen it should be the order before the transpose not after). So something like this: |
|
There could also be multiple reductions with different methods as well. For dynamic quantization, we could fuse a |
This can handle multiple reductions with separate methods, but it relies on earlier pattern matching in TosaToRock to correctly identify reductions and create |
@pfultz2 do you have an example of MIGraphX IR with {N,C,H,W} -> {N,C/8,1,1,1}? I want to see if we have the rock.transforms in our IR that would enable us to detect something like this (and make sure that they haven't been optimized away at the point where |
| }; | ||
|
|
||
| // Helper to get the base value (allocation or block argument) from a value | ||
| static FailureOr<Value> getBaseValue(Value v) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be moved to loweringUtils.h
| } | ||
|
|
||
| // Helper to trace backwards from a value to see if it reaches the target | ||
| // Returns success(hasPointwise) if target is reached, failure otherwise |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| // Returns success(hasPointwise) if target is reached, failure otherwise | |
| // Returns success if target is reached via traversing through pointwise, failure otherwise |
| SmallVector<ReductionInfo> reductions; | ||
|
|
||
| bool hasReduction() const { return !reductions.empty(); } | ||
| int numReductionOutputs() const { return reductions.size(); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
| int numReductionOutputs() const { return reductions.size(); } | |
| int numReductions() const { return reductions.size(); } |
numReductionOutputs is ambigous. It could mean how many reduction operations are being returned from the function or how many outputs each reduction op has (if there is such an op with multiple outputs).
| case ReduceMethod::Sum: | ||
| problemOS << "sum"; | ||
| break; | ||
| case ReduceMethod::Max: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can there be some other kind of reduce operations ? Something other than Sum/Max ?
i suggest adding llvm_unreachable as default case.
| #transform_map8 = #rock.transform_map<#map5 by [<Unmerge{128} ["exp1"] at [1] -> ["dim0"] at [0]>, <PassThrough ["dim1"] at [2] -> ["dim1"] at [1]>, <AddDim{1} ["unit0"] at [0] -> [] at []>] bounds = [1, 128, 256] -> [128, 256]> | ||
| #transform_map9 = #rock.transform_map<#map6 by [<Merge{1, 128, 1} ["dim0"] at [0] -> ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [128] -> [1, 128, 1]> | ||
| module { | ||
| func.func private @gemm_mul_reduce_sum(%arg0: memref<8192xf32> {mhal.read_access}, %arg1: memref<16384xf32> {mhal.read_access}, %arg2: memref<32768xf32> {mhal.read_access}, %arg3: memref<128xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0.000000e+00 : f32}) attributes {arch = "gfx942", kernel, num_cu = 120 : i64} { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
use num_cu = 304 for gfx942 everywhere
| // Extract stride for the reduction dimension | ||
| SmallVector<int64_t> strides; | ||
| int64_t offset; | ||
| if (succeeded(memrefType.getStridesAndOffset(strides, offset))) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Q: Is there a test with stride != 1 ?
I think the way we use memref strides are always 1. Layout information is encoded by series of rock.transforms. So i don't think there is an easy way to get that information out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think the stride from memref is the real stride in rocmlir. We can have transforms that do padding etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, as @pfultz2 suggested, we might want to add a list of axes from the input tensor shape that get reduced, instead of the output axis.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my question as well, is if we are overcomplicating things, maybe to know we are doing a reduction is enough, does it help tuning in any way to know the axis?
Most probably because we haven't had need to enable other reductions |
AFAIK, we only enable sum reductions at the output of the gemm. That is because there's no atomic max for CDNA (well there's only atomic_max_f64). Regarding @pfultz2 point about axis, he's right, we are returning the axis of reduction of the output tensor, but I guess they are more interested in the axis from the input tensor point of view? We can use getLowerSubDimensions to get which axis we are reducing from the input @justinrosner. |
| trans_scale_b = False | ||
|
|
||
| # Store the original command line for accurate tuning DB lookups | ||
| # (including fusion info which we don't parse but need for cache key) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why don't we parse the fusion key? If you have two gemms
-t f16 -m 64 -n 64 -k 64
-t f16 -m 64 -n 64 -k 64 -reduction ...
They should be treated as different problems, not the same. That's the point of all of these changes, right? to tune them as different problems.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, when we run exhaustive tuning + quick tuning, they are different problems as well.
|
before we continue working on this, I think we should have a better understading about perf numbers, it seems with the initial results it didn't help: https://github.com/ROCm/rocMLIR-internal/issues/1720#issuecomment-3582852236 Is this worth it? |
Motivation
This PR extends the problem key to also include information about reduction fusions, in addition to the conv/dot/attention ops that were already being described. The goal of doing so is to remove collisions when querying the problem cache so that fused and standalone modules no longer need to share the same perfConfig.
Technical Details
When reduction operations are detected in the output fusions, the problem key now includes the following fields:
fusion_reduce: Marker flag indicating that reduction fusion information followscount=<N>: The total number of reduction operations fused with the outputmethod:axis<N>:stride<N>: The reduction operation being performed (can besumormax) and the axis on which that reduction takes placehasPointwise: Flag indicating whether there are intermediate pointwise/elementwise operations between the gemm/conv/attention output and this particular reductionE.g.,
Test Plan
Test Result
Submission Checklist