@@ -516,18 +516,17 @@ struct LstmLayerOutput {
516516//
517517// @return A struct containing the hidden state history, final hidden state,
518518// and final cell state.
519- LstmLayerOutput lstm_layer (ImplicitLocOpBuilder &b, Value X, Value initial_h ,
520- Value initial_c, LstmWeights weights ,
521- LstmActivations activations) {
519+ LstmLayerOutput lstm_layer (ConversionPatternRewriter &rewriter, Location &loc ,
520+ Value X, Value initial_h, Value initial_c ,
521+ LstmWeights weights, LstmActivations activations) {
522522
523- Location loc = b. getLoc ( );
523+ mlir::ImplicitLocOpBuilder b (loc, rewriter );
524524
525- auto xTy = cast<ValueTensorType>(X.getType ());
526525 auto hTy = cast<ValueTensorType>(initial_h.getType ());
527526 // these names are snake_case for consistency with onnx.LSTM documentation
528- int64_t seq_len = xTy. getSizes ()[ 0 ] ;
529- int64_t batch_size = xTy. getSizes ()[ 1 ] ;
530- int64_t input_size = xTy. getSizes ()[ 2 ] ;
527+ Value seq_len = getTensorDimSize (rewriter, X, 0 ) ;
528+ Value batch_size = getTensorDimSize (rewriter, X, 1 ) ;
529+ Value input_size = getTensorDimSize (rewriter, X, 2 ) ;
531530 int64_t hidden_size = hTy.getSizes ()[1 ];
532531
533532 auto cTy = hTy;
@@ -537,19 +536,14 @@ LstmLayerOutput lstm_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h,
537536 Value cstNone = b.create <ConstantNoneOp>();
538537 Value cstZero = b.create <ConstantIntOp>(intType, b.getI64IntegerAttr (0 ));
539538 Value cstOne = b.create <ConstantIntOp>(intType, b.getI64IntegerAttr (1 ));
540- Value cstSeqLen =
541- b.create <ConstantIntOp>(intType, b.getI64IntegerAttr (seq_len));
542- Value cstBatchSize =
543- b.create <ConstantIntOp>(intType, b.getI64IntegerAttr (batch_size));
544539 Value cstHiddenSize =
545540 b.create <ConstantIntOp>(intType, b.getI64IntegerAttr (hidden_size));
546541
547- auto yTy = b.getType <ValueTensorType>(
548- SmallVector<int64_t >{seq_len, batch_size, hidden_size}, hTy.getDtype ());
549-
542+ auto yTy = getTensorTypeFromShapeValues ({seq_len, batch_size, cstHiddenSize},
543+ hTy.getDtype ());
550544 auto YShapeList = b.create <PrimListConstructOp>(
551545 b.getType <ListType>(intType),
552- ValueRange ({cstSeqLen, cstBatchSize , cstHiddenSize}));
546+ ValueRange ({seq_len, batch_size , cstHiddenSize}));
553547
554548 int64_t hDtypeInt =
555549 static_cast <int64_t >(getScalarTypeForType (hTy.getDtype ()));
@@ -560,8 +554,7 @@ LstmLayerOutput lstm_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h,
560554 cstNone, cstNone, cstNone);
561555
562556 // Create a for-like PrimLoopOp.
563- Value maxTripCount =
564- b.create <ConstantIntOp>(intType, b.getI64IntegerAttr (seq_len));
557+ Value maxTripCount = seq_len;
565558 Value loopConditionTrue = b.create <ConstantBoolOp>(true );
566559
567560 Type loopIndexType = intType;
@@ -587,16 +580,16 @@ LstmLayerOutput lstm_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h,
587580 Value C_prev = loopBody->getArgument (3 );
588581
589582 auto xTy = cast<ValueTensorType>(X.getType ());
590- auto XtType = b. getType <ValueTensorType>(
591- llvm::SmallVector< int64_t > {batch_size, input_size}, xTy.getDtype ());
583+ auto XtType =
584+ getTensorTypeFromShapeValues ( {batch_size, input_size}, xTy.getDtype ());
592585
593586 Value Xt = b.create <AtenSelectIntOp>(XtType, X, cstZero, loopIndex);
594587
595588 auto [H_new, C_new] =
596589 lstm_cell (b, Xt, H_prev, C_prev, weights, activations);
597590
598- Type hTyUnsqueezed = b. getType <ValueTensorType> (
599- llvm::SmallVector< int64_t >{ 1 , batch_size, hidden_size }, hTy.getDtype ());
591+ auto hTyUnsqueezed = getTensorTypeFromShapeValues (
592+ {cstOne , batch_size, cstHiddenSize }, hTy.getDtype ());
600593 Value H_new_unsqueezed =
601594 b.create <AtenUnsqueezeOp>(hTyUnsqueezed, H_new, cstZero);
602595
@@ -773,17 +766,12 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
773766 binder.op , " invalid value of layout attribute, expecting 0 / 1 got " +
774767 std::to_string (layout));
775768
776- auto XShape = xTy.getSizes ();
777- int64_t seq_len, batch_size;
778- if (layout == 0 ) {
779- seq_len = XShape[0 ];
780- batch_size = XShape[1 ];
781- } else {
782- seq_len = XShape[1 ];
783- batch_size = XShape[0 ];
784- }
769+ Value seqLen = getTensorDimSize (rewriter, X, layout == 0 ? 0 : 1 );
770+ Value batchSize = getTensorDimSize (rewriter, X, layout == 0 ? 1 : 0 );
785771
786- int64_t input_size = XShape[2 ];
772+ int64_t x_input_size = xTy.getSizes ()[2 ];
773+ int64_t w_input_size = wTy.getSizes ()[2 ];
774+ int64_t input_size = w_input_size;
787775 if (num_directions != wTy.getSizes ()[0 ])
788776 return rewriter.notifyMatchFailure (
789777 binder.op , " num_directions (" + std::to_string (num_directions) +
@@ -795,11 +783,22 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
795783 binder.op , " 4 times hidden_size (" + std::to_string (4 * hidden_size) +
796784 " ) does not match the second dimension of wTy (" +
797785 std::to_string (wTy.getSizes ()[1 ]) + " )" );
798- if (wTy.getSizes ()[2 ] != input_size)
799- return rewriter.notifyMatchFailure (
800- binder.op ,
801- " The third dimension of wTy (" + std::to_string (wTy.getSizes ()[2 ]) +
802- " ) does not match input_size (" + std::to_string (input_size) + " )" );
786+ if (x_input_size != Torch::kUnknownSize ) {
787+ if (w_input_size != x_input_size)
788+ return rewriter.notifyMatchFailure (
789+ binder.op , " The input_size of wTy (" + std::to_string (w_input_size) +
790+ " ) does not match input_size of xTY (" +
791+ std::to_string (x_input_size) + " )" );
792+
793+ } else {
794+ Value x_input_size = Torch::getTensorDimSize (rewriter, X, 2 );
795+ Value w_input_size =
796+ b.create <ConstantIntOp>(loc, b.getI64IntegerAttr (wTy.getSizes ()[2 ]));
797+
798+ auto eq = b.create <AtenEqIntOp>(loc, x_input_size, w_input_size);
799+ rewriter.create <RuntimeAssertOp>(
800+ loc, eq, rewriter.getStringAttr (" The input_size of W must equal X." ));
801+ }
803802
804803 Value W_forward = getDirection (b, 0 , W);
805804 Value R_forward = getDirection (b, 0 , R);
@@ -812,25 +811,21 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
812811 B_reverse = getDirection (b, 1 , B);
813812 }
814813
815- auto hTy = b.getType <ValueTensorType>(
816- llvm::SmallVector<int64_t >{num_directions, batch_size, hidden_size},
817- xTy.getDtype ());
818-
819814 auto intType = b.getType <IntType>();
820815
821816 Value cstNumDirections =
822817 b.create <ConstantIntOp>(intType, b.getI64IntegerAttr (num_directions));
823- Value cstBatchSize =
824- b.create <ConstantIntOp>(intType, b.getI64IntegerAttr (batch_size));
825818 Value cstHiddenSize =
826819 b.create <ConstantIntOp>(intType, b.getI64IntegerAttr (hidden_size));
827820 Value cstNone = b.create <ConstantNoneOp>();
828821 Value cstZero = b.create <ConstantIntOp>(intType, b.getI64IntegerAttr (0 ));
829822 Value cstOne = b.create <ConstantIntOp>(intType, b.getI64IntegerAttr (1 ));
830823
824+ auto hTy = getTensorTypeFromShapeValues (
825+ {cstNumDirections, batchSize, cstHiddenSize}, xTy.getDtype ());
831826 Value hShape = b.create <PrimListConstructOp>(
832827 b.getType <ListType>(intType),
833- ValueRange ({cstNumDirections, cstBatchSize , cstHiddenSize}));
828+ ValueRange ({cstNumDirections, batchSize , cstHiddenSize}));
834829
835830 Value cstDtype = getDtypeIntValueForType (rewriter, loc, xTy.getDtype ());
836831
@@ -986,26 +981,26 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
986981 std::tie (weightsRev.R_i , weightsRev.R_o , weightsRev.R_f , weightsRev.R_c ) =
987982 sliceIOFC (sliceGateWeightsHH, R_reverse);
988983
989- LstmLayerOutput lstmLayerOutput = lstm_layer (
990- b, X, initial_h_forward, initial_c_forward, weights, activations);
984+ LstmLayerOutput lstmLayerOutput =
985+ lstm_layer (rewriter, loc, X, initial_h_forward, initial_c_forward,
986+ weights, activations);
991987
992988 Value Y_h_result, Y_c_result, Y_result;
993989
994990 // if forward (unidirectional) unsqueeze and output
995991 auto YallDtype =
996992 cast<ValueTensorType>(lstmLayerOutput.Y_h .getType ()).getDtype ();
997- auto Y_h_Y_c_uni_type = b.getType <ValueTensorType>(
998- llvm::SmallVector<int64_t >{1 , batch_size, hidden_size}, YallDtype);
999- auto Y_uni_type = b.getType <ValueTensorType>(
1000- llvm::SmallVector<int64_t >{seq_len, 1 , batch_size, hidden_size},
1001- YallDtype);
1002- auto Y_h_Y_c_res_type = b.getType <ValueTensorType>(
1003- llvm::SmallVector<int64_t >{num_directions, batch_size, hidden_size},
1004- YallDtype);
1005- auto Y_res_type = b.getType <ValueTensorType>(
1006- llvm::SmallVector<int64_t >{seq_len, num_directions, batch_size,
1007- hidden_size},
1008- YallDtype);
993+ auto Y_h_Y_c_uni_type = getTensorTypeFromShapeValues (
994+ {cstOne, batchSize, cstHiddenSize}, YallDtype);
995+
996+ auto Y_uni_type = getTensorTypeFromShapeValues (
997+ {seqLen, cstOne, batchSize, cstHiddenSize}, YallDtype);
998+
999+ auto Y_h_Y_c_res_type = getTensorTypeFromShapeValues (
1000+ {cstNumDirections, batchSize, cstHiddenSize}, YallDtype);
1001+
1002+ auto Y_res_type = getTensorTypeFromShapeValues (
1003+ {seqLen, cstNumDirections, batchSize, cstHiddenSize}, YallDtype);
10091004
10101005 Value Y_h_forward =
10111006 b.create <AtenUnsqueezeOp>(Y_h_Y_c_uni_type, lstmLayerOutput.Y_h , cstZero);
@@ -1034,8 +1029,8 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
10341029 SmallVector<Value>{cstZero});
10351030 X_reverse = b.create <AtenFlipOp>(xTy, X, dim0); // flip along seq_len dim
10361031 revLstmLayerOutput =
1037- lstm_layer (b, X_reverse, initial_h_reverse, initial_c_reverse ,
1038- weightsRev, activationsRev);
1032+ lstm_layer (rewriter, loc, X_reverse, initial_h_reverse ,
1033+ initial_c_reverse, weightsRev, activationsRev);
10391034
10401035 // unsqueeze Y_rev, Y_h_rev, Y_c_rev
10411036 Y_h_reverse = b.create <AtenUnsqueezeOp>(Y_h_Y_c_uni_type,
@@ -1081,7 +1076,7 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
10811076 outputs;
10821077 ValueTensorType resTy;
10831078 for (int i = 0 ; i < binder.getNumResults (); ++i) {
1084- if (! binder.tensorResultTypeAtIndex (resTy, i) && !resTy ) {
1079+ if (failed ( binder.tensorResultTypeAtIndex (resTy, i)) ) {
10851080 outputs.push_back (cstNone);
10861081 } else {
10871082 outputs.push_back (actualOutputs[i]);
0 commit comments