From ada7c2fa22afa41b9959d9d83230edb35099874d Mon Sep 17 00:00:00 2001 From: Cathal Corbett Date: Mon, 17 Nov 2025 14:41:35 +0100 Subject: [PATCH] [TORCH] Transformer encoder decomposition - Add a dedicated DecomposeTransformerEncoder pass to expand encoder ops into primitive Torch patterns. - Extend shared lowering helpers (ReduceOpVariants.cpp, Utils.h) so the new pass can reuse reduction utilities during decomposition. - Register the pass in the Torch Transform pipeline so it runs as part of the decomposition flow. - Expand e2e coverage with new transformer encoder tests to validate the lowering path. Signed-off-by: Cathal Corbett Change-Id: I6bcda53569cf7b06df4cb97c624bbf512d8fecb7 --- .../Dialect/Torch/Transforms/Passes.h | 5 + lib/Dialect/Torch/Transforms/CMakeLists.txt | 1 + .../Torch/Transforms/DecomposeComplexOps.cpp | 2 + .../DecomposeTransformerEncoder.cpp | 465 ++++++++++++++++++ .../Torch/Transforms/ReduceOpVariants.cpp | 13 +- lib/Dialect/Torch/Transforms/Utils.h | 34 ++ projects/pt1/e2e_testing/xfail_sets.py | 1 + .../test_suite/__init__.py | 1 + .../test_suite/transformer.py | 42 ++ .../python/transformer_encoder_lowering.py | 74 +++ transformer_encoder.mlir | 54 ++ 11 files changed, 691 insertions(+), 1 deletion(-) create mode 100644 lib/Dialect/Torch/Transforms/DecomposeTransformerEncoder.cpp create mode 100644 lib/Dialect/Torch/Transforms/Utils.h create mode 100755 projects/pt1/python/torch_mlir_e2e_test/test_suite/transformer.py create mode 100644 projects/pt1/test/python/transformer_encoder_lowering.py create mode 100644 transformer_encoder.mlir diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h index 13d3a8de9463..6b6ab2238f4a 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h @@ -12,6 +12,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" +#include "llvm/ADT/StringSet.h" #include @@ -157,6 +158,10 @@ static const char kTorchOpPrefix[] = R"(torch.)"; void populateRestructureNonConstantAxesPattern(RewritePatternSet &patterns, MLIRContext *context); +void populateTransformerEncoderPatterns( + RewritePatternSet &patterns, const llvm::StringSet<> &legalOpsSet); + + std::unique_ptr> createRestructureNonConstantAxesPass(); diff --git a/lib/Dialect/Torch/Transforms/CMakeLists.txt b/lib/Dialect/Torch/Transforms/CMakeLists.txt index 1ce006fbe913..c6f6c1944b22 100644 --- a/lib/Dialect/Torch/Transforms/CMakeLists.txt +++ b/lib/Dialect/Torch/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_library(TorchMLIRTorchPasses AdjustCallingConventions.cpp DecomposeComplexOps.cpp + DecomposeTransformerEncoder.cpp DropAbstractInterpCalculations.cpp EraseModuleInitializer.cpp FuseQuantizedOps.cpp diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 08b25c9b6f60..8e2c22c027cf 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -13084,6 +13084,8 @@ class DecomposeComplexOpsPass legalOpsSet.clear(); legalOpsSet.insert(legalOps.begin(), legalOps.end()); + populateTransformerEncoderPatterns(patterns, legalOpsSet); + addPatternIfTargetOpIsIllegal( patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/DecomposeTransformerEncoder.cpp b/lib/Dialect/Torch/Transforms/DecomposeTransformerEncoder.cpp new file mode 100644 index 000000000000..11c7237dbf90 --- /dev/null +++ b/lib/Dialect/Torch/Transforms/DecomposeTransformerEncoder.cpp @@ -0,0 +1,465 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "Utils.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" +#include "llvm/ADT/Twine.h" +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" + +#include + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +namespace { + +static Value createIntConstant(PatternRewriter &rewriter, Location loc, + int64_t value) { + return rewriter.create( + loc, rewriter.getI64IntegerAttr(value)); +} + +static Value createBoolConstant(PatternRewriter &rewriter, Location loc, + bool value) { + return rewriter.create(loc, rewriter.getBoolAttr(value)); +} + +static Value createFloatConstant(PatternRewriter &rewriter, Location loc, + double value, Type /*dtype*/) { + auto attr = rewriter.getF64FloatAttr(value); + return rewriter.create(loc, attr); +} + +static Value createIntList(PatternRewriter &rewriter, Location loc, + ArrayRef values) { + SmallVector elems; + elems.reserve(values.size()); + for (int64_t v : values) + elems.push_back(createIntConstant(rewriter, loc, v)); + auto listType = Torch::ListType::get(Torch::IntType::get(rewriter.getContext())); + return rewriter.create(loc, listType, elems); +} + +static FailureOr expectRankedTensor(Value v, int64_t rank, + StringRef name, + PatternRewriter &rewriter, + Operation *op) { + auto tensorType = dyn_cast(v.getType()); + if (!tensorType || !tensorType.hasSizes() || + tensorType.getSizes().size() != static_cast(rank)) + return rewriter.notifyMatchFailure( + op, Twine("expected ") + name + " tensor of rank " + Twine(rank)); + return tensorType; +} + +static int64_t adaptSizeForView(int64_t size) { + return size == Torch::kUnknownSize ? -1 : size; +} + +static FailureOr +checkTensorShape(Value value, ArrayRef expected, StringRef name, + PatternRewriter &rewriter, Operation *op) { + auto type = dyn_cast(value.getType()); + if (!type || !type.hasSizes()) + return rewriter.notifyMatchFailure( + op, Twine("expected tensor operands with known sizes for ") + name); + ArrayRef actual = type.getSizes(); + if (actual.size() != expected.size()) + return rewriter.notifyMatchFailure( + op, Twine("rank mismatch for ") + name); + for (size_t i = 0; i < expected.size(); ++i) { + int64_t exp = expected[i]; + if (exp == Torch::kUnknownSize) + continue; + if (actual[i] != exp && actual[i] != Torch::kUnknownSize) + return rewriter.notifyMatchFailure( + op, Twine("dimension mismatch for ") + name + " at index " + + Twine(i)); + } + return type; +} + +struct QkvProjections { + Value query; + Value key; + Value value; +}; + +static FailureOr buildQkvProjections( + PatternRewriter &rewriter, Location loc, Value input, Value qkvWeight, + Value qkvBias, int64_t embedDim, int64_t numHeads, + ValueTensorType inputType) { + auto dtype = inputType.getOptionalDtype(); + ArrayRef sizes = inputType.getSizes(); + int64_t batch = sizes[0]; + int64_t seqLen = sizes[1]; + int64_t headDim = embedDim / numHeads; + + SmallVector linearSizes = {batch, seqLen, 3 * embedDim}; + ValueTensorType linearType = + cast(inputType.getWithSizesAndDtype(linearSizes, dtype)); + Value qkvLinear = + rewriter.create(loc, linearType, input, qkvWeight, qkvBias); + + SmallVector reshapeSizes = {batch, seqLen, 3, numHeads, headDim}; + ValueTensorType reshapeType = + cast(inputType.getWithSizesAndDtype(reshapeSizes, dtype)); + SmallVector reshapeShape = {adaptSizeForView(batch), + adaptSizeForView(seqLen), 3, numHeads, + headDim}; + Value reshapeList = createIntList(rewriter, loc, reshapeShape); + Value reshaped = + rewriter.create(loc, reshapeType, qkvLinear, reshapeList); + + Value dimTwo = createIntConstant(rewriter, loc, 2); + SmallVector selectSizes = {batch, seqLen, numHeads, headDim}; + ValueTensorType selectType = cast( + inputType.getWithSizesAndDtype(selectSizes, dtype)); + + Value qIndex = createIntConstant(rewriter, loc, 0); + Value kIndex = createIntConstant(rewriter, loc, 1); + Value vIndex = createIntConstant(rewriter, loc, 2); + Value q = rewriter.create(loc, selectType, reshaped, dimTwo, + qIndex); + Value k = rewriter.create(loc, selectType, reshaped, dimTwo, + kIndex); + Value v = rewriter.create(loc, selectType, reshaped, dimTwo, + vIndex); + + SmallVector permSizes = {batch, numHeads, seqLen, headDim}; + ValueTensorType permType = + cast(inputType.getWithSizesAndDtype(permSizes, dtype)); + Value permList = createIntList(rewriter, loc, {0, 2, 1, 3}); + q = rewriter.create(loc, permType, q, permList); + k = rewriter.create(loc, permType, k, permList); + v = rewriter.create(loc, permType, v, permList); + + return QkvProjections{q, k, v}; +} + +class DecomposeAtenTransformerEncoderLayerFwd + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(Torch::OperatorOp op, + PatternRewriter &rewriter) const override { + if (!isTransformerEncoderOperator(op)) + return failure(); + + SmallVector operands(op.getOperands().begin(), + op.getOperands().end()); + if (operands.size() != 20) + return rewriter.notifyMatchFailure(op, "expected 20 operands"); + + Value src = operands[0]; + Value embedDimVal = operands[1]; + Value numHeadsVal = operands[2]; + Value qkvWeight = operands[3]; + Value qkvBias = operands[4]; + Value projWeight = operands[5]; + Value projBias = operands[6]; + Value useGelu = operands[7]; + Value normFirst = operands[8]; + Value eps = operands[9]; + Value norm1Weight = operands[10]; + Value norm1Bias = operands[11]; + Value norm2Weight = operands[12]; + Value norm2Bias = operands[13]; + Value ffn1Weight = operands[14]; + Value ffn1Bias = operands[15]; + Value ffn2Weight = operands[16]; + Value ffn2Bias = operands[17]; + Value mask = operands[18]; + Value maskType = operands[19]; + + int64_t embedDim; + int64_t numHeads; + if (!matchPattern(embedDimVal, m_TorchConstantInt(&embedDim)) || + !matchPattern(numHeadsVal, m_TorchConstantInt(&numHeads))) { + return rewriter.notifyMatchFailure( + op, "embed_dim and num_heads must be constant integers"); + } + if (numHeads == 0 || embedDim % numHeads != 0) { + return rewriter.notifyMatchFailure( + op, "embedding dimension must be divisible by number of heads"); + } + + if (!isa(mask.getType())) { + return rewriter.notifyMatchFailure(op, + "attention masks are not supported"); + } + if (!isa(maskType.getType()) && + !maskType.getDefiningOp()) { + int64_t maskTypeValue; + if (!matchPattern(maskType, m_TorchConstantInt(&maskTypeValue)) || + maskTypeValue != 0) { + return rewriter.notifyMatchFailure( + op, "mask_type must be None or the constant 0"); + } + } + + FailureOr srcType = + expectRankedTensor(src, 3, "src", rewriter, op); + if (failed(srcType)) { + return failure(); + } + + bool useGeluBool; + if (!matchPattern(useGelu, m_TorchConstantBool(&useGeluBool))) { + return rewriter.notifyMatchFailure(op, "use_gelu must be constant"); + } + + bool normFirstBool; + if (!matchPattern(normFirst, m_TorchConstantBool(&normFirstBool))) { + return rewriter.notifyMatchFailure(op, "norm_first must be constant"); + } + + if (failed(checkTensorShape(qkvWeight, {3 * embedDim, embedDim}, + "qkv_weight", rewriter, op)) || + failed(checkTensorShape(qkvBias, {3 * embedDim}, "qkv_bias", rewriter, + op)) || + failed(checkTensorShape(projWeight, {embedDim, embedDim}, "proj_weight", + rewriter, op)) || + failed(checkTensorShape(projBias, {embedDim}, "proj_bias", rewriter, + op)) || + failed(checkTensorShape(norm1Weight, {embedDim}, "norm1_weight", + rewriter, op)) || + failed(checkTensorShape(norm1Bias, {embedDim}, "norm1_bias", rewriter, + op)) || + failed(checkTensorShape(norm2Weight, {embedDim}, "norm2_weight", + rewriter, op)) || + failed(checkTensorShape(norm2Bias, {embedDim}, "norm2_bias", rewriter, + op))) { + return failure(); + } + + auto ffn1WeightType = checkTensorShape(ffn1Weight, + {Torch::kUnknownSize, embedDim}, + "ffn1_weight", rewriter, op); + if (failed(ffn1WeightType)) { + return failure(); + } + auto ffn1BiasType = + checkTensorShape(ffn1Bias, {Torch::kUnknownSize}, "ffn1_bias", + rewriter, op); + if (failed(ffn1BiasType)) { + return failure(); + } + auto ffn2WeightType = checkTensorShape(ffn2Weight, + {embedDim, Torch::kUnknownSize}, + "ffn2_weight", rewriter, op); + if (failed(ffn2WeightType)) { + return failure(); + } + if (failed(checkTensorShape(ffn2Bias, {embedDim}, "ffn2_bias", rewriter, + op))) { + return failure(); + } + + int64_t hiddenDim = (*ffn1WeightType).getSizes()[0]; + auto enforceHidden = + [&](int64_t candidate, StringRef what) -> LogicalResult { + if (candidate == Torch::kUnknownSize) + return success(); + if (hiddenDim == Torch::kUnknownSize) { + hiddenDim = candidate; + return success(); + } + if (hiddenDim != candidate) { + return rewriter.notifyMatchFailure( + op, Twine("inconsistent hidden dimension inferred from ") + what); + } + return success(); + }; + if (failed(enforceHidden((*ffn1BiasType).getSizes()[0], "ffn1_bias")) || + failed(enforceHidden((*ffn2WeightType).getSizes()[1], "ffn2_weight"))) { + return failure(); + } + if (hiddenDim == Torch::kUnknownSize) { + return rewriter.notifyMatchFailure( + op, "unable to infer feed-forward hidden dimension"); + } + + Location loc = op.getLoc(); + + auto buildLayerNorm = [&](Value input, Value weight, + Value bias) -> FailureOr { + auto inputTensorType = cast(input.getType()); + Value normalizedShape = createIntList(rewriter, loc, {embedDim}); + Value cudnnEnable = createBoolConstant(rewriter, loc, true); + Value ln = rewriter.create(loc, inputTensorType, input, + normalizedShape, weight, + bias, eps, cudnnEnable); + return ln; + }; + + Value attentionInput = src; + if (normFirstBool) { + auto normed = buildLayerNorm(src, norm1Weight, norm1Bias); + if (failed(normed)) + return failure(); + attentionInput = *normed; + } + + auto attentionInputType = + expectRankedTensor(attentionInput, 3, "attention input", rewriter, op); + if (failed(attentionInputType)) + return failure(); + + FailureOr projections = buildQkvProjections( + rewriter, loc, attentionInput, qkvWeight, qkvBias, embedDim, numHeads, + *attentionInputType); + if (failed(projections)) + return failure(); + + Type elementType = srcType->getOptionalDtype(); + if (!elementType) + elementType = rewriter.getF32Type(); + int64_t headDim = embedDim / numHeads; + + Value permIdx = createIntList(rewriter, loc, {0, 1, 3, 2}); + SmallVector keyTShape = {(*srcType).getSizes()[0], numHeads, + headDim, (*srcType).getSizes()[1]}; + ValueTensorType keyTType = cast( + srcType->getWithSizesAndDtype(keyTShape, srcType->getOptionalDtype())); + Value keyT = + rewriter.create(loc, keyTType, projections->key, permIdx); + + SmallVector scoreShape = {(*srcType).getSizes()[0], numHeads, + (*srcType).getSizes()[1], + (*srcType).getSizes()[1]}; + ValueTensorType scoreType = cast( + srcType->getWithSizesAndDtype(scoreShape, srcType->getOptionalDtype())); + Value scores = rewriter.create(loc, scoreType, + projections->query, keyT); + + double scale = 1.0 / std::sqrt(static_cast(headDim)); + Value scaleConst = createFloatConstant(rewriter, loc, scale, elementType); + scores = rewriter.create(loc, scoreType, scores, + scaleConst); + + Value dimLast = createIntConstant(rewriter, loc, -1); + Value halfToFloat = createBoolConstant(rewriter, loc, false); + Value attnWeights = rewriter.create(loc, scoreType, scores, + dimLast, halfToFloat); + + SmallVector contextShape = {(*srcType).getSizes()[0], numHeads, + (*srcType).getSizes()[1], headDim}; + ValueTensorType contextType = cast( + srcType->getWithSizesAndDtype(contextShape, srcType->getOptionalDtype())); + Value context = rewriter.create(loc, contextType, attnWeights, + projections->value); + + Value mergeIdx = createIntList(rewriter, loc, {0, 2, 1, 3}); + SmallVector mergedPermShape = {(*srcType).getSizes()[0], + (*srcType).getSizes()[1], numHeads, + headDim}; + ValueTensorType mergedPermType = cast( + srcType->getWithSizesAndDtype(mergedPermShape, srcType->getOptionalDtype())); + Value merged = + rewriter.create(loc, mergedPermType, context, mergeIdx); + + SmallVector mergedViewShape = {(*srcType).getSizes()[0], + (*srcType).getSizes()[1], + embedDim}; + Value mergedView = rewriter.create( + loc, src.getType(), merged, + createIntList(rewriter, loc, + {adaptSizeForView((*srcType).getSizes()[0]), + adaptSizeForView((*srcType).getSizes()[1]), embedDim})); + + Value attnOutput = rewriter.create(loc, *srcType, mergedView, + projWeight, projBias); + + Value oneScalar = createIntConstant(rewriter, loc, 1); + Value attnResidual = rewriter.create( + loc, *srcType, src, attnOutput, oneScalar); + + Value postAttn; + if (normFirstBool) { + postAttn = attnResidual; + } else { + auto normed = buildLayerNorm(attnResidual, norm1Weight, norm1Bias); + if (failed(normed)) + return failure(); + postAttn = *normed; + } + + Value feedForwardInput = postAttn; + if (normFirstBool) { + auto normed = buildLayerNorm(attnResidual, norm2Weight, norm2Bias); + if (failed(normed)) + return failure(); + feedForwardInput = *normed; + } + + auto buildFeedForward = [&](Value input) -> FailureOr { + SmallVector hiddenShape = {(*srcType).getSizes()[0], + (*srcType).getSizes()[1], + hiddenDim}; + ValueTensorType hiddenType = cast( + srcType->getWithSizesAndDtype(hiddenShape, srcType->getOptionalDtype())); + Value ff1 = rewriter.create(loc, hiddenType, input, + ffn1Weight, ffn1Bias); + Value activated; + if (useGeluBool) { + Value approx = rewriter.create( + loc, Torch::StringType::get(rewriter.getContext()), + rewriter.getStringAttr("none")); + activated = rewriter.create(loc, hiddenType, ff1, approx); + } else { + activated = rewriter.create(loc, hiddenType, ff1); + } + Value ff2 = rewriter.create(loc, *srcType, activated, + ffn2Weight, ffn2Bias); + return ff2; + }; + + FailureOr feedForwardOut = buildFeedForward(feedForwardInput); + if (failed(feedForwardOut)) + return failure(); + + Value result; + if (normFirstBool) { + result = rewriter.create(loc, *srcType, attnResidual, + *feedForwardOut, oneScalar); + } else { + Value secondResidual = rewriter.create( + loc, *srcType, postAttn, *feedForwardOut, oneScalar); + auto normed = buildLayerNorm(secondResidual, norm2Weight, norm2Bias); + if (failed(normed)) + return failure(); + result = *normed; + } + + rewriter.replaceOp(op, result); + return success(); + } +}; + +} // namespace + +void mlir::torch::Torch::populateTransformerEncoderPatterns( + RewritePatternSet &patterns, const llvm::StringSet<> &legalOpsSet) { + MLIRContext *context = patterns.getContext(); + DecomposeAtenTransformerEncoderLayerFwd pattern(context); + auto opName = pattern.getRootKind(); + if (!opName) + return; + StringRef trimmed = + opName->getStringRef().ltrim(mlir::torch::Torch::kTorchOpPrefix); + if (legalOpsSet.contains(trimmed)) + return; + patterns.add(context); +} diff --git a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp index 187d234183a3..dfe9c19a55c2 100644 --- a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp +++ b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp @@ -7,6 +7,7 @@ // //===----------------------------------------------------------------------===// +#include "Utils.h" #include "ReifyAbstractInterpCalculationsUtils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" @@ -265,7 +266,17 @@ void TorchMatchSpecializedBackendOp::populateSpecializedConversions( }); } -bool isSpecializedOperation(Torch::OperatorOp op) { return true; } +bool isSpecializedOperation(Torch::OperatorOp op) { + auto nameAttr = op.getNameAttr(); + if (!nameAttr) + return false; + + StringRef opName = nameAttr.getValue(); + if (isTransformerEncoderOperatorName(opName)) + return false; + + return opName == "torch.aten._scaled_dot_product_flash_attention_for_cpu"; +} } // namespace // Reduce Ops without value semantics but the corresponding without trailing diff --git a/lib/Dialect/Torch/Transforms/Utils.h b/lib/Dialect/Torch/Transforms/Utils.h new file mode 100644 index 000000000000..557686fc4bc1 --- /dev/null +++ b/lib/Dialect/Torch/Transforms/Utils.h @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "llvm/ADT/StringRef.h" + +namespace mlir { +namespace torch { +namespace Torch { + +inline bool isTransformerEncoderOperatorName(llvm::StringRef name) { + if (!name.consume_front("torch.")) + return false; + if (!name.consume_front("aten._transformer_encoder_layer_fwd")) + return false; + return name.empty() || name == ".default"; +} + +inline bool isTransformerEncoderOperator(Torch::OperatorOp op) { + auto nameAttr = op.getNameAttr(); + if (!nameAttr) + return false; + return isTransformerEncoderOperatorName(nameAttr.getValue()); +} + +} // namespace Torch +} // namespace torch +} // namespace mlir diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 4258547b94ee..cb44832d2235 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1788,6 +1788,7 @@ "TrilIndicesModule_basic", "TrilIndicesOfssetGreaterThanRowModule_basic", "TriuIndicesNegativeOffsetModule_basic", + "TransformerEncoderModule_basic", "BmmFloat16Module_basic", "ElementwiseRreluWithNoiseTrainStaticModule_basic", "LinspaceDtypeModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py index bd82cd1c11b0..b0406a564e8a 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -63,3 +63,4 @@ def register_all_tests(): from . import meshgrid from . import timeout from . import kl_div_loss + from . import transformer diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/transformer.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/transformer.py new file mode 100755 index 000000000000..7b436de698da --- /dev/null +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/transformer.py @@ -0,0 +1,42 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import torch + +from torch_mlir_e2e_test.annotations import annotate_args, export +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case + + +class TransformerEncoderModule(torch.nn.Module): + def __init__(self): + super().__init__() + torch.manual_seed(0) + self.layer = torch.nn.TransformerEncoderLayer( + d_model=8, + nhead=2, + dim_feedforward=16, + dropout=0.0, + activation="gelu", + batch_first=True, + norm_first=False, + ) + self.train(False) + + @export + @annotate_args( + [ + None, + ([1, 4, 8], torch.float32, True), + ] + ) + def forward(self, x): + return self.layer(x) + + +@register_test_case(module_factory=lambda: TransformerEncoderModule()) +def TransformerEncoderModule_basic(module, tu: TestUtils): + x = tu.rand(1, 4, 8) + module.forward(x) diff --git a/projects/pt1/test/python/transformer_encoder_lowering.py b/projects/pt1/test/python/transformer_encoder_lowering.py new file mode 100644 index 000000000000..01ee49970d4f --- /dev/null +++ b/projects/pt1/test/python/transformer_encoder_lowering.py @@ -0,0 +1,74 @@ +# RUN: %PYTHON %s | FileCheck %s + +import torch +from torch.nn import TransformerEncoderLayer + +from torch_mlir import ir +from torch_mlir.compiler_utils import run_pipeline_with_repro_report +from torch_mlir.dialects import torch as torch_d +from torch_mlir.extras.fx_decomp_util import get_decomposition_table +from torch_mlir.extras.fx_importer import FxImporter + + +def lower_transformer(norm_first: bool, activation: str) -> None: + layer = TransformerEncoderLayer( + d_model=32, + nhead=4, + dim_feedforward=64, + dropout=0.0, + activation=activation, + layer_norm_eps=1e-5, + batch_first=True, + norm_first=norm_first, + bias=True, + ).eval() + + example_input = torch.randn(2, 5, 32) + exported = torch.export.export(layer, (example_input,)) + + decomposition_table = get_decomposition_table() + if decomposition_table: + exported = exported.run_decompositions(decomposition_table) + + context = ir.Context() + torch_d.register_dialect(context) + importer = FxImporter(context=context) + importer.import_frozen_program(exported) + module = importer.module + + pipeline = """ + builtin.module( + func.func(torch-match-quantized-custom-ops), + torchdynamo-export-to-torch-backend-pipeline{extra-library= backend-legal-ops=aten.as_strided}, + torch-adjust-calling-conventions + ) + """ + run_pipeline_with_repro_report( + module, + pipeline, + "Lowering TorchFX IR -> Torch Backend IR", + enable_ir_printing=False, + ) + + if "torch.operator" in str(module.operation): + raise RuntimeError("Unexpected torch.operator after lowering") + + +if __name__ == "__main__": + torch.manual_seed(0) + if hasattr(torch, "set_deterministic_debug_mode"): + torch.set_deterministic_debug_mode("error") + + for activation in ("gelu", "relu"): + for norm_first in (True, False): + lower_transformer(norm_first, activation) + print( + f"CHECK: lowered norm_first={norm_first} activation={activation}" + ) + + print("SUCCESS") +# CHECK: CHECK: lowered norm_first=True activation=gelu +# CHECK: CHECK: lowered norm_first=False activation=gelu +# CHECK: CHECK: lowered norm_first=True activation=relu +# CHECK: CHECK: lowered norm_first=False activation=relu +# CHECK: SUCCESS diff --git a/transformer_encoder.mlir b/transformer_encoder.mlir new file mode 100644 index 000000000000..f5db55741c14 --- /dev/null +++ b/transformer_encoder.mlir @@ -0,0 +1,54 @@ +// RUN: torch-mlir-opt %s -torch-decompose-complex-ops -convert-torch-to-tosa -split-input-file | FileCheck %s + +// Verify that lowering a single TransformerEncoderLayer produces the expected +// TOSA building blocks (QKV projection, attention, feed-forward, and layer norm). +// The operands are intentionally small so we can keep static shapes throughout. +module { + // CHECK-LABEL: func.func @transformer( + // CHECK: %[[QKV:.*]] = tosa.matmul %{{.*}} : (tensor<1x4x8xf32>, tensor<1x8x24xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x24xf32> + // CHECK: %[[QKVRESHAPE:.*]] = tosa.reshape %{{.*}} : (tensor<1x4x24xf32>, !tosa.shape<5>) -> tensor<1x4x3x2x4xf32> + // CHECK: %[[SCORES:.*]] = tosa.matmul %{{.*}} : (tensor<2x4x4xf32>, tensor<2x4x4xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x4xf32> + // CHECK: %[[SCALE:.*]] = tosa.mul %{{.*}}, %{{.*}}, %{{.*}} : (tensor<1x2x4x4xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x2x4x4xf32> + // CHECK: tosa.sub %{{.*}} : (tensor<1x4x16xf32>, tensor<1x1x1xf32>) -> tensor<1x4x16xf32> + // CHECK: tosa.mul %{{.*}} : (tensor<1x4x16xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<1x4x16xf32> + // CHECK: tosa.erf %{{.*}} : (tensor<1x4x16xf32>) -> tensor<1x4x16xf32> + // CHECK: %[[FFN:.*]] = tosa.matmul %{{.*}} : (tensor<1x4x16xf32>, tensor<1x16x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x8xf32> + // CHECK: %[[NORMINV:.*]] = tosa.rsqrt %{{.*}} : (tensor<1x4x1xf32>) -> tensor<1x4x1xf32> + // CHECK: return %{{.*}} : !torch.vtensor<[1,4,8],f32> + func.func @transformer( + %arg0: !torch.vtensor<[1,4,8],f32>, + %qkv_weight: !torch.vtensor<[24,8],f32>, + %qkv_bias: !torch.vtensor<[24],f32>, + %proj_weight: !torch.vtensor<[8,8],f32>, + %proj_bias: !torch.vtensor<[8],f32>, + %norm1_weight: !torch.vtensor<[8],f32>, + %norm1_bias: !torch.vtensor<[8],f32>, + %norm2_weight: !torch.vtensor<[8],f32>, + %norm2_bias: !torch.vtensor<[8],f32>, + %ffn1_weight: !torch.vtensor<[16,8],f32>, + %ffn1_bias: !torch.vtensor<[16],f32>, + %ffn2_weight: !torch.vtensor<[8,16],f32>, + %ffn2_bias: !torch.vtensor<[8],f32>) -> !torch.vtensor<[1,4,8],f32> { + %embed_dim = torch.constant.int 8 + %num_heads = torch.constant.int 2 + %use_gelu = torch.constant.bool true + %norm_first = torch.constant.bool false + %eps = torch.constant.float 1.000000e-05 + %none = torch.constant.none + %result = torch.operator "torch.aten._transformer_encoder_layer_fwd.default"( + %arg0, %embed_dim, %num_heads, %qkv_weight, %qkv_bias, %proj_weight, + %proj_bias, %use_gelu, %norm_first, %eps, %norm1_weight, %norm1_bias, + %norm2_weight, %norm2_bias, %ffn1_weight, %ffn1_bias, %ffn2_weight, + %ffn2_bias, %none, %none + ) : (!torch.vtensor<[1,4,8],f32>, !torch.int, !torch.int, + !torch.vtensor<[24,8],f32>, !torch.vtensor<[24],f32>, + !torch.vtensor<[8,8],f32>, !torch.vtensor<[8],f32>, !torch.bool, + !torch.bool, !torch.float, !torch.vtensor<[8],f32>, + !torch.vtensor<[8],f32>, !torch.vtensor<[8],f32>, + !torch.vtensor<[8],f32>, !torch.vtensor<[16,8],f32>, + !torch.vtensor<[16],f32>, !torch.vtensor<[8,16],f32>, + !torch.vtensor<[8],f32>, !torch.none, !torch.none) + -> !torch.vtensor<[1,4,8],f32> + return %result : !torch.vtensor<[1,4,8],f32> + } +}