@@ -13,6 +13,34 @@ bool IsOptionalNodeUnitIODef(const NodeUnitIODef& node_io_def) {
1313 const NodeArg& arg = node_io_def.node_arg ;
1414 return !arg.Exists () || arg.Name ().empty ();
1515}
16+
17+ // Function to check whether we should skip processing null input which has 0 dim in shape.
18+ // Such null inputs often exist in models saved from PyTorch, especially for Concat.
19+ bool DoesConcatInputShapeContainZero (QnnModelWrapper& qnn_model_wrapper,
20+ const NodeUnit& node_unit,
21+ const NodeUnitIODef& node_io_def,
22+ const logging::Logger& logger) {
23+ // Although the 0 dim issue should be handled for all op types, restricting in Concat for now since current cases
24+ // only happen on one of Concat inputs. One may rename the function and relax the checking here to extend for other
25+ // ops.
26+ if (node_unit.OpType () != " Concat" ) {
27+ return false ;
28+ }
29+
30+ std::vector<uint32_t > input_shape;
31+ if (!qnn_model_wrapper.GetOnnxShape (node_io_def.node_arg , input_shape)) {
32+ return false ;
33+ }
34+
35+ for (const uint32_t & dim : input_shape) {
36+ if (dim == 0 ) {
37+ LOGS (logger, WARNING) << " Tensor has 0 dim, ignore this input: " << node_io_def.node_arg .Name ();
38+ return true ;
39+ }
40+ }
41+
42+ return false ;
43+ }
1644} // namespace
1745
1846std::string BaseOpBuilder::GetOpBuilderType () const {
@@ -126,7 +154,9 @@ Status BaseOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
126154 const auto & inputs = node_unit.Inputs ();
127155 const auto input_count = GetInputCountQnnRequired (node_unit);
128156 for (size_t input_i = 0 ; input_i < input_count; ++input_i) {
129- ORT_RETURN_IF_ERROR (ProcessInput (qnn_model_wrapper, inputs[input_i], logger, input_names));
157+ if (!DoesConcatInputShapeContainZero (qnn_model_wrapper, node_unit, inputs[input_i], logger)) {
158+ ORT_RETURN_IF_ERROR (ProcessInput (qnn_model_wrapper, inputs[input_i], logger, input_names));
159+ }
130160 }
131161
132162 return Status::OK ();
0 commit comments