@@ -2295,6 +2295,220 @@ class DecomposeAtenTraceOp : public OpRewritePattern<AtenTraceOp> {
22952295};
22962296} // namespace
22972297
2298+ namespace {
2299+ // Decompose scaled dot product attention into matmul/softmax pipeline when
2300+ // there is no masking, dropout, causal, or GQA behaviour.
2301+ class DecomposeAtenScaledDotProductAttentionOp
2302+ : public OpRewritePattern<AtenScaledDotProductAttentionOp> {
2303+ public:
2304+ using OpRewritePattern::OpRewritePattern;
2305+ LogicalResult matchAndRewrite(AtenScaledDotProductAttentionOp op,
2306+ PatternRewriter &rewriter) const override {
2307+ Location loc = op.getLoc();
2308+
2309+ if (!isa<Torch::NoneType>(op.getAttnMask().getType()))
2310+ return rewriter.notifyMatchFailure(
2311+ op, "attention mask decomposition not implemented");
2312+
2313+ double dropoutP;
2314+ if (!matchPattern(op.getDropoutP(), m_TorchConstantFloat(&dropoutP)) ||
2315+ dropoutP != 0.0)
2316+ return rewriter.notifyMatchFailure(
2317+ op, "expected dropout_p to be the constant 0.0");
2318+
2319+ bool isCausal;
2320+ if (!matchPattern(op.getIsCausal(), m_TorchConstantBool(&isCausal)) ||
2321+ isCausal)
2322+ return rewriter.notifyMatchFailure(op,
2323+ "causal attention not supported yet");
2324+
2325+ bool enableGqa;
2326+ if (!matchPattern(op.getEnableGqa(), m_TorchConstantBool(&enableGqa)) ||
2327+ enableGqa)
2328+ return rewriter.notifyMatchFailure(op,
2329+ "grouped-query attention unsupported");
2330+
2331+ Value query = op.getQuery();
2332+ Value key = op.getKey();
2333+ Value value = op.getValue();
2334+
2335+ auto queryTensorType = dyn_cast<BaseTensorType>(query.getType());
2336+ auto keyTensorType = dyn_cast<BaseTensorType>(key.getType());
2337+ auto valueTensorType = dyn_cast<BaseTensorType>(value.getType());
2338+ if (!queryTensorType || !keyTensorType || !valueTensorType)
2339+ return rewriter.notifyMatchFailure(op, "expected tensor inputs");
2340+ if (!queryTensorType.hasSizes() || !keyTensorType.hasSizes() ||
2341+ !valueTensorType.hasSizes())
2342+ return rewriter.notifyMatchFailure(
2343+ op, "expected tensor inputs to have known shapes");
2344+ auto queryValueTensorType = dyn_cast<ValueTensorType>(queryTensorType);
2345+ auto keyValueTensorType = dyn_cast<ValueTensorType>(keyTensorType);
2346+ auto valueValueTensorType = dyn_cast<ValueTensorType>(valueTensorType);
2347+ if (!queryValueTensorType || !keyValueTensorType || !valueValueTensorType)
2348+ return rewriter.notifyMatchFailure(op,
2349+ "expected value tensor semantics");
2350+
2351+ Value oneInt =
2352+ rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
2353+ Value zeroInt =
2354+ rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
2355+ Value rank = rewriter.create<AtenDimOp>(loc, query);
2356+ Value lastDim = rewriter.create<AtenSubIntOp>(loc, rank, oneInt);
2357+ Value headDim = rewriter.create<AtenSizeIntOp>(loc, query, lastDim);
2358+ Value seqDimIndex = rewriter.create<AtenSubIntOp>(loc, lastDim, oneInt);
2359+ Value seqLen = rewriter.create<AtenSizeIntOp>(loc, query, seqDimIndex);
2360+ Value keySeqLen = rewriter.create<AtenSizeIntOp>(loc, key, seqDimIndex);
2361+ ArrayRef<int64_t> querySizes = queryValueTensorType.getSizes();
2362+ bool hasExplicitHeadDim = querySizes.size() >= 4;
2363+ Value numHeadsSize = hasExplicitHeadDim
2364+ ? (Value)rewriter.create<AtenSizeIntOp>(loc, query,
2365+ oneInt)
2366+ : oneInt;
2367+ Value batchSize = rewriter.create<AtenSizeIntOp>(loc, query, zeroInt);
2368+ auto listIntType =
2369+ Torch::ListType::get(Torch::IntType::get(rewriter.getContext()));
2370+
2371+ auto getDimValue = [&](int64_t staticDim, Value fallback) -> Value {
2372+ if (staticDim != Torch::kUnknownSize)
2373+ return ConstantIntOp::create(
2374+ rewriter, loc, rewriter.getI64IntegerAttr(staticDim));
2375+ return fallback;
2376+ };
2377+
2378+ Value scaleFloat;
2379+ if (isa<Torch::NoneType>(op.getScale().getType())) {
2380+ Value sqrtHeadDim = rewriter.create<AtenSqrtIntOp>(loc, headDim);
2381+ Value oneFloat = rewriter.create<ConstantFloatOp>(
2382+ loc, rewriter.getF64FloatAttr(1.0));
2383+ scaleFloat = rewriter.create<AtenDivFloatOp>(loc, oneFloat, sqrtHeadDim);
2384+ } else {
2385+ scaleFloat = op.getScale();
2386+ }
2387+
2388+ Value negTwo =
2389+ rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-2));
2390+ Value negOne =
2391+ rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
2392+
2393+ ArrayRef<int64_t> keySizes = keyValueTensorType.getSizes();
2394+ SmallVector<int64_t> keyTransposedSizes(keySizes.begin(), keySizes.end());
2395+ if (keyTransposedSizes.size() < 2)
2396+ return rewriter.notifyMatchFailure(
2397+ op, "expected key tensor rank >= 2 for transpose");
2398+ std::swap(keyTransposedSizes[keyTransposedSizes.size() - 1],
2399+ keyTransposedSizes[keyTransposedSizes.size() - 2]);
2400+ ArrayRef<int64_t> keyTransposedRef(keyTransposedSizes);
2401+ std::optional<ArrayRef<int64_t>> keyTransposedOpt(keyTransposedRef);
2402+ Type keyTransposedType = keyValueTensorType.getWithSizesAndDtypeAndSparsity(
2403+ keyTransposedSizes, keyValueTensorType.getOptionalDtype(),
2404+ keyValueTensorType.getOptionalSparsity());
2405+ Value keyTransposed = rewriter.create<AtenTransposeIntOp>(
2406+ loc, keyTransposedType, key, negTwo, negOne);
2407+ SmallVector<Value> keyDims;
2408+ auto getOrFallback = [&](ArrayRef<int64_t> staticDims, unsigned idx,
2409+ Value fallback) -> Value {
2410+ return getDimValue(idx < staticDims.size() ? staticDims[idx]
2411+ : Torch::kUnknownSize,
2412+ fallback);
2413+ };
2414+ keyDims.push_back(getOrFallback(keyTransposedSizes, 0, batchSize));
2415+ if (keyTransposedSizes.size() == 4) {
2416+ keyDims.push_back(getOrFallback(keyTransposedSizes, 1, numHeadsSize));
2417+ keyDims.push_back(getOrFallback(keyTransposedSizes, 2, seqLen));
2418+ keyDims.push_back(getOrFallback(keyTransposedSizes, 3, keySeqLen));
2419+ } else {
2420+ keyDims.push_back(getOrFallback(keyTransposedSizes, 1, headDim));
2421+ keyDims.push_back(getOrFallback(keyTransposedSizes, 2, keySeqLen));
2422+ }
2423+ Value keyTransposeShapeList = rewriter.create<PrimListConstructOp>(
2424+ loc, listIntType, ValueRange(keyDims));
2425+ keyTransposed = rewriter.create<AtenViewOp>(loc, keyTransposedType,
2426+ keyTransposed,
2427+ keyTransposeShapeList);
2428+
2429+ auto getStaticDim = [](ArrayRef<int64_t> sizes, int64_t index) {
2430+ if (index < 0)
2431+ index += sizes.size();
2432+ if (index < 0 || index >= static_cast<int64_t>(sizes.size()))
2433+ return Torch::kUnknownSize;
2434+ return sizes[index];
2435+ };
2436+ int64_t queryBatchStatic = getStaticDim(querySizes, 0);
2437+ int64_t querySeqStatic = getStaticDim(querySizes, -2);
2438+ int64_t keySeqStatic = getStaticDim(keySizes, -2);
2439+ int64_t queryHeadsStatic =
2440+ hasExplicitHeadDim ? getStaticDim(querySizes, 1) : 1;
2441+ SmallVector<int64_t, 4> scoresSizes;
2442+ if (hasExplicitHeadDim)
2443+ scoresSizes.assign(
2444+ {queryBatchStatic, queryHeadsStatic, querySeqStatic, keySeqStatic});
2445+ else
2446+ scoresSizes.assign({queryBatchStatic, querySeqStatic, keySeqStatic});
2447+ Type scoresType = ValueTensorType::get(
2448+ op->getContext(),
2449+ ArrayRef<int64_t>(scoresSizes.begin(), scoresSizes.end()),
2450+ queryValueTensorType.getOptionalDtype(),
2451+ queryValueTensorType.getOptionalSparsity());
2452+ Value scores = rewriter.create<AtenMatmulOp>(loc, scoresType, query,
2453+ keyTransposed);
2454+ SmallVector<Value> scoresDims;
2455+ scoresDims.push_back(getDimValue(scoresSizes[0], batchSize));
2456+ unsigned seqIndex = 1;
2457+ if (hasExplicitHeadDim) {
2458+ scoresDims.push_back(getDimValue(scoresSizes[1], numHeadsSize));
2459+ seqIndex = 2;
2460+ }
2461+ scoresDims.push_back(getDimValue(scoresSizes[seqIndex], seqLen));
2462+ scoresDims.push_back(getDimValue(scoresSizes.back(), keySeqLen));
2463+ Value scoresShapeList = rewriter.create<PrimListConstructOp>(
2464+ loc, listIntType, ValueRange(scoresDims));
2465+ scores = rewriter.create<AtenViewOp>(loc, scoresType, scores,
2466+ scoresShapeList);
2467+ Value scaledScores = rewriter.create<AtenMulScalarOp>(
2468+ loc, scoresType, scores, scaleFloat);
2469+
2470+ SmallVector<int64_t> reducedSizes(scoresSizes.begin(), scoresSizes.end());
2471+ reducedSizes.back() = 1;
2472+ ArrayRef<int64_t> reducedSizesRef(reducedSizes);
2473+ std::optional<ArrayRef<int64_t>> reducedSizesOpt(reducedSizesRef);
2474+ Type reducedValueType =
2475+ ValueTensorType::get(op->getContext(), reducedSizesOpt,
2476+ queryValueTensorType.getOptionalDtype());
2477+ Type reducedIndexType =
2478+ ValueTensorType::get(op->getContext(), reducedSizesOpt,
2479+ IntegerType::get(op->getContext(), 64,
2480+ IntegerType::Signed));
2481+ Value keepDimTrue =
2482+ rewriter.create<ConstantBoolOp>(loc, rewriter.getBoolAttr(true));
2483+ auto maxOp = rewriter.create<AtenMaxDimOp>(
2484+ loc, reducedValueType, reducedIndexType, scaledScores, negOne,
2485+ keepDimTrue);
2486+ Value softmaxMax = rewriter.create<TensorStaticInfoCastOp>(
2487+ loc, reducedValueType, maxOp.getValues());
2488+ Value centered = createTensorSub(rewriter, loc, scoresType, scaledScores,
2489+ softmaxMax);
2490+ Value unNormalizedExp =
2491+ rewriter.create<AtenExpOp>(loc, scoresType, centered);
2492+ Value dimList = rewriter.create<PrimListConstructOp>(
2493+ loc, listIntType, ValueRange(negOne));
2494+ Value noneValue = rewriter.create<ConstantNoneOp>(loc);
2495+ Value softmaxDenominator = rewriter.create<AtenSumDimIntListOp>(
2496+ loc, reducedValueType, unNormalizedExp, dimList, keepDimTrue,
2497+ noneValue);
2498+ softmaxDenominator = rewriter.create<TensorStaticInfoCastOp>(
2499+ loc, reducedValueType, softmaxDenominator);
2500+ Value softmax = rewriter.create<AtenDivTensorOp>(
2501+ loc, scoresType, unNormalizedExp, softmaxDenominator);
2502+
2503+ Value output = rewriter.create<AtenMatmulOp>(
2504+ loc, op.getType(), softmax, value);
2505+
2506+ rewriter.replaceOp(op, output);
2507+ return success();
2508+ }
2509+ };
2510+ } // namespace
2511+
22982512// Calculates the softmax function on the given `input` tensor. Softmax(x) =
22992513// exp(x)/sum(exp(x)).
23002514// To avoid overflow we use the following decomposition rule:
@@ -13086,6 +13300,8 @@ class DecomposeComplexOpsPass
1308613300
1308713301 populateTransformerEncoderPatterns(patterns, legalOpsSet);
1308813302
13303+ patterns.add<DecomposeAtenScaledDotProductAttentionOp>(context);
13304+
1308913305 addPatternIfTargetOpIsIllegal<DecomposeAten_WeightNormInterfaceOp>(
1309013306 patterns);
1309113307 addPatternIfTargetOpIsIllegal<DecomposeAtenSoftmaxIntOp>(patterns);
0 commit comments