Skip to content

Commit 21e4877

Browse files
committed
[QNN EP] Handle 0-dim tensor for Concat.
There may exsit tensor with 0-dim in shape, especially for Concat's inputs. Modify the base op builder to ignore such tensor during construction. Test: UT.
1 parent 790018d commit 21e4877

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,34 @@ bool IsOptionalNodeUnitIODef(const NodeUnitIODef& node_io_def) {
1313
const NodeArg& arg = node_io_def.node_arg;
1414
return !arg.Exists() || arg.Name().empty();
1515
}
16+
17+
// Function to check whether we should skip processing null input which has 0 dim in shape.
18+
// Such null inputs often exist in models saved from PyTorch, especially for Concat.
19+
bool DoesConcatInputShapeContainZero(QnnModelWrapper& qnn_model_wrapper,
20+
const NodeUnit& node_unit,
21+
const NodeUnitIODef& node_io_def,
22+
const logging::Logger& logger) {
23+
// Although the 0 dim issue should be handled for all op types, restricting in Concat for now since current cases
24+
// only happen on one of Concat inputs. One may rename the function and relax the checking here to extend for other
25+
// ops.
26+
if (node_unit.OpType() != "Concat") {
27+
return false;
28+
}
29+
30+
std::vector<uint32_t> input_shape;
31+
if (!qnn_model_wrapper.GetOnnxShape(node_io_def.node_arg, input_shape)) {
32+
return false;
33+
}
34+
35+
for (const uint32_t& dim : input_shape) {
36+
if (dim == 0) {
37+
LOGS(logger, WARNING) << "Tensor has 0 dim, ignore this input: " << node_io_def.node_arg.Name();
38+
return true;
39+
}
40+
}
41+
42+
return false;
43+
}
1644
} // namespace
1745

1846
std::string BaseOpBuilder::GetOpBuilderType() const {
@@ -126,7 +154,9 @@ Status BaseOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
126154
const auto& inputs = node_unit.Inputs();
127155
const auto input_count = GetInputCountQnnRequired(node_unit);
128156
for (size_t input_i = 0; input_i < input_count; ++input_i) {
129-
ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[input_i], logger, input_names));
157+
if (!DoesConcatInputShapeContainZero(qnn_model_wrapper, node_unit, inputs[input_i], logger)) {
158+
ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[input_i], logger, input_names));
159+
}
130160
}
131161

132162
return Status::OK();

onnxruntime/test/providers/qnn/simple_op_htp_test.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,15 @@ TEST_F(QnnCPUBackendTests, DISABLED_UnaryOp_Relu) {
128128
ExpectedEPNodeAssignment::All);
129129
}
130130

131+
TEST_F(QnnCPUBackendTests, Concat_EmptyInput) {
132+
RunOpTestOnCPU("Concat",
133+
{TestInputDef<float>({1, 3, 4, 4}, false, -10.0f, 10.0f),
134+
TestInputDef<float>({1, 0, 4, 4}, false, {})},
135+
{utils::MakeAttribute("axis", static_cast<int64_t>(1))},
136+
13,
137+
ExpectedEPNodeAssignment::All);
138+
}
139+
131140
#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
132141

133142
// Tests the accuracy of a QDQ model on QNN EP by comparing to CPU EP, which runs both the fp32 model

0 commit comments

Comments
 (0)