File tree Expand file tree Collapse file tree 3 files changed +50
-2
lines changed
lib/Conversion/TorchToStablehlo
python/torch_mlir_e2e_test/test_suite Expand file tree Collapse file tree 3 files changed +50
-2
lines changed Original file line number Diff line number Diff line change @@ -876,14 +876,17 @@ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
876876 // Tensors with integer types need to be converted to signless integer
877877 // element type. All tensors with element types other than integer can reuse
878878 // existing elements attribute.
879- // TODO: what about unsigned integer?
880879 if (auto elements = dyn_cast<DenseIntElementsAttr>(op.getValueAttr ())) {
881880 Type builtinTensorElemTy = resultType.getElementType ();
882881 unsigned bitWidth = builtinTensorElemTy.getIntOrFloatBitWidth ();
882+ bool isUnsigned =
883+ cast<IntegerType>(builtinTensorElemTy).isUnsignedInteger ();
883884
884885 DenseElementsAttr valueAttr =
885886 elements.mapValues (builtinTensorElemTy, [&](const APInt &v) {
886- return APInt (bitWidth, v.getSExtValue ());
887+ APInt intValue =
888+ isUnsigned ? v.zextOrTrunc (bitWidth) : v.sextOrTrunc (bitWidth);
889+ return intValue;
887890 });
888891 rewriter.replaceOpWithNewOp <stablehlo::ConstantOp>(op, resultType,
889892 valueAttr);
Original file line number Diff line number Diff line change 17711771 # raise TimeoutError(self.error_message)
17721772 # TimeoutError: Timeout
17731773 "BertModule_basic" ,
1774+ "UInt8Tensor_basic" ,
1775+ "BoolTensor_basic" ,
17741776}
17751777
17761778# Write the TOSA set as a "passing" set as it is very early in development
Original file line number Diff line number Diff line change @@ -2102,3 +2102,46 @@ def forward(self, a):
21022102 @register_test_case (module_factory = lambda : AtenDiagEmbedNonDefault4DDiag ())
21032103 def AtenDiagEmbedNonDefault4DDiag_basic (module , tu : TestUtils ):
21042104 module .forward (tu .rand (2 , 3 , 4 , 5 ))
2105+
2106+
2107+ # ==============================================================================
2108+
2109+
2110+ class UInt8Tensor (torch .nn .Module ):
2111+ def __init__ (self ):
2112+ super ().__init__ ()
2113+
2114+ @export
2115+ @annotate_args (
2116+ [
2117+ None ,
2118+ ]
2119+ )
2120+ def forward (self ):
2121+ x = torch .tensor ([128 ], dtype = torch .uint8 )
2122+ return torch .ops .aten .to (x , dtype = torch .float32 )
2123+
2124+
2125+ @register_test_case (module_factory = lambda : UInt8Tensor ())
2126+ def UInt8Tensor_basic (module , tu : TestUtils ):
2127+ module .forward ()
2128+
2129+
2130+ class BoolTensor (torch .nn .Module ):
2131+ def __init__ (self ):
2132+ super ().__init__ ()
2133+
2134+ @export
2135+ @annotate_args (
2136+ [
2137+ None ,
2138+ ]
2139+ )
2140+ def forward (self ):
2141+ x = torch .tensor ([True ], dtype = torch .bool )
2142+ return torch .ops .aten .to (x , dtype = torch .float32 )
2143+
2144+
2145+ @register_test_case (module_factory = lambda : BoolTensor ())
2146+ def BoolTensor_basic (module , tu : TestUtils ):
2147+ module .forward ()
You can’t perform that action at this time.
0 commit comments