@@ -210,6 +210,7 @@ def __init__(self, hf_quantizer):
210210 def convert (
211211 self ,
212212 input_dict : dict [str , torch .Tensor ],
213+ source_patterns : list [str ] | None = None ,
213214 model : Optional [torch .nn .Module ] = None ,
214215 full_layer_name : str | None = None ,
215216 missing_keys = None ,
@@ -223,13 +224,13 @@ def convert(
223224 "_weight_qdata": torch.Tensor,
224225 "_weight_scale": torch.Tensor,
225226 }
226- full_layer_name: "model.layers.0.self_attn.k_proj"
227+ full_layer_name: "model.layers.0.self_attn.k_proj.weight "
227228
228229 Given this, we reconstruct a Float8Tensor instance using the qdata and scale
229230 and return it as a dictionary with the full_layer_name as the key and the recovered
230231 Float8Tensor instance as the value.
231232 """
232- is_unsafe_serialization = "_weight_" not in list (input_dict .keys ())[0 ]
233+ is_unsafe_serialization = list (input_dict .keys ())[0 ] not in source_patterns
233234
234235 param_data = {}
235236 layer_name = "." .join (full_layer_name .split ("." )[:- 1 ])
@@ -240,7 +241,10 @@ def convert(
240241 weight = input_dict ["weight" ]
241242 else :
242243 for suffix in input_dict .keys ():
243- assert len (input_dict [suffix ]) == 1
244+ if len (input_dict [suffix ]) != 1 :
245+ raise ValueError (
246+ f"Expected a single tensor for { suffix } but got { len (input_dict [suffix ])} tensors instead"
247+ )
244248 param_data [f"{ layer_name } .{ suffix } " ] = input_dict [suffix ][0 ]
245249
246250 # If it's unsafe-serialized (i.e. not safetensors), no need for anything
0 commit comments