Skip to content

Commit 881f6ed

Browse files
committed
[TOSA] MultiheadAttention legalization
- Legalize Torch scaled_dot_product_attention into TOSA by adding the necessary patterns in TorchToTosa.cpp plus backend type-conversion hooks. - Introduce a detailed decomposition path for multi-head attention within DecomposeComplexOps.cpp, preparing inputs for TOSA lowering. - Expands the PT1 e2e suite with a dedicated multi-head attention MLIR/Python test and drop the corresponding xfails now that the path works. Signed-off-by: Cathal Corbett <[email protected]> Change-Id: I96c17aefd25b979f1cf6e897d91d5a29f0a2fa85
1 parent ada7c2f commit 881f6ed

File tree

6 files changed

+398
-11
lines changed

6 files changed

+398
-11
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4016,8 +4016,28 @@ LogicalResult ConvertAtenOp<AtenTransposeIntOp>::matchAndRewrite(
40164016
transposedDims[dim0] = dim1;
40174017
transposedDims[dim1] = dim0;
40184018

4019+
Type resultType = getTypeConverter()->convertType(op.getType());
4020+
if (auto rankedSelf = dyn_cast<RankedTensorType>(selfType)) {
4021+
SmallVector<int64_t> transposedShape(rankedSelf.getRank(),
4022+
ShapedType::kDynamic);
4023+
if (rankedSelf.hasStaticShape()) {
4024+
auto staticShape = llvm::to_vector(
4025+
makeShapeTorchCompatible(rankedSelf.getShape()));
4026+
auto dim0Index = static_cast<size_t>(dim0);
4027+
auto dim1Index = static_cast<size_t>(dim1);
4028+
if (dim0Index < staticShape.size() && dim1Index < staticShape.size())
4029+
std::swap(staticShape[dim0Index], staticShape[dim1Index]);
4030+
for (size_t i = 0; i < staticShape.size(); ++i)
4031+
transposedShape[i] = staticShape[i];
4032+
}
4033+
auto rankedResult = RankedTensorType::get(
4034+
makeShapeLLVMCompatible(transposedShape), rankedSelf.getElementType());
4035+
if (auto converted = getTypeConverter()->convertType(rankedResult))
4036+
resultType = converted;
4037+
}
4038+
40194039
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
4020-
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(),
4040+
op, resultType, adaptor.getSelf(),
40214041
rewriter.getDenseI32ArrayAttr(transposedDims));
40224042

40234043
return success();
@@ -9393,6 +9413,32 @@ class ConvertTorchToTosa
93939413
};
93949414
} // namespace
93959415

9416+
namespace {
9417+
class FoldStaticToDynamicTensorCast
9418+
: public OpConversionPattern<tensor::CastOp> {
9419+
public:
9420+
using OpConversionPattern<tensor::CastOp>::OpConversionPattern;
9421+
LogicalResult
9422+
matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
9423+
ConversionPatternRewriter &rewriter) const override {
9424+
auto sourceType = dyn_cast<RankedTensorType>(adaptor.getSource().getType());
9425+
auto resultType = dyn_cast<RankedTensorType>(op.getType());
9426+
if (!sourceType || !resultType)
9427+
return failure();
9428+
if (sourceType.getElementType() != resultType.getElementType())
9429+
return failure();
9430+
if (!sourceType.hasStaticShape())
9431+
return failure();
9432+
if (!resultType.hasStaticShape())
9433+
return failure();
9434+
if (sourceType == resultType)
9435+
return failure();
9436+
rewriter.replaceOp(op, adaptor.getSource());
9437+
return success();
9438+
}
9439+
};
9440+
} // namespace
9441+
93969442
void populateTorchToTosaConversionLegalOps(ConversionTarget &target) {
93979443
// The following ops are never the primary reason why lowering fails.
93989444
// The backend contract only allows functions to return tensors thus there
@@ -9408,6 +9454,22 @@ void populateTorchToTosaConversionLegalOps(ConversionTarget &target) {
94089454
target.addLegalOp<ConstantDeviceOp>();
94099455
target.addLegalOp<PrimListConstructOp>();
94109456
target.addLegalOp<PrimTupleConstructOp>();
9457+
target.addDynamicallyLegalOp<tensor::CastOp>(
9458+
[](tensor::CastOp op) -> bool {
9459+
auto sourceType = dyn_cast<RankedTensorType>(op.getSource().getType());
9460+
auto resultType = dyn_cast<RankedTensorType>(op.getType());
9461+
if (!sourceType || !resultType)
9462+
return true;
9463+
if (sourceType.getElementType() != resultType.getElementType())
9464+
return true;
9465+
if (!sourceType.hasStaticShape())
9466+
return true;
9467+
if (!resultType.hasStaticShape())
9468+
return true;
9469+
if (sourceType == resultType)
9470+
return true;
9471+
return false;
9472+
});
94119473
}
94129474

94139475
std::set<StringRef> populateTorchToTosaConversionPatternsAndIllegalOps(
@@ -9729,6 +9791,8 @@ std::set<StringRef> populateTorchToTosaConversionPatternsAndIllegalOps(
97299791
INSERT_CAST_ATENOP_PATTERN(AtenIntReprOp);
97309792
#undef INSERT_CAST_ATENOP_PATTERN
97319793

9794+
patterns.add<FoldStaticToDynamicTensorCast>(typeConverter, context);
9795+
97329796
return illegalOps;
97339797
}
97349798

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2295,6 +2295,220 @@ class DecomposeAtenTraceOp : public OpRewritePattern<AtenTraceOp> {
22952295
};
22962296
} // namespace
22972297

2298+
namespace {
2299+
// Decompose scaled dot product attention into matmul/softmax pipeline when
2300+
// there is no masking, dropout, causal, or GQA behaviour.
2301+
class DecomposeAtenScaledDotProductAttentionOp
2302+
: public OpRewritePattern<AtenScaledDotProductAttentionOp> {
2303+
public:
2304+
using OpRewritePattern::OpRewritePattern;
2305+
LogicalResult matchAndRewrite(AtenScaledDotProductAttentionOp op,
2306+
PatternRewriter &rewriter) const override {
2307+
Location loc = op.getLoc();
2308+
2309+
if (!isa<Torch::NoneType>(op.getAttnMask().getType()))
2310+
return rewriter.notifyMatchFailure(
2311+
op, "attention mask decomposition not implemented");
2312+
2313+
double dropoutP;
2314+
if (!matchPattern(op.getDropoutP(), m_TorchConstantFloat(&dropoutP)) ||
2315+
dropoutP != 0.0)
2316+
return rewriter.notifyMatchFailure(
2317+
op, "expected dropout_p to be the constant 0.0");
2318+
2319+
bool isCausal;
2320+
if (!matchPattern(op.getIsCausal(), m_TorchConstantBool(&isCausal)) ||
2321+
isCausal)
2322+
return rewriter.notifyMatchFailure(op,
2323+
"causal attention not supported yet");
2324+
2325+
bool enableGqa;
2326+
if (!matchPattern(op.getEnableGqa(), m_TorchConstantBool(&enableGqa)) ||
2327+
enableGqa)
2328+
return rewriter.notifyMatchFailure(op,
2329+
"grouped-query attention unsupported");
2330+
2331+
Value query = op.getQuery();
2332+
Value key = op.getKey();
2333+
Value value = op.getValue();
2334+
2335+
auto queryTensorType = dyn_cast<BaseTensorType>(query.getType());
2336+
auto keyTensorType = dyn_cast<BaseTensorType>(key.getType());
2337+
auto valueTensorType = dyn_cast<BaseTensorType>(value.getType());
2338+
if (!queryTensorType || !keyTensorType || !valueTensorType)
2339+
return rewriter.notifyMatchFailure(op, "expected tensor inputs");
2340+
if (!queryTensorType.hasSizes() || !keyTensorType.hasSizes() ||
2341+
!valueTensorType.hasSizes())
2342+
return rewriter.notifyMatchFailure(
2343+
op, "expected tensor inputs to have known shapes");
2344+
auto queryValueTensorType = dyn_cast<ValueTensorType>(queryTensorType);
2345+
auto keyValueTensorType = dyn_cast<ValueTensorType>(keyTensorType);
2346+
auto valueValueTensorType = dyn_cast<ValueTensorType>(valueTensorType);
2347+
if (!queryValueTensorType || !keyValueTensorType || !valueValueTensorType)
2348+
return rewriter.notifyMatchFailure(op,
2349+
"expected value tensor semantics");
2350+
2351+
Value oneInt =
2352+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
2353+
Value zeroInt =
2354+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
2355+
Value rank = rewriter.create<AtenDimOp>(loc, query);
2356+
Value lastDim = rewriter.create<AtenSubIntOp>(loc, rank, oneInt);
2357+
Value headDim = rewriter.create<AtenSizeIntOp>(loc, query, lastDim);
2358+
Value seqDimIndex = rewriter.create<AtenSubIntOp>(loc, lastDim, oneInt);
2359+
Value seqLen = rewriter.create<AtenSizeIntOp>(loc, query, seqDimIndex);
2360+
Value keySeqLen = rewriter.create<AtenSizeIntOp>(loc, key, seqDimIndex);
2361+
ArrayRef<int64_t> querySizes = queryValueTensorType.getSizes();
2362+
bool hasExplicitHeadDim = querySizes.size() >= 4;
2363+
Value numHeadsSize = hasExplicitHeadDim
2364+
? (Value)rewriter.create<AtenSizeIntOp>(loc, query,
2365+
oneInt)
2366+
: oneInt;
2367+
Value batchSize = rewriter.create<AtenSizeIntOp>(loc, query, zeroInt);
2368+
auto listIntType =
2369+
Torch::ListType::get(Torch::IntType::get(rewriter.getContext()));
2370+
2371+
auto getDimValue = [&](int64_t staticDim, Value fallback) -> Value {
2372+
if (staticDim != Torch::kUnknownSize)
2373+
return ConstantIntOp::create(
2374+
rewriter, loc, rewriter.getI64IntegerAttr(staticDim));
2375+
return fallback;
2376+
};
2377+
2378+
Value scaleFloat;
2379+
if (isa<Torch::NoneType>(op.getScale().getType())) {
2380+
Value sqrtHeadDim = rewriter.create<AtenSqrtIntOp>(loc, headDim);
2381+
Value oneFloat = rewriter.create<ConstantFloatOp>(
2382+
loc, rewriter.getF64FloatAttr(1.0));
2383+
scaleFloat = rewriter.create<AtenDivFloatOp>(loc, oneFloat, sqrtHeadDim);
2384+
} else {
2385+
scaleFloat = op.getScale();
2386+
}
2387+
2388+
Value negTwo =
2389+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-2));
2390+
Value negOne =
2391+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
2392+
2393+
ArrayRef<int64_t> keySizes = keyValueTensorType.getSizes();
2394+
SmallVector<int64_t> keyTransposedSizes(keySizes.begin(), keySizes.end());
2395+
if (keyTransposedSizes.size() < 2)
2396+
return rewriter.notifyMatchFailure(
2397+
op, "expected key tensor rank >= 2 for transpose");
2398+
std::swap(keyTransposedSizes[keyTransposedSizes.size() - 1],
2399+
keyTransposedSizes[keyTransposedSizes.size() - 2]);
2400+
ArrayRef<int64_t> keyTransposedRef(keyTransposedSizes);
2401+
std::optional<ArrayRef<int64_t>> keyTransposedOpt(keyTransposedRef);
2402+
Type keyTransposedType = keyValueTensorType.getWithSizesAndDtypeAndSparsity(
2403+
keyTransposedSizes, keyValueTensorType.getOptionalDtype(),
2404+
keyValueTensorType.getOptionalSparsity());
2405+
Value keyTransposed = rewriter.create<AtenTransposeIntOp>(
2406+
loc, keyTransposedType, key, negTwo, negOne);
2407+
SmallVector<Value> keyDims;
2408+
auto getOrFallback = [&](ArrayRef<int64_t> staticDims, unsigned idx,
2409+
Value fallback) -> Value {
2410+
return getDimValue(idx < staticDims.size() ? staticDims[idx]
2411+
: Torch::kUnknownSize,
2412+
fallback);
2413+
};
2414+
keyDims.push_back(getOrFallback(keyTransposedSizes, 0, batchSize));
2415+
if (keyTransposedSizes.size() == 4) {
2416+
keyDims.push_back(getOrFallback(keyTransposedSizes, 1, numHeadsSize));
2417+
keyDims.push_back(getOrFallback(keyTransposedSizes, 2, seqLen));
2418+
keyDims.push_back(getOrFallback(keyTransposedSizes, 3, keySeqLen));
2419+
} else {
2420+
keyDims.push_back(getOrFallback(keyTransposedSizes, 1, headDim));
2421+
keyDims.push_back(getOrFallback(keyTransposedSizes, 2, keySeqLen));
2422+
}
2423+
Value keyTransposeShapeList = rewriter.create<PrimListConstructOp>(
2424+
loc, listIntType, ValueRange(keyDims));
2425+
keyTransposed = rewriter.create<AtenViewOp>(loc, keyTransposedType,
2426+
keyTransposed,
2427+
keyTransposeShapeList);
2428+
2429+
auto getStaticDim = [](ArrayRef<int64_t> sizes, int64_t index) {
2430+
if (index < 0)
2431+
index += sizes.size();
2432+
if (index < 0 || index >= static_cast<int64_t>(sizes.size()))
2433+
return Torch::kUnknownSize;
2434+
return sizes[index];
2435+
};
2436+
int64_t queryBatchStatic = getStaticDim(querySizes, 0);
2437+
int64_t querySeqStatic = getStaticDim(querySizes, -2);
2438+
int64_t keySeqStatic = getStaticDim(keySizes, -2);
2439+
int64_t queryHeadsStatic =
2440+
hasExplicitHeadDim ? getStaticDim(querySizes, 1) : 1;
2441+
SmallVector<int64_t, 4> scoresSizes;
2442+
if (hasExplicitHeadDim)
2443+
scoresSizes.assign(
2444+
{queryBatchStatic, queryHeadsStatic, querySeqStatic, keySeqStatic});
2445+
else
2446+
scoresSizes.assign({queryBatchStatic, querySeqStatic, keySeqStatic});
2447+
Type scoresType = ValueTensorType::get(
2448+
op->getContext(),
2449+
ArrayRef<int64_t>(scoresSizes.begin(), scoresSizes.end()),
2450+
queryValueTensorType.getOptionalDtype(),
2451+
queryValueTensorType.getOptionalSparsity());
2452+
Value scores = rewriter.create<AtenMatmulOp>(loc, scoresType, query,
2453+
keyTransposed);
2454+
SmallVector<Value> scoresDims;
2455+
scoresDims.push_back(getDimValue(scoresSizes[0], batchSize));
2456+
unsigned seqIndex = 1;
2457+
if (hasExplicitHeadDim) {
2458+
scoresDims.push_back(getDimValue(scoresSizes[1], numHeadsSize));
2459+
seqIndex = 2;
2460+
}
2461+
scoresDims.push_back(getDimValue(scoresSizes[seqIndex], seqLen));
2462+
scoresDims.push_back(getDimValue(scoresSizes.back(), keySeqLen));
2463+
Value scoresShapeList = rewriter.create<PrimListConstructOp>(
2464+
loc, listIntType, ValueRange(scoresDims));
2465+
scores = rewriter.create<AtenViewOp>(loc, scoresType, scores,
2466+
scoresShapeList);
2467+
Value scaledScores = rewriter.create<AtenMulScalarOp>(
2468+
loc, scoresType, scores, scaleFloat);
2469+
2470+
SmallVector<int64_t> reducedSizes(scoresSizes.begin(), scoresSizes.end());
2471+
reducedSizes.back() = 1;
2472+
ArrayRef<int64_t> reducedSizesRef(reducedSizes);
2473+
std::optional<ArrayRef<int64_t>> reducedSizesOpt(reducedSizesRef);
2474+
Type reducedValueType =
2475+
ValueTensorType::get(op->getContext(), reducedSizesOpt,
2476+
queryValueTensorType.getOptionalDtype());
2477+
Type reducedIndexType =
2478+
ValueTensorType::get(op->getContext(), reducedSizesOpt,
2479+
IntegerType::get(op->getContext(), 64,
2480+
IntegerType::Signed));
2481+
Value keepDimTrue =
2482+
rewriter.create<ConstantBoolOp>(loc, rewriter.getBoolAttr(true));
2483+
auto maxOp = rewriter.create<AtenMaxDimOp>(
2484+
loc, reducedValueType, reducedIndexType, scaledScores, negOne,
2485+
keepDimTrue);
2486+
Value softmaxMax = rewriter.create<TensorStaticInfoCastOp>(
2487+
loc, reducedValueType, maxOp.getValues());
2488+
Value centered = createTensorSub(rewriter, loc, scoresType, scaledScores,
2489+
softmaxMax);
2490+
Value unNormalizedExp =
2491+
rewriter.create<AtenExpOp>(loc, scoresType, centered);
2492+
Value dimList = rewriter.create<PrimListConstructOp>(
2493+
loc, listIntType, ValueRange(negOne));
2494+
Value noneValue = rewriter.create<ConstantNoneOp>(loc);
2495+
Value softmaxDenominator = rewriter.create<AtenSumDimIntListOp>(
2496+
loc, reducedValueType, unNormalizedExp, dimList, keepDimTrue,
2497+
noneValue);
2498+
softmaxDenominator = rewriter.create<TensorStaticInfoCastOp>(
2499+
loc, reducedValueType, softmaxDenominator);
2500+
Value softmax = rewriter.create<AtenDivTensorOp>(
2501+
loc, scoresType, unNormalizedExp, softmaxDenominator);
2502+
2503+
Value output = rewriter.create<AtenMatmulOp>(
2504+
loc, op.getType(), softmax, value);
2505+
2506+
rewriter.replaceOp(op, output);
2507+
return success();
2508+
}
2509+
};
2510+
} // namespace
2511+
22982512
// Calculates the softmax function on the given `input` tensor. Softmax(x) =
22992513
// exp(x)/sum(exp(x)).
23002514
// To avoid overflow we use the following decomposition rule:
@@ -13086,6 +13300,8 @@ class DecomposeComplexOpsPass
1308613300

1308713301
populateTransformerEncoderPatterns(patterns, legalOpsSet);
1308813302

13303+
patterns.add<DecomposeAtenScaledDotProductAttentionOp>(context);
13304+
1308913305
addPatternIfTargetOpIsIllegal<DecomposeAten_WeightNormInterfaceOp>(
1309013306
patterns);
1309113307
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftmaxIntOp>(patterns);

lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
1111
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
12+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1213

1314
using namespace mlir;
1415
using namespace mlir::torch;
@@ -40,6 +41,23 @@ static void setupValueTensorToBuiltinTensorConversion(
4041
return {};
4142
return ToBuiltinTensorOp::create(builder, loc, type, inputs[0]);
4243
});
44+
typeConverter.addTargetMaterialization(
45+
[](OpBuilder &builder, Type type, ValueRange inputs,
46+
Location loc) -> Value {
47+
if (inputs.size() != 1)
48+
return Value();
49+
auto fromType = dyn_cast<RankedTensorType>(inputs[0].getType());
50+
auto toType = dyn_cast<RankedTensorType>(type);
51+
if (!fromType || !toType)
52+
return Value();
53+
if (fromType == toType)
54+
return inputs[0];
55+
if (fromType.getElementType() != toType.getElementType())
56+
return Value();
57+
if (!toType.hasStaticShape())
58+
return Value();
59+
return builder.create<tensor::CastOp>(loc, toType, inputs[0]);
60+
});
4361
auto sourceMaterialization = [](OpBuilder &builder,
4462
Torch::ValueTensorType type,
4563
ValueRange inputs, Location loc) -> Value {

0 commit comments

Comments
 (0)