@@ -2250,66 +2250,91 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
22502250 Value zeropoint = operands[2 ];
22512251
22522252 auto operandTy = cast<Torch::ValueTensorType>(operand.getType ());
2253-
2254- auto operandETy = operandTy.getDtype ();
22552253 auto scaleTy = dyn_cast<Torch::ValueTensorType>(scale.getType ());
22562254 if (!scaleTy || !scaleTy.hasSizes ())
22572255 return rewriter.notifyMatchFailure (binder.op , " requires known rank" );
22582256 if (!resultType.hasDtype ())
22592257 return rewriter.notifyMatchFailure (binder.op ,
22602258 " requires known result dtype" );
22612259
2262- bool rank0 = scaleTy.getSizes ().size () == 0 ;
2263- bool length1 =
2264- scaleTy.getSizes ().size () == 1 && scaleTy.getSizes ()[0 ] == 1 ;
2265-
2266- if (!rank0 && !length1)
2267- return rewriter.notifyMatchFailure (binder.op ,
2268- " unimplemented: non-scalar scale" );
2260+ int64_t scaleRank = scaleTy.getSizes ().size ();
2261+ if (scaleRank > 1 )
2262+ return rewriter.notifyMatchFailure (
2263+ binder.op , " unimplemented: only per-tensor or per-axis "
2264+ " quantization supported" );
22692265 auto qTensorTy = getQTorchTypeFromTorchIntType (operandTy);
22702266 if (!qTensorTy) {
22712267 return rewriter.notifyMatchFailure (binder.op ,
22722268 " unsupported result dtype" );
22732269 }
22742270
2275- scale = rewriter.create <Torch::AtenItemOp>(
2276- loc, rewriter.getType <Torch::FloatType>(), scale);
2277-
2271+ auto operandETy = operandTy.getDtype ();
22782272 bool fpOperand = isa<mlir::FloatType>(operandETy);
2279- Type zeropointTy = rewriter.getType <Torch::IntType>();
2280- if (fpOperand)
2281- zeropointTy = rewriter.getType <Torch::FloatType>();
2282-
2283- zeropoint =
2284- rewriter.create <Torch::AtenItemOp>(loc, zeropointTy, zeropoint);
2285-
2286- if (fpOperand) {
2287- Value none = rewriter.create <Torch::ConstantNoneOp>(loc);
2288- Value cstFalse = rewriter.create <Torch::ConstantBoolOp>(loc, false );
2289- auto tyVal = Torch::getScalarTypeForType (resultType.getDtype ());
2290- Value tyConst = rewriter.create <Torch::ConstantIntOp>(
2291- loc, rewriter.getType <Torch::IntType>(),
2292- rewriter.getIntegerAttr (rewriter.getIntegerType (64 ),
2293- static_cast <int64_t >(tyVal)));
2294- Value toDtype = rewriter.create <Torch::AtenToDtypeOp>(
2295- loc, resultType, operand, tyConst,
2296- /* non_blocking=*/ cstFalse, /* copy=*/ cstFalse,
2297- /* memory_format=*/ none);
2298-
2299- Value one = rewriter.create <Torch::ConstantFloatOp>(
2300- loc, rewriter.getF64FloatAttr (1.0 ));
2301- Value sub = rewriter.create <Torch::AtenSubScalarOp>(
2302- loc, resultType, toDtype, zeropoint, one);
2303- rewriter.replaceOpWithNewOp <Torch::AtenMulScalarOp>(
2304- binder.op , resultType, sub, scale);
2273+ bool isPerTensorQuantization = false ;
2274+ if (scaleRank == 0 ||
2275+ llvm::all_of (scaleTy.getSizes (), [](int64_t s) { return s == 1 ; }))
2276+ isPerTensorQuantization = true ;
2277+
2278+ // (TODO) Case: Per-Channel Quantization for floating point input.
2279+ if (scaleRank == 1 && fpOperand)
2280+ return rewriter.notifyMatchFailure (
2281+ binder.op , " unimplemented: support for per-Channel Quantization "
2282+ " for floating point input not present" );
2283+
2284+ if (isPerTensorQuantization) {
2285+ scale = rewriter.create <Torch::AtenItemOp>(
2286+ loc, rewriter.getType <Torch::FloatType>(), scale);
2287+
2288+ Type zeropointTy = rewriter.getType <Torch::IntType>();
2289+ if (fpOperand)
2290+ zeropointTy = rewriter.getType <Torch::FloatType>();
2291+ zeropoint =
2292+ rewriter.create <Torch::AtenItemOp>(loc, zeropointTy, zeropoint);
2293+ }
2294+
2295+ if (!fpOperand) {
2296+ Value quantize;
2297+ // Case 1: Per-Tensor Quantization for non-floating point input.
2298+ if (isPerTensorQuantization) {
2299+ quantize =
2300+ rewriter.create <Torch::Aten_MakePerTensorQuantizedTensorOp>(
2301+ loc, qTensorTy, operand, scale, zeropoint);
2302+ } else {
2303+ // Case 2: Per-Channel Quantization for non-floating point input.
2304+ int64_t axis;
2305+ if (binder.s64IntegerAttr (axis, " axis" , 1 ))
2306+ return failure ();
2307+
2308+ Value cstAxis = rewriter.create <Torch::ConstantIntOp>(
2309+ loc, rewriter.getI64IntegerAttr (axis));
2310+ quantize =
2311+ rewriter.create <Torch::Aten_MakePerChannelQuantizedTensorOp>(
2312+ loc, qTensorTy, operand, scale, zeropoint, cstAxis);
2313+ }
2314+ rewriter.replaceOpWithNewOp <Torch::AtenDequantizeSelfOp>(
2315+ binder.op , resultType, quantize);
23052316 return success ();
23062317 }
23072318
2308- auto quantize =
2309- rewriter.create <Torch::Aten_MakePerTensorQuantizedTensorOp>(
2310- loc, qTensorTy, operand, scale, zeropoint);
2311- rewriter.replaceOpWithNewOp <Torch::AtenDequantizeSelfOp>(
2312- binder.op , resultType, quantize);
2319+ // Case 3: Per-Tensor Quantization for floating point input.
2320+ Value none = rewriter.create <Torch::ConstantNoneOp>(loc);
2321+ Value cstFalse = rewriter.create <Torch::ConstantBoolOp>(loc, false );
2322+ auto tyVal = Torch::getScalarTypeForType (resultType.getDtype ());
2323+ Value tyConst = rewriter.create <Torch::ConstantIntOp>(
2324+ loc, rewriter.getType <Torch::IntType>(),
2325+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ),
2326+ static_cast <int64_t >(tyVal)));
2327+ Value toDtype = rewriter.create <Torch::AtenToDtypeOp>(
2328+ loc, resultType, operand, tyConst,
2329+ /* non_blocking=*/ cstFalse, /* copy=*/ cstFalse,
2330+ /* memory_format=*/ none);
2331+
2332+ Value one = rewriter.create <Torch::ConstantFloatOp>(
2333+ loc, rewriter.getF64FloatAttr (1.0 ));
2334+ Value sub = rewriter.create <Torch::AtenSubScalarOp>(
2335+ loc, resultType, toDtype, zeropoint, one);
2336+ rewriter.replaceOpWithNewOp <Torch::AtenMulScalarOp>(
2337+ binder.op , resultType, sub, scale);
23132338 return success ();
23142339 });
23152340 patterns.onOp (" Div" , 7 ,
0 commit comments