Skip to content

Commit dee5158

Browse files
authored
[TorchToStableHLO] Fix unsigned integer conversion in ValueTensorLiteralOp. (#4313)
Previously, all integer values were sign-extended when converting tensors, which corrupted unsigned integer values. This commit fixes the conversion to zero-extend unsigned integers and sign-extend signed integers.
1 parent 72f6da0 commit dee5158

File tree

3 files changed

+50
-2
lines changed

3 files changed

+50
-2
lines changed

lib/Conversion/TorchToStablehlo/Basic.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -876,14 +876,17 @@ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
876876
// Tensors with integer types need to be converted to signless integer
877877
// element type. All tensors with element types other than integer can reuse
878878
// existing elements attribute.
879-
// TODO: what about unsigned integer?
880879
if (auto elements = dyn_cast<DenseIntElementsAttr>(op.getValueAttr())) {
881880
Type builtinTensorElemTy = resultType.getElementType();
882881
unsigned bitWidth = builtinTensorElemTy.getIntOrFloatBitWidth();
882+
bool isUnsigned =
883+
cast<IntegerType>(builtinTensorElemTy).isUnsignedInteger();
883884

884885
DenseElementsAttr valueAttr =
885886
elements.mapValues(builtinTensorElemTy, [&](const APInt &v) {
886-
return APInt(bitWidth, v.getSExtValue());
887+
APInt intValue =
888+
isUnsigned ? v.zextOrTrunc(bitWidth) : v.sextOrTrunc(bitWidth);
889+
return intValue;
887890
});
888891
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(op, resultType,
889892
valueAttr);

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1771,6 +1771,8 @@
17711771
# raise TimeoutError(self.error_message)
17721772
# TimeoutError: Timeout
17731773
"BertModule_basic",
1774+
"UInt8Tensor_basic",
1775+
"BoolTensor_basic",
17741776
}
17751777

17761778
# Write the TOSA set as a "passing" set as it is very early in development

projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2102,3 +2102,46 @@ def forward(self, a):
21022102
@register_test_case(module_factory=lambda: AtenDiagEmbedNonDefault4DDiag())
21032103
def AtenDiagEmbedNonDefault4DDiag_basic(module, tu: TestUtils):
21042104
module.forward(tu.rand(2, 3, 4, 5))
2105+
2106+
2107+
# ==============================================================================
2108+
2109+
2110+
class UInt8Tensor(torch.nn.Module):
2111+
def __init__(self):
2112+
super().__init__()
2113+
2114+
@export
2115+
@annotate_args(
2116+
[
2117+
None,
2118+
]
2119+
)
2120+
def forward(self):
2121+
x = torch.tensor([128], dtype=torch.uint8)
2122+
return torch.ops.aten.to(x, dtype=torch.float32)
2123+
2124+
2125+
@register_test_case(module_factory=lambda: UInt8Tensor())
2126+
def UInt8Tensor_basic(module, tu: TestUtils):
2127+
module.forward()
2128+
2129+
2130+
class BoolTensor(torch.nn.Module):
2131+
def __init__(self):
2132+
super().__init__()
2133+
2134+
@export
2135+
@annotate_args(
2136+
[
2137+
None,
2138+
]
2139+
)
2140+
def forward(self):
2141+
x = torch.tensor([True], dtype=torch.bool)
2142+
return torch.ops.aten.to(x, dtype=torch.float32)
2143+
2144+
2145+
@register_test_case(module_factory=lambda: BoolTensor())
2146+
def BoolTensor_basic(module, tu: TestUtils):
2147+
module.forward()

0 commit comments

Comments
 (0)