Skip to content

Commit f860ccc

Browse files
Added changes for correct computations:
- PyTorch’s attention already supplies the correct scale for the base‑e softmax. All the following files are changed to support exp computation instead of exp2. 1. ConvertTorchUnstructuredToLinalgExt: FlexAttentionOpConversion pattern passes use_exp2 = false, which can be used correctly in decomposition. 2. AggregatedOpInterfaceImpl: accepts the use_exp2 flag as an attribute for decomposition and calls computeSubAndExp accordingly. 3. LinalgExtOps.td: Added getUseExp2AttrStr() to both online_attention and attention. 4. ReshapeFusion.cpp: createCollapsedOP was recreating the attention op and stripping all attributes before StripCompilationPass. Change was necessary to support correct decomposition. Signed-off-by: Keshav Vinayak Jha <[email protected]>
1 parent fca1fe8 commit f860ccc

File tree

4 files changed

+59
-22
lines changed

4 files changed

+59
-22
lines changed

compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,12 +409,23 @@ struct FlexAttentionOpConversion
409409

410410
indexingMaps.push_back(oMap);
411411

412+
// Create decomposition config with use_exp2 flag
413+
// PyTorch's compiled kernels use exp2 for performance, so we match that
414+
SmallVector<NamedAttribute> decompositionConfigAttrs;
415+
decompositionConfigAttrs.push_back(
416+
rewriter.getNamedAttr("use_exp2", rewriter.getBoolAttr(false)));
417+
DictionaryAttr decompositionConfig =
418+
rewriter.getDictionaryAttr(decompositionConfigAttrs);
419+
412420
// Create attention op
413421
auto attentionOp = IREE::LinalgExt::AttentionOp::create(
414422
rewriter, loc, outputTensor.getType(), builtinQuery, builtinKey,
415423
builtinValue, scale, outputTensor,
416424
rewriter.getAffineMapArrayAttr(indexingMaps), mask);
417425

426+
// Set decomposition config
427+
attentionOp.setDecompositionConfigAttr(decompositionConfig);
428+
418429
{
419430
OpBuilder::InsertionGuard g(rewriter);
420431
Block *block = rewriter.createBlock(&attentionOp.getRegion());

compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -294,15 +294,14 @@ static Value applyMask(OpBuilder &builder, Location loc, AffineMap qkMap,
294294
return genericOp.getResult(0);
295295
}
296296

297-
// Compute output = exp2(output - input)
298-
static Value computeSubAndExp2(OpBuilder &builder, Location loc,
299-
AffineMap inputMap, AffineMap outputMap,
300-
Value input, Value output) {
297+
// Compute output = exp2/exp(output - input) depending on useExp2 flag.
298+
static Value computeSubAndExp(OpBuilder &builder, Location loc,
299+
AffineMap inputMap, AffineMap outputMap,
300+
Value input, Value output, bool useExp2) {
301301
SmallVector<AffineMap> compressedMaps =
302302
compressUnusedDims(SmallVector<AffineMap>{inputMap, outputMap});
303303
inputMap = compressedMaps[0];
304304
outputMap = compressedMaps[1];
305-
306305
SmallVector<utils::IteratorType> iteratorTypes(inputMap.getNumDims(),
307306
utils::IteratorType::parallel);
308307
auto genericOp = linalg::GenericOp::create(
@@ -313,8 +312,9 @@ static Value computeSubAndExp2(OpBuilder &builder, Location loc,
313312
Value in = convertScalarToDtype(b, loc, args[0], args[1].getType(),
314313
/*isUnsignedCast=*/false);
315314
Value diff = arith::SubFOp::create(b, loc, args[1], in);
316-
Value weight = math::Exp2Op::create(b, loc, diff);
317-
linalg::YieldOp::create(b, loc, weight);
315+
Operation *weight = useExp2 ? math::Exp2Op::create(b, loc, diff)
316+
: math::ExpOp::create(b, loc, diff);
317+
linalg::YieldOp::create(b, loc, weight->getResult(0));
318318
});
319319
return genericOp.getResult(0);
320320
}
@@ -350,15 +350,18 @@ Value computeQKAndElementwise(Location loc, OpBuilder &b, Value query,
350350
std::optional<AffineMap> maskMap,
351351
SmallVector<OpFoldResult> iterationDomain,
352352
Type sElementType, Region &elementwiseRegion,
353-
DictionaryAttr qkAttrs, bool lowPrecision) {
353+
DictionaryAttr qkAttrs, bool lowPrecision,
354+
bool useExp2) {
354355
MLIRContext *ctx = b.getContext();
355-
// Since we use exp2 for attention instead of the original exp, we have to
356+
// If using exp2 for attention instead of the original exp, we have to
356357
// multiply the scale by log2(e). We use exp2 instead of exp as most platforms
357358
// have better support for exp2 (we verified that we gain some speedup on
358359
// some GPUs).
359-
Value log2e = arith::ConstantOp::create(
360-
b, loc, b.getFloatAttr(scale.getType(), M_LOG2E));
361-
scale = arith::MulFOp::create(b, loc, scale, log2e);
360+
if (useExp2) {
361+
Value log2e = arith::ConstantOp::create(
362+
b, loc, b.getFloatAttr(scale.getType(), M_LOG2E));
363+
scale = arith::MulFOp::create(b, loc, scale, log2e);
364+
}
362365

363366
auto qETy = getElementTypeOrSelf(query.getType());
364367

@@ -445,9 +448,12 @@ FailureOr<SmallVector<Value>> AttentionOp::decomposeOperation(OpBuilder &b) {
445448
DictionaryAttr config = getDecompositionConfigAttr();
446449

447450
DictionaryAttr qkAttrs, pvAttrs;
451+
bool useExp2 = true; // Default to exp2 for backward compatibility
448452
if (config) {
449453
qkAttrs = config.getAs<DictionaryAttr>(getQKAttrStr());
450454
pvAttrs = config.getAs<DictionaryAttr>(getPVAttrStr());
455+
if (auto useExp2Attr = config.getAs<BoolAttr>(getUseExp2AttrStr()))
456+
useExp2 = useExp2Attr.getValue();
451457
}
452458
Value output = getOutput();
453459

@@ -470,9 +476,9 @@ FailureOr<SmallVector<Value>> AttentionOp::decomposeOperation(OpBuilder &b) {
470476
Type f32Type = b.getF32Type();
471477

472478
// ---- QK Matmul + elementwise math ----
473-
Value s = computeQKAndElementwise(loc, b, query, key, getScale(), mask, qMap,
474-
kMap, sMap, getMaskMap(), sizes, f32Type,
475-
getRegion(), qkAttrs, lowPrecision);
479+
Value s = computeQKAndElementwise(
480+
loc, b, query, key, getScale(), mask, qMap, kMap, sMap, getMaskMap(),
481+
sizes, f32Type, getRegion(), qkAttrs, lowPrecision, useExp2);
476482

477483
// ---- Softmax ----
478484

@@ -512,9 +518,9 @@ FailureOr<SmallVector<Value>> AttentionOp::decomposeOperation(OpBuilder &b) {
512518
// max = rowMax(S)
513519
Value max = reduce<arith::MaximumFOp>(b, loc, sMap, maxMap, s, maxFill);
514520

515-
// P = exp2(S - max)
521+
// P = exp2(S - max) or exp(S - max) depending on useExp2 flag
516522
AffineMap pMap = sMap;
517-
Value p = computeSubAndExp2(b, loc, maxMap, sMap, max, s);
523+
Value p = computeSubAndExp(b, loc, maxMap, sMap, max, s, useExp2);
518524

519525
// sum = rowSum(P)
520526
Value sum = reduce<arith::AddFOp>(b, loc, pMap, sumMap, p, sumFill);
@@ -564,9 +570,12 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
564570
DictionaryAttr config = getDecompositionConfigAttr();
565571

566572
DictionaryAttr qkAttrs, pvAttrs;
573+
bool useExp2 = true; // Default to exp2 for backward compatibility
567574
if (config) {
568575
qkAttrs = config.getAs<DictionaryAttr>(getQKAttrStr());
569576
pvAttrs = config.getAs<DictionaryAttr>(getPVAttrStr());
577+
if (auto useExp2Attr = config.getAs<BoolAttr>(getUseExp2AttrStr()))
578+
useExp2 = useExp2Attr.getValue();
570579
}
571580

572581
FailureOr<AttentionOpDetail> maybeOpInfo = AttentionOpDetail::get(
@@ -587,7 +596,7 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
587596
// ---- QK Matmul + elementwise math ----
588597
Value s = computeQKAndElementwise(
589598
loc, b, query, key, getScale(), mask, qMap, kMap, sMap, getMaskMap(),
590-
sizes, elementType, getRegion(), qkAttrs, lowPrecision);
599+
sizes, elementType, getRegion(), qkAttrs, lowPrecision, useExp2);
591600

592601
// TODO: This decomposition should be in a seperate op called
593602
// "online softmax".
@@ -597,20 +606,21 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
597606
AffineMap maxMap = getMaxMap();
598607
Value newMax = reduce<arith::MaximumFOp>(b, loc, sMap, maxMap, s, oldMax);
599608

600-
// norm = exp2(oldMax - newMax)
609+
// norm = exp2(oldMax - newMax) or exp(oldMax - newMax) depending on useExp2
601610
// normMap = maxMap
602611
AffineMap normMap = getMaxMap();
603-
Value norm = computeSubAndExp2(b, loc, maxMap, normMap, newMax, oldMax);
612+
Value norm =
613+
computeSubAndExp(b, loc, maxMap, normMap, newMax, oldMax, useExp2);
604614

605615
// normSum = norm * oldSum
606616
AffineMap sumMap = getSumMap();
607617
Value normSum = elementwiseValueInPlace<arith::MulFOp>(b, loc, sumMap,
608618
normMap, oldSum, norm);
609619

610-
// P = exp2(S - newMax)
620+
// P = exp2(S - newMax) or exp(S - newMax) depending on useExp2
611621
// PMap = SMap
612622
AffineMap pMap = sMap;
613-
Value p = computeSubAndExp2(b, loc, maxMap, sMap, newMax, s);
623+
Value p = computeSubAndExp(b, loc, maxMap, sMap, newMax, s, useExp2);
614624

615625
// newSum = normSum + rowSum(P)
616626
Value newSum = reduce<arith::AddFOp>(b, loc, pMap, sumMap, p, normSum);

compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,8 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_Op<"attention",
871871
// Attributes to set on QK and PV matmul after decomposition.
872872
static StringRef getQKAttrStr() { return "qk_attrs"; }
873873
static StringRef getPVAttrStr() { return "pv_attrs"; }
874+
// Flag to control whether to use exp2 (with log2(e) scaling) or exp.
875+
static StringRef getUseExp2AttrStr() { return "use_exp2"; }
874876
}];
875877

876878
let hasCanonicalizer = 1;
@@ -1013,6 +1015,8 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_Op<"online_attention",
10131015
// Attributes to set on QK and PV matmul after decomposition.
10141016
static StringRef getQKAttrStr() { return "qk_attrs"; }
10151017
static StringRef getPVAttrStr() { return "pv_attrs"; }
1018+
// Flag to control whether to use exp2 (with log2(e) scaling) or exp.
1019+
static StringRef getUseExp2AttrStr() { return "use_exp2"; }
10161020
}];
10171021

10181022
let hasCanonicalizer = 1;

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,6 +1055,12 @@ static Operation *createCollapsedOp(AttentionOp origOp,
10551055
rewriter, origOp.getLoc(), resultTypes, inputOperands[0],
10561056
inputOperands[1], inputOperands[2], inputOperands[3], outputOperands[0],
10571057
rewriter.getAffineMapArrayAttr(indexingMaps), maskOperand);
1058+
1059+
// Preserve decomposition_config attribute from original op
1060+
if (auto config = origOp.getDecompositionConfigAttr()) {
1061+
collapsedOp.setDecompositionConfigAttr(config);
1062+
}
1063+
10581064
rewriter.inlineRegionBefore(origOp.getRegion(), collapsedOp.getRegion(),
10591065
collapsedOp.getRegion().begin());
10601066
return collapsedOp;
@@ -1152,6 +1158,12 @@ struct DropAttentionUnitDims final
11521158
newOperands.take_front(attentionOp.getNumDpsInputs()),
11531159
newOperands.take_back(attentionOp.getNumDpsInits()),
11541160
b.getAffineMapArrayAttr(newIndexingMaps));
1161+
1162+
// Preserve decomposition_config attribute from original op
1163+
if (auto config = attentionOp.getDecompositionConfigAttr()) {
1164+
newOp.setDecompositionConfigAttr(config);
1165+
}
1166+
11551167
b.cloneRegionBefore(attentionOp.getRegion(), newOp.getRegion(),
11561168
newOp.getRegion().begin());
11571169
return newOp;

0 commit comments

Comments
 (0)