Skip to content

Conversation

@justinrosner
Copy link
Contributor

@justinrosner justinrosner commented Nov 26, 2025

Motivation

This PR extends the problem key to also include information about reduction fusions, in addition to the conv/dot/attention ops that were already being described. The goal of doing so is to remove collisions when querying the problem cache so that fused and standalone modules no longer need to share the same perfConfig.

Technical Details

When reduction operations are detected in the output fusions, the problem key now includes the following fields:

-fusion_reduce count=<N> <method>:axis<N>:stride<N>[:<hasPointwise>] ...
  • fusion_reduce: Marker flag indicating that reduction fusion information follows
  • count=<N>: The total number of reduction operations fused with the output
  • method:axis<N>:stride<N>: The reduction operation being performed (can be sum or max) and the axis on which that reduction takes place
    • Different axes can have different data access patterns, so I though it was best to track this value
    • Also track the strides
  • hasPointwise: Flag indicating whether there are intermediate pointwise/elementwise operations between the gemm/conv/attention output and this particular reduction

E.g.,

-fusion_reduce count=2 sum:axis2 sum:axis2:hasPointwise

Test Plan

  • Nightly CI

Test Result

Submission Checklist

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends the problem key generation for the tuning database to include information about reduction fusion operations. Previously, only the primary operation (GEMM, convolution, attention) was described in the problem key, leading to cache collisions when fused and standalone modules shared the same key. By adding reduction fusion details, each unique fusion pattern now gets its own cache entry with appropriate performance configurations.

Key Changes:

  • Extended the C++ problem key generation to detect and encode reduction operations (sum/max) with their axes and whether intermediate pointwise operations exist
  • Modified Python configuration classes to preserve the full original command line (including fusion info) for accurate tuning DB lookups
  • Added comprehensive test coverage for various fusion scenarios across different operation types

Reviewed changes

Copilot reviewed 19 out of 19 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp Implements fusion analysis logic to detect reduction operations and append fusion information to problem keys
mlir/include/mlir/Dialect/Rock/utility/fusionUtils.h Exposes validOperationGemmOut() function for use in fusion detection
mlir/lib/Dialect/Rock/utility/fusionUtils.cpp Changes visibility of validOperationGemmOut() from static to public
mlir/utils/performance/perfRunner.py Adds to_tuning_key() method to configuration classes and updates DB lookup logic to use full problem keys with fusion info
mlir/test/fusion/problem-key-tests/*.mlir Comprehensive test suite covering GEMM, GEMM+GEMM, and convolution operations with various reduction fusion patterns

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@umangyadav
Copy link
Member

method:axis: The reduction operation being performed (can be sum or max) and the axis on which that reduction takes place
Different axes can have different data access patterns, so I though it was best to track this value

axis doesn't tell what is the rank of underlying tensor. if let's says x is 5D and axis = 2; y = 4D and axis = 2, they both will have differnet amount of strides and can lead to different number of DPP instructions being used.

@justinrosner
Copy link
Contributor Author

axis doesn't tell what is the rank of underlying tensor. if let's says x is 5D and axis = 2; y = 4D and axis = 2, they both will have differnet amount of strides and can lead to different number of DPP instructions being used.

I updated to add the the rank of the input to the reduction to avoid these collisions.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 16 out of 16 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.


FailureOr<Value> baseValue = getBaseValue(start);
if (failed(baseValue))
return {false, false}; // Could not find base value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which cases could lead to this? I wonder if we should make the function FailureOr<..> and fail here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It fails like this when the start value (i.e., the output of the reduction) cannot be traced back to a memref::AllocOp or a BlockArgument. I do like using FailureOr, so I'm updating to that.

if (writers) {
for (OpOperand *writerOperand : *writers) {
auto genericOp = dyn_cast<linalg::GenericOp>(writerOperand->getOwner());
if (!genericOp) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we "continue" it if it's not genericOp?
We might be also to find another genericOp in the next one even if this one is not a genericOp, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, we could have:

out = gemm(...)
b = linalg(out)
a = non_linalg(b)
return a

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding of .getWriters() from BufferDependencyAnalysis is that it will do just that. It already has logic that traces through so that any rock.transforms will be skipped. Are there other non-linalg ops that we care about? Wouldn't that mean that it's a non-fusible operation and should be skipped?

Copy link
Contributor

@dhernandez0 dhernandez0 Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it does that, yes, but then if the op is not a linalg.generic you just skip it here, why do we want to do that?

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 16 out of 16 changed files in this pull request and generated no new comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@justinrosner justinrosner force-pushed the 1720-problem-key-extension branch from dea97c1 to d3a6a7a Compare December 1, 2025 18:01
@justinrosner
Copy link
Contributor Author

axis doesn't tell what is the rank of underlying tensor. if let's says x is 5D and axis = 2; y = 4D and axis = 2, they both will have differnet amount of strides and can lead to different number of DPP instructions being used.

@umangyadav Also just updated to get the stride information from the memref type.

@pfultz2
Copy link

pfultz2 commented Dec 1, 2025

Reductions can have multiple axes although we reshape it to one axis in migaphx due to limitations with TOSA, but there is an important difference between {N,C,H,W} -> {N,C,1,1} and {N,C,H,W} -> {N,C/8,1,1,1} which this problem key doesnt capture.

Perhaps the reduction should be the space of the reduction as a tensor in the order it is from the output of the gemm or conv(ie a transpose is unlikely but if it does happen it should be the order before the transpose not after). So something like this: 1x1xHxW or 1x1x8xHxW(assuming H and W are integers).

@pfultz2
Copy link

pfultz2 commented Dec 1, 2025

There could also be multiple reductions with different methods as well. For dynamic quantization, we could fuse a min and max reduction so I dont know if this handles specifying the method for each reduction.

@justinrosner
Copy link
Contributor Author

justinrosner commented Dec 1, 2025

There could also be multiple reductions with different methods as well. For dynamic quantization, we could fuse a min and max reduction so I dont know if this handles specifying the method for each reduction.

This can handle multiple reductions with separate methods, but it relies on earlier pattern matching in TosaToRock to correctly identify reductions and create rock.reduce ops. I was taking a look at RockAttrDefs.td and we only have max and sum reductions, not min. @umangyadav @dhernandez0 do you know why that is?

@justinrosner
Copy link
Contributor Author

Reductions can have multiple axes although we reshape it to one axis in migaphx due to limitations with TOSA, but there is an important difference between {N,C,H,W} -> {N,C,1,1} and {N,C,H,W} -> {N,C/8,1,1,1} which this problem key doesnt capture.

Perhaps the reduction should be the space of the reduction as a tensor in the order it is from the output of the gemm or conv(ie a transpose is unlikely but if it does happen it should be the order before the transpose not after). So something like this: 1x1xHxW or 1x1x8xHxW(assuming H and W are integers).

@pfultz2 do you have an example of MIGraphX IR with {N,C,H,W} -> {N,C/8,1,1,1}? I want to see if we have the rock.transforms in our IR that would enable us to detect something like this (and make sure that they haven't been optimized away at the point where --emit-tuning-key is called)

};

// Helper to get the base value (allocation or block argument) from a value
static FailureOr<Value> getBaseValue(Value v) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be moved to loweringUtils.h

}

// Helper to trace backwards from a value to see if it reaches the target
// Returns success(hasPointwise) if target is reached, failure otherwise
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Returns success(hasPointwise) if target is reached, failure otherwise
// Returns success if target is reached via traversing through pointwise, failure otherwise

SmallVector<ReductionInfo> reductions;

bool hasReduction() const { return !reductions.empty(); }
int numReductionOutputs() const { return reductions.size(); }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
int numReductionOutputs() const { return reductions.size(); }
int numReductions() const { return reductions.size(); }

numReductionOutputs is ambigous. It could mean how many reduction operations are being returned from the function or how many outputs each reduction op has (if there is such an op with multiple outputs).

case ReduceMethod::Sum:
problemOS << "sum";
break;
case ReduceMethod::Max:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can there be some other kind of reduce operations ? Something other than Sum/Max ?

i suggest adding llvm_unreachable as default case.

#transform_map8 = #rock.transform_map<#map5 by [<Unmerge{128} ["exp1"] at [1] -> ["dim0"] at [0]>, <PassThrough ["dim1"] at [2] -> ["dim1"] at [1]>, <AddDim{1} ["unit0"] at [0] -> [] at []>] bounds = [1, 128, 256] -> [128, 256]>
#transform_map9 = #rock.transform_map<#map6 by [<Merge{1, 128, 1} ["dim0"] at [0] -> ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [128] -> [1, 128, 1]>
module {
func.func private @gemm_mul_reduce_sum(%arg0: memref<8192xf32> {mhal.read_access}, %arg1: memref<16384xf32> {mhal.read_access}, %arg2: memref<32768xf32> {mhal.read_access}, %arg3: memref<128xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0.000000e+00 : f32}) attributes {arch = "gfx942", kernel, num_cu = 120 : i64} {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:
use num_cu = 304 for gfx942 everywhere

// Extract stride for the reduction dimension
SmallVector<int64_t> strides;
int64_t offset;
if (succeeded(memrefType.getStridesAndOffset(strides, offset))) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Is there a test with stride != 1 ?
I think the way we use memref strides are always 1. Layout information is encoded by series of rock.transforms. So i don't think there is an easy way to get that information out.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the stride from memref is the real stride in rocmlir. We can have transforms that do padding etc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, as @pfultz2 suggested, we might want to add a list of axes from the input tensor shape that get reduced, instead of the output axis.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my question as well, is if we are overcomplicating things, maybe to know we are doing a reduction is enough, does it help tuning in any way to know the axis?

@umangyadav
Copy link
Member

There could also be multiple reductions with different methods as well. For dynamic quantization, we could fuse a min and max reduction so I dont know if this handles specifying the method for each reduction.

This can handle multiple reductions with separate methods, but it relies on earlier pattern matching in TosaToRock to correctly identify reductions and create rock.reduce ops. I was taking a look at RockAttrDefs.td and we only have max and sum reductions, not min. @umangyadav @dhernandez0 do you know why that is?

Most probably because we haven't had need to enable other reductions

@dhernandez0
Copy link
Contributor

There could also be multiple reductions with different methods as well. For dynamic quantization, we could fuse a min and max reduction so I dont know if this handles specifying the method for each reduction.

This can handle multiple reductions with separate methods, but it relies on earlier pattern matching in TosaToRock to correctly identify reductions and create rock.reduce ops. I was taking a look at RockAttrDefs.td and we only have max and sum reductions, not min. @umangyadav @dhernandez0 do you know why that is?

AFAIK, we only enable sum reductions at the output of the gemm. That is because there's no atomic max for CDNA (well there's only atomic_max_f64).

Regarding @pfultz2 point about axis, he's right, we are returning the axis of reduction of the output tensor, but I guess they are more interested in the axis from the input tensor point of view? We can use getLowerSubDimensions to get which axis we are reducing from the input @justinrosner.

trans_scale_b = False

# Store the original command line for accurate tuning DB lookups
# (including fusion info which we don't parse but need for cache key)
Copy link
Contributor

@dhernandez0 dhernandez0 Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't we parse the fusion key? If you have two gemms

-t f16 -m 64 -n 64 -k 64
-t f16 -m 64 -n 64 -k 64 -reduction ...

They should be treated as different problems, not the same. That's the point of all of these changes, right? to tune them as different problems.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, when we run exhaustive tuning + quick tuning, they are different problems as well.

@dhernandez0
Copy link
Contributor

before we continue working on this, I think we should have a better understading about perf numbers, it seems with the initial results it didn't help: https://github.com/ROCm/rocMLIR-internal/issues/1720#issuecomment-3582852236 Is this worth it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants