Skip to content

Commit b56f3cf

Browse files
committed
tag scales for external data
1 parent 92bf722 commit b56f3cf

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

backends/xnnpack/operators/node_visitor.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def get_per_channel_dtype(
275275
return dtype
276276

277277
def get_quant_params(
278-
self, quant_params: QuantParams, xnn_graph: XNNGraph
278+
self, quant_params: QuantParams, xnn_graph: XNNGraph, external_tag: str = None
279279
) -> XNNQuantParams:
280280
if quant_params.per_channel:
281281
scale = cast(torch.Tensor, quant_params.scale)
@@ -291,13 +291,18 @@ def get_quant_params(
291291
ctypes.POINTER(ctypes.c_char * num_bytes),
292292
).contents
293293
scale_name = hashlib.sha256(bytes(scale_array)).hexdigest()
294+
scale_name = "scale_" + scale_name
294295
xnn_graph.constant_data.append(
295296
ConstantDataOffset(
296297
offset=UINT64_MAX, size=num_bytes, named_key=scale_name
297298
)
298299
)
300+
if external_tag is not None:
301+
logging.info(
302+
f"Adding constant data with name, key {scale_name} and external_tag {external_tag} to named_data_store"
303+
)
299304
self._named_data_store.add_named_data(
300-
scale_name, bytes(scale_array), CONSTANT_TENSOR_ALIGNMENT
305+
scale_name, bytes(scale_array), CONSTANT_TENSOR_ALIGNMENT, external_tag
301306
)
302307

303308
if quant_params.per_channel_group:
@@ -470,13 +475,19 @@ def define_tensor( # noqa: C901
470475
assert f"Unsupported weight per channel quantization axis for depthwise conv2d / conv_transpose2d : {quant_params.axis}, expecting 0 / 1."
471476

472477
# Serialize tensor value
478+
custom_meta = tensor.meta.get("custom", None)
479+
external_tag = (
480+
custom_meta.get("delegate_constant_tag", None) if custom_meta else None
481+
)
473482
ser_val = (
474483
XValue(xvalue_union=tvalue)
475484
if quant_params is None
476485
else XValue(
477486
xvalue_union=XNNQuantizedTensorValue(
478487
tensor_value=tvalue,
479-
quant_params=self.get_quant_params(quant_params, xnn_graph),
488+
quant_params=self.get_quant_params(
489+
quant_params, xnn_graph, external_tag
490+
),
480491
)
481492
)
482493
)
@@ -614,7 +625,7 @@ def get_serialized_buffer_index(
614625
f"Serializing constant data node {tensor} but tensor value has no bytes",
615626
)
616627
sha256_hash = hashlib.sha256(bytes(array))
617-
named_key = sha256_hash.hexdigest()
628+
named_key = tensor.name + "_" + sha256_hash.hexdigest()
618629

619630
size = const_val.untyped_storage().nbytes()
620631
xnn_graph.constant_data.append(
@@ -626,7 +637,6 @@ def get_serialized_buffer_index(
626637
custom_meta.get("delegate_constant_tag", None) if custom_meta else None
627638
)
628639
if external_tag is not None:
629-
external_tag = custom_meta.get("delegate_constant_tag", None)
630640
logging.info(
631641
f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store"
632642
)

0 commit comments

Comments
 (0)