Skip to content

Commit 55c015e

Browse files
committed
tag scales for external data
1 parent 92bf722 commit 55c015e

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

backends/xnnpack/operators/node_visitor.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def get_per_channel_dtype(
275275
return dtype
276276

277277
def get_quant_params(
278-
self, quant_params: QuantParams, xnn_graph: XNNGraph
278+
self, quant_params: QuantParams, xnn_graph: XNNGraph, external_tag: str = None
279279
) -> XNNQuantParams:
280280
if quant_params.per_channel:
281281
scale = cast(torch.Tensor, quant_params.scale)
@@ -291,13 +291,18 @@ def get_quant_params(
291291
ctypes.POINTER(ctypes.c_char * num_bytes),
292292
).contents
293293
scale_name = hashlib.sha256(bytes(scale_array)).hexdigest()
294+
scale_name = "scale_" + scale_name
294295
xnn_graph.constant_data.append(
295296
ConstantDataOffset(
296297
offset=UINT64_MAX, size=num_bytes, named_key=scale_name
297298
)
298299
)
300+
if external_tag is not None:
301+
logging.info(
302+
f"Adding constant data with name, key {scale_name} and external_tag {external_tag} to named_data_store"
303+
)
299304
self._named_data_store.add_named_data(
300-
scale_name, bytes(scale_array), CONSTANT_TENSOR_ALIGNMENT
305+
scale_name, bytes(scale_array), CONSTANT_TENSOR_ALIGNMENT, external_tag
301306
)
302307

303308
if quant_params.per_channel_group:
@@ -470,13 +475,17 @@ def define_tensor( # noqa: C901
470475
assert f"Unsupported weight per channel quantization axis for depthwise conv2d / conv_transpose2d : {quant_params.axis}, expecting 0 / 1."
471476

472477
# Serialize tensor value
478+
custom_meta = tensor.meta.get("custom", None)
479+
external_tag = (
480+
custom_meta.get("delegate_constant_tag", None) if custom_meta else None
481+
)
473482
ser_val = (
474483
XValue(xvalue_union=tvalue)
475484
if quant_params is None
476485
else XValue(
477486
xvalue_union=XNNQuantizedTensorValue(
478487
tensor_value=tvalue,
479-
quant_params=self.get_quant_params(quant_params, xnn_graph),
488+
quant_params=self.get_quant_params(quant_params, xnn_graph, external_tag),
480489
)
481490
)
482491
)
@@ -626,7 +635,6 @@ def get_serialized_buffer_index(
626635
custom_meta.get("delegate_constant_tag", None) if custom_meta else None
627636
)
628637
if external_tag is not None:
629-
external_tag = custom_meta.get("delegate_constant_tag", None)
630638
logging.info(
631639
f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store"
632640
)

0 commit comments

Comments
 (0)