Skip to content
Open
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,26 @@ def RockAccelTuningParamAttrInterface : AttrInterface<"RockAccelTuningParamAttrI
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/""
>,
InterfaceMethod<
/*desc=*/[{
Return waves_per_eu, this is a hint to the backend compiler to tune the number of wavefronts that are capable of fitting within the resources of an EU.
}],
/*retType=*/"int64_t",
/*methodName=*/"getWavesPerEU",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/""
>,
InterfaceMethod<
/*desc=*/[{
Group size for layout on the distribution of the workgroups.
}],
/*retType=*/"int64_t",
/*methodName=*/"getGridGroupSize",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/""
>

// TODO: more methods here as needed
Expand Down
31 changes: 26 additions & 5 deletions mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def Rock_AttnPerfConfig : Rock_Attr<"AttnPerfConfig", [RockTuningParamAttrInterf
"int64_t":$nPerBlockG0, "int64_t":$kpackPerBlock, "int64_t":$mPerWave,
"int64_t":$nPerWave, "int64_t":$mnPerXdl, "int64_t":$kpack,
"int64_t":$splitKFactor, "int64_t":$scheduleVersion,
"int64_t":$outputSwizzle, "bool":$forceUnroll);
"int64_t":$outputSwizzle, "int64_t":$wavesPerEU, "bool":$forceUnroll);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not gridGroupSize?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gridGroupSize is only used for GEMMs/convs, not attention/g+g kernels.


let extraClassDeclaration = [{
void getPerfConfigStr(::llvm::SmallVectorImpl<char> &perfStr) {
Expand All @@ -296,13 +296,14 @@ def Rock_AttnPerfConfig : Rock_Attr<"AttnPerfConfig", [RockTuningParamAttrInterf
+ Twine(getSplitKFactor()) + ","
+ Twine(getScheduleVersion()) + ","
+ Twine(getOutputSwizzle()) + ","
+ Twine(getWavesPerEU()) + ","
+ Twine(getForceUnroll())).toVector(perfStr);
}
AttnPerfConfigAttr withScheduleVersion(int64_t newScheduleVersion) const {
return AttnPerfConfigAttr::get(
getContext(), getMPerBlockG0(), getMPerBlockG1(), getNPerBlockG0(),
getKpackPerBlock(), getMPerWave(), getNPerWave(), getMnPerXdl(), getKpack(),
getSplitKFactor(), newScheduleVersion, getOutputSwizzle(), getForceUnroll());
getSplitKFactor(), newScheduleVersion, getOutputSwizzle(), getWavesPerEU(), getForceUnroll());
}
}];

Expand All @@ -325,7 +326,7 @@ def Rock_MfmaGemmParamsAttr
"int64_t":$nPerBlock, "int64_t":$kpack, "int64_t":$mPerWave,
"int64_t":$nPerWave, "int64_t":$mnPerXdl, "int64_t":$splitKFactor,
"int64_t":$scheduleVersion, "int64_t":$outputSwizzle,
"bool":$forceUnroll);
"int64_t":$wavesPerEU, "int64_t":$gridGroupSize, "bool":$forceUnroll);

let extraClassDeclaration = [{
void getPerfConfigStr(::llvm::SmallVectorImpl<char> &perfStr) {
Expand All @@ -339,6 +340,8 @@ def Rock_MfmaGemmParamsAttr
+ Twine(getSplitKFactor()) + ","
+ Twine(getScheduleVersion()) + ","
+ Twine(getOutputSwizzle()) + ","
+ Twine(getWavesPerEU()) + ","
+ Twine(getGridGroupSize()) + ","
+ Twine(getForceUnroll()) + ","
+ "1") /* *ThreadCopyMore* */
.toVector(perfStr);
Expand All @@ -359,7 +362,7 @@ def Rock_WmmaGemmParamsAttr : Rock_Attr<"WmmaGemmParams", [RockTuningParamAttrIn
"int64_t":$nPerBlock, "int64_t":$kpack, "int64_t":$mPerWave,
"int64_t":$nPerWave, "int64_t":$mnPerXdl, "int64_t":$splitKFactor,
"int64_t":$scheduleVersion, "int64_t":$outputSwizzle,
"bool":$forceUnroll);
"int64_t":$wavesPerEU, "int64_t":$gridGroupSize, "bool":$forceUnroll);

let extraClassDeclaration = [{
void getPerfConfigStr(SmallVectorImpl<char> &perfStr) {
Expand All @@ -373,6 +376,8 @@ def Rock_WmmaGemmParamsAttr : Rock_Attr<"WmmaGemmParams", [RockTuningParamAttrIn
+ Twine(getSplitKFactor()) + ","
+ Twine(getScheduleVersion()) + ","
+ Twine(getOutputSwizzle()) + ","
+ Twine(getWavesPerEU()) + ","
+ Twine(getGridGroupSize()) + ","
+ Twine(getForceUnroll()) + ","
+ "1") /* *ThreadCopyMore* */
.toVector(perfStr);
Expand Down Expand Up @@ -459,9 +464,25 @@ def Rock_ScheduleVersionAttr : Rock_Attr<"ScheduleVersion"> {
}];
}

// It is a temporary attribute
def Rock_EnableSplitKForTuning : Rock_Attr<"EnableSplitKForTuning"> {
let mnemonic = "enable_splitk_for_tuning";
let description = [{
Whether we tune for split-k. If unset, split-k=1.
}];
}

def Rock_WavesPerEU : Rock_Attr<"WavesPerEU"> {
let mnemonic = "waves_per_eu";
let description = [{
This is a hint to the backend compiler to tune the number of wavefronts that are capable of fitting within the resources of an EU.
}];
}

def Rock_OutputSwizzle : Rock_Attr<"OutputSwizzle"> {
let mnemonic = "output_swizzle";
let description = [{
Whether we run the output swizzle pass. 0 -> disabled, 1 -> enabled, 2 -> heuristic.
}];
}

def Rock_PrefillAttr : Rock_Attr<"Prefill"> {
Expand Down
31 changes: 29 additions & 2 deletions mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ struct InitParamsNonAccel : InitParams, Serializable<InitParamsNonAccel> {
};

struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
// TODO: remove once we generate new quick tuning list
constexpr InitParamsAccel(int64_t mPerBlock, int64_t nPerBlock,
int64_t kPerBlock, int64_t mPerWave,
int64_t nPerWave, int64_t mnPerXdl, int64_t kPack,
Expand All @@ -152,12 +153,28 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
gemmNPerWave(nPerWave), gemmMnPerXdl(mnPerXdl),
gemmNPerWaveOrMnPerXdl(0), gemmKPack(kPack), splitKFactor(splitKFactor),
gemmScheduleVersion(scheduleVersion), outputSwizzle(outputSwizzle),
wavesPerEU(0), gridGroupSize(0),
gemmAThreadCopyMoreGemmK(aThreadCopyMoreGemmK),
gemmBThreadCopyMoreGemmKPack(bThreadCopyMoreGemmKPack) {}

constexpr InitParamsAccel(int64_t mPerBlock, int64_t nPerBlock,
int64_t kPerBlock, int64_t mPerWave,
int64_t nPerWave, int64_t mnPerXdl, int64_t kPack,
int64_t splitKFactor, int64_t scheduleVersion,
int64_t outputSwizzle, int64_t wavesPerEU,
int64_t gridGroupSize, bool aThreadCopyMoreGemmK,
bool bThreadCopyMoreGemmKPack)
: InitParams{mPerBlock, nPerBlock, kPerBlock}, gemmMPerWave(mPerWave),
gemmNPerWave(nPerWave), gemmMnPerXdl(mnPerXdl),
gemmNPerWaveOrMnPerXdl(0), gemmKPack(kPack), splitKFactor(splitKFactor),
gemmScheduleVersion(scheduleVersion), outputSwizzle(outputSwizzle),
wavesPerEU(wavesPerEU), gridGroupSize(gridGroupSize),
gemmAThreadCopyMoreGemmK(aThreadCopyMoreGemmK),
gemmBThreadCopyMoreGemmKPack(bThreadCopyMoreGemmKPack) {}

constexpr InitParamsAccel()
: InitParamsAccel(0LL, 0LL, 0LL, 0LL, 0LL, 0LL, 0LL, 1LL, 1LL, 2LL, false,
false) {}
: InitParamsAccel(0LL, 0LL, 0LL, 0LL, 0LL, 0LL, 0LL, 1LL, 1LL, 2LL, 0LL,
0LL, false, false) {}

InitParamsAccel(MfmaGemmParamsAttr attr)
: InitParams{attr.getMPerBlock(), attr.getNPerBlock(),
Expand All @@ -167,6 +184,8 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
gemmKPack(attr.getKpack()), splitKFactor(attr.getSplitKFactor()),
gemmScheduleVersion(attr.getScheduleVersion()),
outputSwizzle(attr.getOutputSwizzle()),
wavesPerEU(attr.getWavesPerEU()),
gridGroupSize(attr.getGridGroupSize()),
gemmAThreadCopyMoreGemmK(attr.getForceUnroll()),
gemmBThreadCopyMoreGemmKPack(false) {};

Expand All @@ -178,6 +197,8 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
gemmKPack(attr.getKpack()), splitKFactor(attr.getSplitKFactor()),
gemmScheduleVersion(attr.getScheduleVersion()),
outputSwizzle(attr.getOutputSwizzle()),
wavesPerEU(attr.getWavesPerEU()),
gridGroupSize(attr.getGridGroupSize()),
gemmAThreadCopyMoreGemmK(attr.getForceUnroll()),
gemmBThreadCopyMoreGemmKPack(false) {};

Expand All @@ -191,6 +212,8 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
int64_t splitKFactor;
int64_t gemmScheduleVersion;
int64_t outputSwizzle;
int64_t wavesPerEU;
int64_t gridGroupSize;
bool gemmAThreadCopyMoreGemmK;
bool gemmBThreadCopyMoreGemmKPack;

Expand All @@ -214,6 +237,10 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
f(self.gemmScheduleVersion);
f(self.outputSwizzle);
}
if (self.version >= Version::V4) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why adding this here instead of adding it in the previous check if (self.version >= Version::V4)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because we want to visit the params in the right order.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see serialize() in Serializable.h

f(self.wavesPerEU);
f(self.gridGroupSize);
}
f(self.gemmAThreadCopyMoreGemmK);
f(self.gemmBThreadCopyMoreGemmKPack);
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Rock/Tuning/Serializable.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ struct Serializable {
}

bool checkVersionFormat(const std::string &s) {
const int32_t maxNumTokensArray[] = {0, 8, 9, 11, 12};
const int32_t maxNumTokensArray[] = {0, 8, 9, 11, 14};
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we haven't done a release since v4 (I think), is it ok to keep the same version?

const int32_t versionIdx = static_cast<int32_t>(version);
if (versionIdx < 1 || versionIdx >= static_cast<int32_t>(Version::Count)) {
llvm_unreachable("Unknown version of the perfConfig");
Expand Down
128 changes: 76 additions & 52 deletions mlir/lib/Conversion/RockToGPU/RockToGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,63 @@ struct WorkgroupIdRewritePattern
};
} // namespace

static void runWavesPerEUHeuristic(OpBuilder b, gpu::GPUFuncOp gpuFunc,
int64_t ldsUsage) {
LLVM_DEBUG(llvm::dbgs() << "Using heuristic to set wavesPerEU...\n");
if (!gpuFunc->hasAttrOfType<IntegerAttr>("block_size")) {
LLVM_DEBUG(llvm::dbgs() << "blockSize not found in gpuFunc.\n");
return;
}
int64_t blockSize =
gpuFunc->getAttrOfType<IntegerAttr>("block_size").getInt();
if (!gpuFunc->hasAttrOfType<IntegerAttr>("grid_size")) {
LLVM_DEBUG(llvm::dbgs() << "gridSize not found in gpuFunc.\n");
return;
}
int64_t gridSize = gpuFunc->getAttrOfType<IntegerAttr>("grid_size").getInt();
FailureOr<StringAttr> maybeArch = rock::getArch(gpuFunc);
if (succeeded(maybeArch)) {
StringAttr arch = maybeArch.value();
rock::AmdArchInfo archInfo = rock::lookupArchInfo(arch);
FailureOr<int64_t> maybeNumCU = rock::getNumCU(gpuFunc);
int64_t numCU = maybeNumCU.value_or(archInfo.minNumCU);
int64_t totalEUs = archInfo.numEUPerCU * numCU;
int64_t wavesPerBlock = (blockSize / archInfo.waveSize);
int64_t totalWaves = wavesPerBlock * gridSize;
int64_t wavesPerEUPerBlock = wavesPerBlock / archInfo.numEUPerCU;
int64_t wavesPerEUPerGrid = (totalWaves + totalEUs - 1) / totalEUs;
int64_t wavesPerEU = std::max(wavesPerEUPerBlock, wavesPerEUPerGrid);
LLVM_DEBUG(llvm::dbgs() << "wavesPerEU:" << wavesPerEU << "\n");
LLVM_DEBUG(llvm::dbgs() << " blockSize:" << blockSize << "\n");
LLVM_DEBUG(llvm::dbgs() << " waveSize:" << archInfo.waveSize << "\n");
LLVM_DEBUG(llvm::dbgs() << " gridSize:" << gridSize << "\n");
LLVM_DEBUG(llvm::dbgs() << " numCU:" << numCU << "\n");
LLVM_DEBUG(llvm::dbgs() << " numEUPerCU:" << archInfo.numEUPerCU << "\n");
LLVM_DEBUG(llvm::dbgs()
<< "maxSharedMemPerWG:" << archInfo.maxSharedMemPerWG << "\n");
LLVM_DEBUG(llvm::dbgs() << "ldsUsage:" << ldsUsage << "\n");
// limit wavesPerEU based on lds usage
if (ldsUsage > 0) {
wavesPerEU =
std::min(wavesPerEU, archInfo.totalSharedMemPerCU / ldsUsage);
}
// Currently limiting wavesPerEU to be two
// it is a future to ticket to remove this constraint with further
// analysis
constexpr int64_t wavesPerEUUpperBound = 2;
wavesPerEU = std::min(wavesPerEU, wavesPerEUUpperBound);
if (wavesPerEU > 1) {
LLVM_DEBUG(llvm::dbgs() << "waves_per_eu:" << wavesPerEU << "\n");
gpuFunc->setAttr("rocdl.waves_per_eu", b.getI32IntegerAttr(wavesPerEU));
} else {
LLVM_DEBUG(llvm::dbgs() << "waves_per_eu not set"
<< "\n");
}
} else {
LLVM_DEBUG(llvm::dbgs() << "arch not found.\n");
}
}

void LowerRockOpsToGPUPass::runOnOperation() {
ModuleOp op = getOperation();
MLIRContext *ctx = op.getContext();
Expand Down Expand Up @@ -204,6 +261,11 @@ void LowerRockOpsToGPUPass::runOnOperation() {
gridSize = cast<IntegerAttr>(gridSizeAttr).getInt();
gpuFunc.setKnownGridSizeAttr(b.getDenseI32ArrayAttr({gridSize, 1, 1}));

auto wavesPerEUAttr = theFunc->getAttr(rock::WavesPerEUAttr::getMnemonic());
if (wavesPerEUAttr) {
gpuFunc->setAttr(rock::WavesPerEUAttr::getMnemonic(), wavesPerEUAttr);
}

FailureOr<StringAttr> maybeArch = rock::getArch(theFunc);
if (succeeded(maybeArch)) {
gpuFunc->setAttr("arch", maybeArch.value());
Expand Down Expand Up @@ -391,61 +453,23 @@ void LowerRockOpsToGPUPass::runOnOperation() {
gpuFunc->setAttr("rock.shared_buffer_size",
b.getI32IntegerAttr(ldsUsage));
}
LLVM_DEBUG(llvm::dbgs() << "Attempting to set wavesPerEU...\n");
if (!gpuFunc->hasAttrOfType<IntegerAttr>("block_size")) {
LLVM_DEBUG(llvm::dbgs() << "blockSize not found in gpuFunc.\n");
return;
}
int64_t blockSize =
gpuFunc->getAttrOfType<IntegerAttr>("block_size").getInt();
if (!gpuFunc->hasAttrOfType<IntegerAttr>("grid_size")) {
LLVM_DEBUG(llvm::dbgs() << "gridSize not found in gpuFunc.\n");
return;
}
int64_t gridSize =
gpuFunc->getAttrOfType<IntegerAttr>("grid_size").getInt();
FailureOr<StringAttr> maybeArch = rock::getArch(gpuFunc);
if (succeeded(maybeArch)) {
StringAttr arch = maybeArch.value();
rock::AmdArchInfo archInfo = rock::lookupArchInfo(arch);
FailureOr<int64_t> maybeNumCU = rock::getNumCU(gpuFunc);
int64_t numCU = maybeNumCU.value_or(archInfo.minNumCU);
int64_t totalEUs = archInfo.numEUPerCU * numCU;
int64_t wavesPerBlock = (blockSize / archInfo.waveSize);
int64_t totalWaves = wavesPerBlock * gridSize;
int64_t wavesPerEUPerBlock = wavesPerBlock / archInfo.numEUPerCU;
int64_t wavesPerEUPerGrid = (totalWaves + totalEUs - 1) / totalEUs;
int64_t wavesPerEU = std::max(wavesPerEUPerBlock, wavesPerEUPerGrid);
LLVM_DEBUG(llvm::dbgs() << "wavesPerEU:" << wavesPerEU << "\n");
LLVM_DEBUG(llvm::dbgs() << " blockSize:" << blockSize << "\n");
LLVM_DEBUG(llvm::dbgs() << " waveSize:" << archInfo.waveSize << "\n");
LLVM_DEBUG(llvm::dbgs() << " gridSize:" << gridSize << "\n");
LLVM_DEBUG(llvm::dbgs() << " numCU:" << numCU << "\n");
LLVM_DEBUG(llvm::dbgs()
<< " numEUPerCU:" << archInfo.numEUPerCU << "\n");
LLVM_DEBUG(llvm::dbgs()
<< "maxSharedMemPerWG:" << archInfo.maxSharedMemPerWG << "\n");
LLVM_DEBUG(llvm::dbgs() << "ldsUsage:" << ldsUsage << "\n");
// limit wavesPerEU based on lds usage
if (ldsUsage > 0) {
wavesPerEU =
std::min(wavesPerEU, archInfo.totalSharedMemPerCU / ldsUsage);
}
// Currently limiting wavesPerEU to be two
// it is a future to ticket to remove this constraint with further
// analysis
constexpr int64_t wavesPerEUUpperBound = 2;
wavesPerEU = std::min(wavesPerEU, wavesPerEUUpperBound);
if (wavesPerEU > 1) {
LLVM_DEBUG(llvm::dbgs() << "waves_per_eu:" << wavesPerEU << "\n");
// if waves_per_eu is set, use it
if (gpuFunc->hasAttrOfType<IntegerAttr>(
rock::WavesPerEUAttr::getMnemonic())) {
int64_t wavesPerEU =
gpuFunc
->getAttrOfType<IntegerAttr>(rock::WavesPerEUAttr::getMnemonic())
.getInt();
// zero means, use heuristic
if (wavesPerEU != 0) {
gpuFunc->setAttr("rocdl.waves_per_eu", b.getI32IntegerAttr(wavesPerEU));
} else {
LLVM_DEBUG(llvm::dbgs() << "waves_per_eu not set"
<< "\n");
LLVM_DEBUG(llvm::dbgs() << "Setting waves_per_eu using tuning param\n");
// we are done
return;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it returning from here ? Doesn't it need to set some other attributes ?

Copy link
Contributor Author

@dhernandez0 dhernandez0 Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, we are inside op.walk([](gpu::GPUFuncOp gpuFunc) {, only two attributes are set: rock.shared_buffer_size and rocdl.waves_per_eu. The code below is just the heuristic to set rocdl.waves_per_eu.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also was confused by this but yes, it looks correct. However, to make it more readable, we could move all the code below into a function like "runWavesPerEUHeuristic", that way it would be easier to understand why a return here

}
} else {
LLVM_DEBUG(llvm::dbgs() << "arch not found.\n");
}
// no "waves_per_eu" attribute, use heuristic
runWavesPerEUHeuristic(b, gpuFunc, ldsUsage);
});

if (gpuModCount == 0) {
Expand Down
11 changes: 7 additions & 4 deletions mlir/lib/Dialect/Rock/IR/RockDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3254,17 +3254,19 @@ AttnPerfConfigAttr AttnPerfConfigAttr::get(StringAttr perfConfigStrAttr,
expectedNumTokens = 11;
break;
case 3:
expectedNumTokens = 12;
expectedNumTokens = 13;
break;
default:
llvm_unreachable("Unknown version of the perfConfig");
}
SmallVector<StringRef, 11> tokens;
SmallVector<StringRef> tokens;
SmallVector<int64_t> params;
tokens.reserve(expectedNumTokens);
params.reserve(expectedNumTokens);
rest.split(tokens, ',');
if (tokens.size() != expectedNumTokens) {
return {};
}
SmallVector<int64_t, 11> params;
llvm::transform(tokens, std::back_inserter(params), [](StringRef s) {
int param;
llvm::to_integer(s, param);
Expand Down Expand Up @@ -3298,11 +3300,12 @@ AttnPerfConfigAttr AttnPerfConfigAttr::get(StringAttr perfConfigStrAttr,
int64_t splitKFactor = version > 1 ? params[lastIdx++] : 1;
int64_t scheduleVersion = version > 1 ? params[lastIdx++] : 1;
int64_t outputSwizzle = version > 1 ? params[lastIdx++] : 2;
int64_t wavesPerEU = isV3 ? params[lastIdx++] : 0; // 0 -> use heuristic
int64_t forceUnroll = params[expectedNumTokens - 1] == 1;
return AttnPerfConfigAttr::get(
perfConfigStrAttr.getContext(), mPerBlockG0, mPerBlockG1, nPerBlockG0,
kpackPerBlock, mPerWave, nPerWave, mnPerXdl, kpack, splitKFactor,
scheduleVersion, outputSwizzle, forceUnroll);
scheduleVersion, outputSwizzle, wavesPerEU, forceUnroll);
}

//===-----------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ void AffixTuningParameters::affixTuningParametersImpl(
Attribute params0 = op.getGemm0Params().value_or(nullptr);
// set a default one if params is not provided
StringAttr perfConfigStrAttr =
builder.getStringAttr("attn:v3:32,32,32,32,32,32,16,1,1,1,2,1");
builder.getStringAttr("attn:v3:32,32,32,32,32,32,16,1,1,1,2,0,1");
if (!params0) {
if (StringAttr mayBePerfConfigStrAttr =
dyn_cast_or_null<StringAttr>(op->getAttr("perf_config"))) {
Expand Down
Loading
Loading