Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions backends/xnnpack/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def get_per_channel_dtype(
return dtype

def get_quant_params(
self, quant_params: QuantParams, xnn_graph: XNNGraph
self, quant_params: QuantParams, xnn_graph: XNNGraph, external_tag: str = None
) -> XNNQuantParams:
if quant_params.per_channel:
scale = cast(torch.Tensor, quant_params.scale)
Expand All @@ -291,13 +291,18 @@ def get_quant_params(
ctypes.POINTER(ctypes.c_char * num_bytes),
).contents
scale_name = hashlib.sha256(bytes(scale_array)).hexdigest()
scale_name = "scale_" + scale_name
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add this for debugging purposes

xnn_graph.constant_data.append(
ConstantDataOffset(
offset=UINT64_MAX, size=num_bytes, named_key=scale_name
)
)
if external_tag is not None:
logging.info(
f"Adding constant data with name, key {scale_name} and external_tag {external_tag} to named_data_store"
)
self._named_data_store.add_named_data(
scale_name, bytes(scale_array), CONSTANT_TENSOR_ALIGNMENT
scale_name, bytes(scale_array), CONSTANT_TENSOR_ALIGNMENT, external_tag
)

if quant_params.per_channel_group:
Expand Down Expand Up @@ -470,13 +475,19 @@ def define_tensor( # noqa: C901
assert f"Unsupported weight per channel quantization axis for depthwise conv2d / conv_transpose2d : {quant_params.axis}, expecting 0 / 1."

# Serialize tensor value
custom_meta = tensor.meta.get("custom", None)
external_tag = (
custom_meta.get("delegate_constant_tag", None) if custom_meta else None
)
ser_val = (
XValue(xvalue_union=tvalue)
if quant_params is None
else XValue(
xvalue_union=XNNQuantizedTensorValue(
tensor_value=tvalue,
quant_params=self.get_quant_params(quant_params, xnn_graph),
quant_params=self.get_quant_params(
quant_params, xnn_graph, external_tag
),
)
)
)
Expand Down Expand Up @@ -614,7 +625,7 @@ def get_serialized_buffer_index(
f"Serializing constant data node {tensor} but tensor value has no bytes",
)
sha256_hash = hashlib.sha256(bytes(array))
named_key = sha256_hash.hexdigest()
named_key = tensor.name + "_" + sha256_hash.hexdigest()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also debugging here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you plan on landing this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@metascroy yeah - adding the tensor name to the hash for general debuggability until we can use names instead of hash


size = const_val.untyped_storage().nbytes()
xnn_graph.constant_data.append(
Expand All @@ -626,7 +637,6 @@ def get_serialized_buffer_index(
custom_meta.get("delegate_constant_tag", None) if custom_meta else None
)
if external_tag is not None:
external_tag = custom_meta.get("delegate_constant_tag", None)
logging.info(
f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store"
)
Expand Down
Loading