@@ -294,15 +294,14 @@ static Value applyMask(OpBuilder &builder, Location loc, AffineMap qkMap,
294294 return genericOp.getResult (0 );
295295}
296296
297- // Compute output = exp2(output - input)
298- static Value computeSubAndExp2 (OpBuilder &builder, Location loc,
299- AffineMap inputMap, AffineMap outputMap,
300- Value input, Value output) {
297+ // Compute output = exp2/exp (output - input) depending on useExp2 flag.
298+ static Value computeSubAndExp (OpBuilder &builder, Location loc,
299+ AffineMap inputMap, AffineMap outputMap,
300+ Value input, Value output, bool useExp2 ) {
301301 SmallVector<AffineMap> compressedMaps =
302302 compressUnusedDims (SmallVector<AffineMap>{inputMap, outputMap});
303303 inputMap = compressedMaps[0 ];
304304 outputMap = compressedMaps[1 ];
305-
306305 SmallVector<utils::IteratorType> iteratorTypes (inputMap.getNumDims (),
307306 utils::IteratorType::parallel);
308307 auto genericOp = linalg::GenericOp::create (
@@ -313,8 +312,9 @@ static Value computeSubAndExp2(OpBuilder &builder, Location loc,
313312 Value in = convertScalarToDtype (b, loc, args[0 ], args[1 ].getType (),
314313 /* isUnsignedCast=*/ false );
315314 Value diff = arith::SubFOp::create (b, loc, args[1 ], in);
316- Value weight = math::Exp2Op::create (b, loc, diff);
317- linalg::YieldOp::create (b, loc, weight);
315+ Operation *weight = useExp2 ? math::Exp2Op::create (b, loc, diff)
316+ : math::ExpOp::create (b, loc, diff);
317+ linalg::YieldOp::create (b, loc, weight->getResult (0 ));
318318 });
319319 return genericOp.getResult (0 );
320320}
@@ -350,15 +350,18 @@ Value computeQKAndElementwise(Location loc, OpBuilder &b, Value query,
350350 std::optional<AffineMap> maskMap,
351351 SmallVector<OpFoldResult> iterationDomain,
352352 Type sElementType , Region &elementwiseRegion,
353- DictionaryAttr qkAttrs, bool lowPrecision) {
353+ DictionaryAttr qkAttrs, bool lowPrecision,
354+ bool useExp2) {
354355 MLIRContext *ctx = b.getContext ();
355- // Since we use exp2 for attention instead of the original exp, we have to
356+ // If using exp2 for attention instead of the original exp, we have to
356357 // multiply the scale by log2(e). We use exp2 instead of exp as most platforms
357358 // have better support for exp2 (we verified that we gain some speedup on
358359 // some GPUs).
359- Value log2e = arith::ConstantOp::create (
360- b, loc, b.getFloatAttr (scale.getType (), M_LOG2E));
361- scale = arith::MulFOp::create (b, loc, scale, log2e);
360+ if (useExp2) {
361+ Value log2e = arith::ConstantOp::create (
362+ b, loc, b.getFloatAttr (scale.getType (), M_LOG2E));
363+ scale = arith::MulFOp::create (b, loc, scale, log2e);
364+ }
362365
363366 auto qETy = getElementTypeOrSelf (query.getType ());
364367
@@ -445,9 +448,12 @@ FailureOr<SmallVector<Value>> AttentionOp::decomposeOperation(OpBuilder &b) {
445448 DictionaryAttr config = getDecompositionConfigAttr ();
446449
447450 DictionaryAttr qkAttrs, pvAttrs;
451+ bool useExp2 = true ; // Default to exp2 for backward compatibility
448452 if (config) {
449453 qkAttrs = config.getAs <DictionaryAttr>(getQKAttrStr ());
450454 pvAttrs = config.getAs <DictionaryAttr>(getPVAttrStr ());
455+ if (auto useExp2Attr = config.getAs <BoolAttr>(getUseExp2AttrStr ()))
456+ useExp2 = useExp2Attr.getValue ();
451457 }
452458 Value output = getOutput ();
453459
@@ -470,9 +476,9 @@ FailureOr<SmallVector<Value>> AttentionOp::decomposeOperation(OpBuilder &b) {
470476 Type f32Type = b.getF32Type ();
471477
472478 // ---- QK Matmul + elementwise math ----
473- Value s = computeQKAndElementwise (loc, b, query, key, getScale (), mask, qMap,
474- kMap , sMap , getMaskMap (), sizes, f32Type ,
475- getRegion (), qkAttrs, lowPrecision);
479+ Value s = computeQKAndElementwise (
480+ loc, b, query, key, getScale (), mask, qMap, kMap , sMap , getMaskMap (),
481+ sizes, f32Type, getRegion (), qkAttrs, lowPrecision, useExp2 );
476482
477483 // ---- Softmax ----
478484
@@ -512,9 +518,9 @@ FailureOr<SmallVector<Value>> AttentionOp::decomposeOperation(OpBuilder &b) {
512518 // max = rowMax(S)
513519 Value max = reduce<arith::MaximumFOp>(b, loc, sMap , maxMap, s, maxFill);
514520
515- // P = exp2(S - max)
521+ // P = exp2(S - max) or exp(S - max) depending on useExp2 flag
516522 AffineMap pMap = sMap ;
517- Value p = computeSubAndExp2 (b, loc, maxMap, sMap , max, s);
523+ Value p = computeSubAndExp (b, loc, maxMap, sMap , max, s, useExp2 );
518524
519525 // sum = rowSum(P)
520526 Value sum = reduce<arith::AddFOp>(b, loc, pMap, sumMap, p, sumFill);
@@ -564,9 +570,12 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
564570 DictionaryAttr config = getDecompositionConfigAttr ();
565571
566572 DictionaryAttr qkAttrs, pvAttrs;
573+ bool useExp2 = true ; // Default to exp2 for backward compatibility
567574 if (config) {
568575 qkAttrs = config.getAs <DictionaryAttr>(getQKAttrStr ());
569576 pvAttrs = config.getAs <DictionaryAttr>(getPVAttrStr ());
577+ if (auto useExp2Attr = config.getAs <BoolAttr>(getUseExp2AttrStr ()))
578+ useExp2 = useExp2Attr.getValue ();
570579 }
571580
572581 FailureOr<AttentionOpDetail> maybeOpInfo = AttentionOpDetail::get (
@@ -587,7 +596,7 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
587596 // ---- QK Matmul + elementwise math ----
588597 Value s = computeQKAndElementwise (
589598 loc, b, query, key, getScale (), mask, qMap, kMap , sMap , getMaskMap (),
590- sizes, elementType, getRegion (), qkAttrs, lowPrecision);
599+ sizes, elementType, getRegion (), qkAttrs, lowPrecision, useExp2 );
591600
592601 // TODO: This decomposition should be in a seperate op called
593602 // "online softmax".
@@ -597,20 +606,21 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
597606 AffineMap maxMap = getMaxMap ();
598607 Value newMax = reduce<arith::MaximumFOp>(b, loc, sMap , maxMap, s, oldMax);
599608
600- // norm = exp2(oldMax - newMax)
609+ // norm = exp2(oldMax - newMax) or exp(oldMax - newMax) depending on useExp2
601610 // normMap = maxMap
602611 AffineMap normMap = getMaxMap ();
603- Value norm = computeSubAndExp2 (b, loc, maxMap, normMap, newMax, oldMax);
612+ Value norm =
613+ computeSubAndExp (b, loc, maxMap, normMap, newMax, oldMax, useExp2);
604614
605615 // normSum = norm * oldSum
606616 AffineMap sumMap = getSumMap ();
607617 Value normSum = elementwiseValueInPlace<arith::MulFOp>(b, loc, sumMap,
608618 normMap, oldSum, norm);
609619
610- // P = exp2(S - newMax)
620+ // P = exp2(S - newMax) or exp(S - newMax) depending on useExp2
611621 // PMap = SMap
612622 AffineMap pMap = sMap ;
613- Value p = computeSubAndExp2 (b, loc, maxMap, sMap , newMax, s);
623+ Value p = computeSubAndExp (b, loc, maxMap, sMap , newMax, s, useExp2 );
614624
615625 // newSum = normSum + rowSum(P)
616626 Value newSum = reduce<arith::AddFOp>(b, loc, pMap, sumMap, p, normSum);
0 commit comments