@@ -275,7 +275,7 @@ def get_per_channel_dtype(
275275 return dtype
276276
277277 def get_quant_params (
278- self , quant_params : QuantParams , xnn_graph : XNNGraph
278+ self , quant_params : QuantParams , xnn_graph : XNNGraph , external_tag : str = None
279279 ) -> XNNQuantParams :
280280 if quant_params .per_channel :
281281 scale = cast (torch .Tensor , quant_params .scale )
@@ -291,13 +291,18 @@ def get_quant_params(
291291 ctypes .POINTER (ctypes .c_char * num_bytes ),
292292 ).contents
293293 scale_name = hashlib .sha256 (bytes (scale_array )).hexdigest ()
294+ scale_name = "scale_" + scale_name
294295 xnn_graph .constant_data .append (
295296 ConstantDataOffset (
296297 offset = UINT64_MAX , size = num_bytes , named_key = scale_name
297298 )
298299 )
300+ if external_tag is not None :
301+ logging .info (
302+ f"Adding constant data with name, key { scale_name } and external_tag { external_tag } to named_data_store"
303+ )
299304 self ._named_data_store .add_named_data (
300- scale_name , bytes (scale_array ), CONSTANT_TENSOR_ALIGNMENT
305+ scale_name , bytes (scale_array ), CONSTANT_TENSOR_ALIGNMENT , external_tag
301306 )
302307
303308 if quant_params .per_channel_group :
@@ -470,13 +475,19 @@ def define_tensor( # noqa: C901
470475 assert f"Unsupported weight per channel quantization axis for depthwise conv2d / conv_transpose2d : { quant_params .axis } , expecting 0 / 1."
471476
472477 # Serialize tensor value
478+ custom_meta = tensor .meta .get ("custom" , None )
479+ external_tag = (
480+ custom_meta .get ("delegate_constant_tag" , None ) if custom_meta else None
481+ )
473482 ser_val = (
474483 XValue (xvalue_union = tvalue )
475484 if quant_params is None
476485 else XValue (
477486 xvalue_union = XNNQuantizedTensorValue (
478487 tensor_value = tvalue ,
479- quant_params = self .get_quant_params (quant_params , xnn_graph ),
488+ quant_params = self .get_quant_params (
489+ quant_params , xnn_graph , external_tag
490+ ),
480491 )
481492 )
482493 )
@@ -614,7 +625,7 @@ def get_serialized_buffer_index(
614625 f"Serializing constant data node { tensor } but tensor value has no bytes" ,
615626 )
616627 sha256_hash = hashlib .sha256 (bytes (array ))
617- named_key = sha256_hash .hexdigest ()
628+ named_key = tensor . name + "_" + sha256_hash .hexdigest ()
618629
619630 size = const_val .untyped_storage ().nbytes ()
620631 xnn_graph .constant_data .append (
@@ -626,7 +637,6 @@ def get_serialized_buffer_index(
626637 custom_meta .get ("delegate_constant_tag" , None ) if custom_meta else None
627638 )
628639 if external_tag is not None :
629- external_tag = custom_meta .get ("delegate_constant_tag" , None )
630640 logging .info (
631641 f"Adding constant data with name { tensor .name } , key { named_key } and external_tag { external_tag } to named_data_store"
632642 )
0 commit comments