Skip to content

Commit 189e1ea

Browse files
committed
[QNN EP] Relex check allowzero=1 when concrete shape without 0
1 parent 4665804 commit 189e1ea

File tree

2 files changed

+42
-7
lines changed

2 files changed

+42
-7
lines changed

onnxruntime/core/providers/qnn/builder/opbuilder/reshape_op_builder.cc

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,33 @@ Status ReshapeOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
3737
if (do_op_validation) {
3838
NodeAttrHelper node_helper(node_unit);
3939
auto allowzero = node_helper.Get("allowzero", static_cast<int64_t>(0));
40+
41+
// Only reject allowzero=1 if dynamic shape or the shape actually contains zeros
4042
if (0 != allowzero) {
41-
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN Reshape doesn't support dynamic shape!");
43+
const auto& inputs = node_unit.Inputs();
44+
const auto& initializer_tensors = qnn_model_wrapper.GetInitializerTensors();
45+
auto shape_tensor_iter = initializer_tensors.find(inputs[1].node_arg.Name());
46+
47+
if (shape_tensor_iter == initializer_tensors.end()) {
48+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
49+
"QNN Reshape requires a constant shape input");
50+
}
51+
52+
// Check if the constant shape contains any zeros
53+
const auto* shape_tensor = shape_tensor_iter->second;
54+
std::vector<uint8_t> unpacked_tensor;
55+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*shape_tensor, unpacked_tensor));
56+
57+
const int64_t* shape_data = reinterpret_cast<const int64_t*>(unpacked_tensor.data());
58+
size_t shape_size = unpacked_tensor.size() / sizeof(int64_t);
59+
60+
for (size_t i = 0; i < shape_size; ++i) {
61+
if (shape_data[i] == 0) {
62+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
63+
"QNN Reshape does not support shapes with zero dimensions. "
64+
"The 'allowzero' attribute is not supported by QNN.");
65+
}
66+
}
4267
}
4368
}
4469

onnxruntime/test/providers/qnn/reshape_expand_op_test.cc

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -271,14 +271,24 @@ TEST_F(QnnHTPBackendTests, Reshape_DynamicShape_Unsupported) {
271271
19); // Opset
272272
}
273273

274-
// Test that QDQ Reshape with an enabled 'allowzero' attribute is not supported by QNN EP.
275-
TEST_F(QnnHTPBackendTests, Reshape_AllowZeroAttr_Unsupported) {
274+
// Test that QDQ Reshape with allowzero=1 and a shape containing zeros is not supported by QNN EP.
275+
TEST_F(QnnHTPBackendTests, Reshape_AllowZeroAttr_WithZeros_Unsupported) {
276+
RunReshapeExpandTestOnHTP<float>("Reshape",
277+
TestInputDef<float>({2, 0, 3}, false, {}),
278+
TestInputDef<int64_t>({2}, true, {6, 0}),
279+
{utils::MakeAttribute("allowzero", static_cast<int64_t>(1))},
280+
ExpectedEPNodeAssignment::None,
281+
19);
282+
}
283+
284+
// Test that QDQ Reshape with allowzero=1 but no zeros in shape IS supported by QNN EP.
285+
TEST_F(QnnHTPBackendTests, Reshape_AllowZeroAttr_NoZeros_Supported) {
276286
RunQDQReshapeExpandTestOnHTP<uint8_t>("Reshape",
277-
TestInputDef<float>({1, 3, 4, 4}, false, -10.0f, 10.0f),
278-
TestInputDef<int64_t>({2}, true, {1, 48}),
287+
TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)),
288+
TestInputDef<int64_t>({2}, true, {1, 48}), // concrete shape with no zeros
279289
{utils::MakeAttribute("allowzero", static_cast<int64_t>(1))},
280-
ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP.
281-
19); // Opset
290+
ExpectedEPNodeAssignment::All,
291+
19);
282292
}
283293

284294
// Test 8-bit QDQ Reshape of rank 4 -> rank 2.

0 commit comments

Comments
 (0)