Skip to content

Commit ada7c2f

Browse files
committed
[TORCH] Transformer encoder decomposition
- Add a dedicated DecomposeTransformerEncoder pass to expand encoder ops into primitive Torch patterns. - Extend shared lowering helpers (ReduceOpVariants.cpp, Utils.h) so the new pass can reuse reduction utilities during decomposition. - Register the pass in the Torch Transform pipeline so it runs as part of the decomposition flow. - Expand e2e coverage with new transformer encoder tests to validate the lowering path. Signed-off-by: Cathal Corbett <[email protected]> Change-Id: I6bcda53569cf7b06df4cb97c624bbf512d8fecb7
1 parent b834f94 commit ada7c2f

File tree

11 files changed

+691
-1
lines changed

11 files changed

+691
-1
lines changed

include/torch-mlir/Dialect/Torch/Transforms/Passes.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "mlir/Dialect/Func/IR/FuncOps.h"
1414
#include "mlir/Pass/Pass.h"
15+
#include "llvm/ADT/StringSet.h"
1516

1617
#include <memory>
1718

@@ -157,6 +158,10 @@ static const char kTorchOpPrefix[] = R"(torch.)";
157158
void populateRestructureNonConstantAxesPattern(RewritePatternSet &patterns,
158159
MLIRContext *context);
159160

161+
void populateTransformerEncoderPatterns(
162+
RewritePatternSet &patterns, const llvm::StringSet<> &legalOpsSet);
163+
164+
160165
std::unique_ptr<OperationPass<func::FuncOp>>
161166
createRestructureNonConstantAxesPass();
162167

lib/Dialect/Torch/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
add_mlir_library(TorchMLIRTorchPasses
22
AdjustCallingConventions.cpp
33
DecomposeComplexOps.cpp
4+
DecomposeTransformerEncoder.cpp
45
DropAbstractInterpCalculations.cpp
56
EraseModuleInitializer.cpp
67
FuseQuantizedOps.cpp

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13084,6 +13084,8 @@ class DecomposeComplexOpsPass
1308413084
legalOpsSet.clear();
1308513085
legalOpsSet.insert(legalOps.begin(), legalOps.end());
1308613086

13087+
populateTransformerEncoderPatterns(patterns, legalOpsSet);
13088+
1308713089
addPatternIfTargetOpIsIllegal<DecomposeAten_WeightNormInterfaceOp>(
1308813090
patterns);
1308913091
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftmaxIntOp>(patterns);

0 commit comments

Comments
 (0)