@@ -464,13 +464,19 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
464464 if (binder.tensorOperandsList (operands) ||
465465 binder.tensorResultType (resultType))
466466 return failure ();
467+
468+ if (operands.size () != 8 )
469+ return rewriter.notifyMatchFailure (
470+ binder.op , " Unimplemented: expected 8 input operands" );
471+
467472 Value a = operands[0 ];
468473 Value aScale = operands[1 ];
469474 Value aZp = operands[2 ];
470475 Value b = operands[3 ];
471476 Value bScale = operands[4 ];
472477 Value bZp = operands[5 ];
473478 Value cScale = operands[6 ];
479+ Value cZp = operands[7 ];
474480
475481 auto check = [](Value v) {
476482 auto vTy = cast<Torch::ValueTensorType>(v.getType ());
@@ -480,7 +486,7 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
480486 return true ;
481487 };
482488 if (!check (aScale) || !check (aZp) || !check (bScale) || !check (bZp) ||
483- !check (cScale))
489+ !check (cScale) || ! check (cZp) )
484490 return rewriter.notifyMatchFailure (
485491 binder.op , " Unsupported per-tensor quantization" );
486492
@@ -508,19 +514,7 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
508514
509515 aZp = extract (aZp);
510516 bZp = extract (bZp);
511-
512- Value cZp;
513- if (operands.size () == 8 ) {
514- cZp = operands[7 ];
515- if (!check (cZp))
516- return rewriter.notifyMatchFailure (
517- binder.op ,
518- " Unsupported c_zero_point for per-tensor quantization" );
519- cZp = extract (cZp);
520- } else {
521- cZp = rewriter.create <Torch::ConstantIntOp>(
522- loc, rewriter.getI64IntegerAttr (0 ));
523- }
517+ cZp = extract (cZp);
524518
525519 aScale = extract (aScale);
526520 bScale = extract (bScale);
@@ -590,6 +584,10 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
590584 binder.f32FloatAttr (alpha, " alpha" ))
591585 return failure ();
592586
587+ if (operands.size () != 5 )
588+ return rewriter.notifyMatchFailure (
589+ binder.op , " Unimplemented: expected 5 input operands" );
590+
593591 Value x = operands[0 ];
594592 Value xScale = operands[1 ];
595593 Value xZp = operands[2 ];
@@ -760,6 +758,12 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
760758 binder.s64IntegerAttr (channelsLast, " channels_last" ))
761759 return failure ();
762760
761+ // TODO: Add support for channels_last attribute.
762+ if (channelsLast)
763+ return rewriter.notifyMatchFailure (
764+ binder.op ,
765+ " Unimplemented: support not present for channels_last attribute" );
766+
763767 Value x = operands[0 ];
764768 Value xScale, xZp, yScale, yZp;
765769
@@ -880,6 +884,10 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
880884 binder.tensorResultType (resultType))
881885 return failure ();
882886
887+ if (operands.size () != 5 )
888+ return rewriter.notifyMatchFailure (
889+ binder.op , " Unimplemented: expected 5 input operands" );
890+
883891 Value x = operands[0 ];
884892 Value xScale, xZp, yScale, yZp;
885893
@@ -946,6 +954,10 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
946954 binder.op ,
947955 " Unimplemented: support not present for channels_last attribute" );
948956
957+ if (operands.size () != 5 )
958+ return rewriter.notifyMatchFailure (
959+ binder.op , " Unimplemented: expected 5 input operands" );
960+
949961 Value x = operands[0 ];
950962 Value xScale, xZp, yScale, yZp;
951963
0 commit comments