Skip to content

Commit 5b514cb

Browse files
committed
make safetensors check more robust
1 parent 0ffed5e commit 5b514cb

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

src/transformers/integrations/torchao.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def __init__(self, hf_quantizer):
210210
def convert(
211211
self,
212212
input_dict: dict[str, torch.Tensor],
213+
source_patterns: list[str] | None = None,
213214
model: Optional[torch.nn.Module] = None,
214215
full_layer_name: str | None = None,
215216
missing_keys=None,
@@ -223,13 +224,13 @@ def convert(
223224
"_weight_qdata": torch.Tensor,
224225
"_weight_scale": torch.Tensor,
225226
}
226-
full_layer_name: "model.layers.0.self_attn.k_proj"
227+
full_layer_name: "model.layers.0.self_attn.k_proj.weight"
227228
228229
Given this, we reconstruct a Float8Tensor instance using the qdata and scale
229230
and return it as a dictionary with the full_layer_name as the key and the recovered
230231
Float8Tensor instance as the value.
231232
"""
232-
is_unsafe_serialization = "_weight_" not in list(input_dict.keys())[0]
233+
is_unsafe_serialization = list(input_dict.keys())[0] not in source_patterns
233234

234235
param_data = {}
235236
layer_name = ".".join(full_layer_name.split(".")[:-1])
@@ -240,7 +241,10 @@ def convert(
240241
weight = input_dict["weight"]
241242
else:
242243
for suffix in input_dict.keys():
243-
assert len(input_dict[suffix]) == 1
244+
if len(input_dict[suffix]) != 1:
245+
raise ValueError(
246+
f"Expected a single tensor for {suffix} but got {len(input_dict[suffix])} tensors instead"
247+
)
244248
param_data[f"{layer_name}.{suffix}"] = input_dict[suffix][0]
245249

246250
# If it's unsafe-serialized (i.e. not safetensors), no need for anything

src/transformers/quantizers/quantizer_torchao.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,6 @@ def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]
234234
return [k for k in unexpected_keys if not any(k.endswith(x) for x in self.full_ao_keys)]
235235

236236
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
237-
if "_weight_" in param_name:
238-
return True
239237
if self.pre_quantized:
240238
return False
241239
if self.quantization_config.quant_type == "autoquant":

0 commit comments

Comments
 (0)