Skip to content

Commit c18625a

Browse files
committed
[QNN EP] Support NonZero.
- Implement NonZero op builder and regsiter QDQ selector. - Implement ShapeNonZero QNN preprocess to fix shape. Test: UTs.
1 parent ee0ffd5 commit c18625a

File tree

9 files changed

+428
-2
lines changed

9 files changed

+428
-2
lines changed

onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() {
4848
// These produce int64 indices output, which can't be quantized, so there's no downstream Q node.
4949
static const OpVersionsAndSelector::OpVersionsMap GetDropDQOpVersionsMap() {
5050
return {{"ArgMax", {}},
51-
{"ArgMin", {}}};
51+
{"ArgMin", {}},
52+
{"NonZero", {}}};
5253
}
5354

5455
static const OpVersionsAndSelector::OpVersionsMap GetUnaryOpVersionsMap() {

onnxruntime/core/providers/qnn/builder/op_builder_factory.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() {
223223
{
224224
CreateInverseOpBuilder("Inverse", *this);
225225
}
226+
227+
{
228+
CreateNonZeroOpBuilder("NonZero", *this);
229+
}
226230
}
227231

228232
const IOpBuilder* GetOpBuilder(const std::string& onnx_op_type) {

onnxruntime/core/providers/qnn/builder/op_builder_factory.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,5 +125,7 @@ void CreateSTFTOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_
125125

126126
void CreateInverseOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
127127

128+
void CreateNonZeroOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
129+
128130
} // namespace qnn
129131
} // namespace onnxruntime

onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ class BaseOpBuilder : public IOpBuilder {
156156
{"Max", QNN_OP_ELEMENT_WISE_MAXIMUM},
157157
{"Min", QNN_OP_ELEMENT_WISE_MINIMUM},
158158
{"Neg", QNN_OP_ELEMENT_WISE_NEG},
159+
{"NonZero", QNN_OP_NON_ZERO},
159160
{"Not", QNN_OP_ELEMENT_WISE_NOT},
160161
{"Or", QNN_OP_ELEMENT_WISE_OR},
161162
{"Pow", QNN_OP_ELEMENT_WISE_POWER},
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include <string>
5+
#include <vector>
6+
7+
#include "core/providers/qnn/builder/opbuilder/base_op_builder.h"
8+
#include "core/providers/qnn/builder/op_builder_factory.h"
9+
#include "core/providers/qnn/builder/qnn_utils.h"
10+
11+
namespace onnxruntime {
12+
namespace qnn {
13+
14+
class NonZeroOpBuilder : public BaseOpBuilder {
15+
public:
16+
NonZeroOpBuilder() : BaseOpBuilder("NonZeroOpBuilder") {}
17+
18+
protected:
19+
Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
20+
const NodeUnit& node_unit,
21+
std::vector<std::string>&& input_names,
22+
const logging::Logger& logger,
23+
bool do_op_validation) const override ORT_MUST_USE_RESULT;
24+
};
25+
26+
Status NonZeroOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
27+
const NodeUnit& node_unit,
28+
std::vector<std::string>&& input_names,
29+
const logging::Logger& logger,
30+
bool do_op_validation) const {
31+
// Handle a corner case explicitly, which can pass backend validation but in fact not executable.
32+
const std::vector<uint32_t>& input_shape = qnn_model_wrapper.GetQnnTensorWrapper(input_names[0]).GetTensorDims();
33+
for (const uint32_t& dim : input_shape) {
34+
ORT_RETURN_IF(dim == 0, "QNN does not support NonZero with empty input.");
35+
}
36+
37+
const auto& output = node_unit.Outputs()[0];
38+
const std::string& output_name = output.node_arg.Name();
39+
40+
TensorInfo output_info = {};
41+
Status status = qnn_model_wrapper.GetTensorInfo(output, output_info);
42+
if (!status.IsOK()) {
43+
LOGS(logger, ERROR) << "Encountering NonZero " << node_unit.Name() << " which has dynamically shaped output tensor."
44+
<< "QNN supports NonZero by allocating maximum possible size (i.e., all elements != 0), "
45+
<< "and fills only the detected nonzero elements in the output tensor."
46+
<< "The model must be preproceesed to eliminate the dynamic shapes first for QNN to support.";
47+
return status;
48+
}
49+
50+
// ONNX NonZero has shape [input_rank, #input_elements].
51+
uint32_t rank = output_info.shape[0];
52+
uint32_t num_elements = output_info.shape[1];
53+
54+
// QNN NonZero has shape [#input elements, input_rank], and thus an extra Transpose must be inserted afterwards.
55+
const std::string transpose_input_name = utils::GetUniqueName(output_name, +"_transpose");
56+
const std::vector<uint32_t> transpose_input_shape{num_elements, rank};
57+
QnnTensorWrapper output_tensorwrapper(transpose_input_name,
58+
QNN_TENSOR_TYPE_NATIVE,
59+
output_info.qnn_data_type,
60+
output_info.quant_param.Copy(),
61+
std::vector<uint32_t>(transpose_input_shape));
62+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor.");
63+
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit),
64+
QNN_OP_PACKAGE_NAME_QTI_AISW,
65+
GetQnnOpType(node_unit.OpType()),
66+
std::move(input_names),
67+
{transpose_input_name},
68+
{},
69+
do_op_validation),
70+
"Failed to add NonZero node.");
71+
72+
// NonZero's output is indices which is INT64 dtype. If it happens to be graph output as well, add a Cast node to
73+
// cast the dtype back to INT64 since wrapper construction implicitly changes the dtype to INT32.
74+
const bool is_cast_required = output_info.qnn_data_type == QNN_DATATYPE_INT_64 &&
75+
qnn_model_wrapper.IsGraphOutput(output_name);
76+
const std::string transpose_output_name = is_cast_required ? utils::GetUniqueName(output_name, "_cast") : output_name;
77+
78+
std::vector<uint32_t> transpose_perm{1, 0};
79+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddTransposeNode(node_unit.Index(),
80+
transpose_input_name,
81+
transpose_output_name,
82+
transpose_input_shape,
83+
transpose_perm,
84+
output_info.shape,
85+
output_info.qnn_data_type,
86+
output_info.quant_param,
87+
do_op_validation,
88+
false,
89+
false));
90+
91+
if (is_cast_required) {
92+
QnnTensorWrapper cast_output_tensorwrapper(output_name,
93+
QNN_TENSOR_TYPE_APP_READ,
94+
output_info.qnn_data_type,
95+
output_info.quant_param.Copy(),
96+
std::vector<uint32_t>(output_info.shape));
97+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(cast_output_tensorwrapper)),
98+
"Failed to add tensor.");
99+
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit, QNN_OP_CAST),
100+
QNN_OP_PACKAGE_NAME_QTI_AISW,
101+
QNN_OP_CAST,
102+
{transpose_output_name},
103+
{output_name},
104+
{},
105+
do_op_validation),
106+
"Failed to add node");
107+
}
108+
109+
return Status::OK();
110+
}
111+
112+
void CreateNonZeroOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
113+
op_registrations.AddOpBuilder(op_type, std::make_unique<NonZeroOpBuilder>());
114+
}
115+
116+
} // namespace qnn
117+
} // namespace onnxruntime

onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from ...onnx_model import ONNXModel
1818
from .fusion_lpnorm import FusionLpNormalization
1919
from .fusion_spacetodepth import FusionSpaceToDepth
20+
from .shape_nonzero import ShapeNonZero
2021

2122

2223
def qnn_preprocess_model(
@@ -108,6 +109,9 @@ def qnn_preprocess_model(
108109
if exclude_initializer_from_input:
109110
modified |= remove_initializer_from_input(onnx_model.model)
110111

112+
# Shape dynamic-shaped NonZero.
113+
modified |= ShapeNonZero(onnx_model).apply()
114+
111115
# Fuse Erf sequence into a single Gelu
112116
fusion_gelu = FusionGelu(onnx_model)
113117
if fusion_gelu.apply():
@@ -166,7 +170,7 @@ def qnn_preprocess_model(
166170
if modified:
167171
onnx_model.topological_sort()
168172
onnx.save_model(
169-
model,
173+
onnx_model.model,
170174
model_output,
171175
save_as_external_data=save_as_external_data,
172176
all_tensors_to_one_file=all_tensors_to_one_file,
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See License.txt in the project root for
4+
# license information.
5+
# --------------------------------------------------------------------------
6+
"""Define NonZero shape inference."""
7+
8+
import logging
9+
10+
import numpy as np
11+
import onnx
12+
13+
from ... import fusions, onnx_model
14+
15+
16+
class ShapeNonZero(fusions.Fusion):
17+
"""Shape inference for NonZero.
18+
19+
NonZero node produces dynamically shaped output tensor, causing the tensor shapes of following nodes undetermined
20+
as well. QNN expects NonZero having its shape set to maximum size (i.e., number of total input elements) and let
21+
runtime handle the dynamic shape later.
22+
"""
23+
24+
def __init__(self, model: onnx_model.ONNXModel):
25+
"""Initialize.
26+
Args:
27+
model: An onnx_model.ONNXModel instance.
28+
"""
29+
super().__init__(model, "", "NonZero")
30+
31+
def fuse(
32+
self,
33+
node: onnx.NodeProto,
34+
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
35+
output_name_to_node: dict[str, onnx.NodeProto],
36+
) -> bool:
37+
"""Infer shape for NonZero.
38+
39+
Args:
40+
node: An onnx.NodeProto matching the specified search type (i.e., NonZero).
41+
input_name_to_nodes: A dict mapping tensor name to consumed nodes.
42+
output_name_to_node: A dict mapping tensor name to produced node.
43+
44+
Returns:
45+
A bool indicating whether the node is updated.
46+
"""
47+
logging.warning(
48+
"The model contains a NonZero node which produces a dynamically shaped output tensor."
49+
"Following QNN requirements, its output shape will be deliberately set to the maximum size."
50+
)
51+
52+
if (input_tensor_type := self.model.get_tensor_type(node.input[0])) is None or (
53+
output_tensor_type := self.model.get_tensor_type(node.output[0])
54+
) is None:
55+
return False
56+
57+
if not (input_tensor_shape := self.tensor_shape_to_list(input_tensor_type)):
58+
return False
59+
60+
if not all(isinstance(dim, int) for dim in input_tensor_shape):
61+
return False
62+
63+
output_tensor_type.shape.dim[1].dim_value = np.prod(input_tensor_shape)
64+
return True
65+
66+
def apply(self) -> bool:
67+
"""Apply fusion.
68+
69+
This method is overridden to execute shape inference again since NonZero will have fixed shape.
70+
71+
Returns:
72+
A bool indicating whether the model is updated.
73+
"""
74+
input_name_to_nodes = self.model.input_name_to_nodes()
75+
output_name_to_node = self.model.output_name_to_node()
76+
77+
updated = False
78+
for node in self.model.nodes():
79+
if node.op_type == self.search_op_type:
80+
updated |= self.fuse(node, input_name_to_nodes, output_name_to_node)
81+
82+
if updated:
83+
self.model.model = onnx.shape_inference.infer_shapes(self.model.model)
84+
85+
return updated

0 commit comments

Comments
 (0)