Skip to content

Conversation

@liangel-02
Copy link
Contributor

@liangel-02 liangel-02 commented Dec 1, 2025

Summary
updating safetensors logic to support all tensor subclasses generically

Test
verified integration with vllm
unit tests pass: python tests/quantization/torchao_integration/test_torchao.py -k TorchAoSafeSerializationTest

@liangel-02 liangel-02 force-pushed the ao_safetensors branch 5 times, most recently from 8ee8375 to 32df8ec Compare December 1, 2025 18:52
@liangel-02
Copy link
Contributor Author

cc @SunMarc @MekkCyber

@liangel-02 liangel-02 force-pushed the ao_safetensors branch 3 times, most recently from 5fa770a to 80af667 Compare December 1, 2025 21:41
Copy link
Contributor

@MekkCyber MekkCyber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for working on this

Comment on lines 890 to 896
concrete_source_pattern = source_pattern
if isinstance(mapping, WeightConverter) and source_pattern is not None and "*" in source_pattern:
pattern_with_captures = source_pattern.replace("*", r"(.*?)")
pattern_regex = re.compile(f"^{pattern_with_captures}$")
concrete_source_pattern = extract_concrete_key_from_regex_pattern(
original_key, source_pattern, pattern_regex
)
Copy link
Contributor

@MekkCyber MekkCyber Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this ? could you explain a bit more why we need to change the regex handling here and in the other parts of the code?

Copy link
Contributor Author

@liangel-02 liangel-02 Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the torchao WeightConverter, the original code has hardcoded tensor data components (ie.
weight_qdata, weight_scale) mapped to the consolidated weight. however, these components change depending on the config used, with some also being optional. for max generality, i wanted to use wildcard matching (*_weight_* --> *weight*).

but with the current regex handling, the code was renaming the key with the literal "*weight*" rather than matching the regex. the changes i have extracts the prefix and uses that for source/target. lmk if this is an ok approach/if theres a better solution

@liangel-02 liangel-02 force-pushed the ao_safetensors branch 2 times, most recently from 13b1d0e to 5f770e1 Compare December 3, 2025 00:18
@liangel-02
Copy link
Contributor Author

liangel-02 commented Dec 3, 2025

@MekkCyber to avoid adding torchao specific logic to the core model loader and unblock landing, i reverted the changes to the regex logic and added all other tensor data names. in the future, maybe we can implement a strategy that generalizes this matching in a better way?

cc @jerryzh168

# check if the param_name is not in self.modules_to_not_convert
if any(key + "." in param_name or key == param_name for key in self.modules_to_not_convert):
return False
elif any(param_name.endswith(f":{x}") for x in self.full_ao_keys):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for this one, maybe just change it to (f"_{x} for x in self.full_ao_keys)" to be safe, at least with this we are confident it's correct. if there are no better general way to detect safetensors

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this is achieving the same thing as L237, i think we only need one check

@github-actions
Copy link
Contributor

github-actions bot commented Dec 3, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: torchao_integration

Copy link
Contributor

@MekkCyber MekkCyber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice alternative solution ! Thanks a lot

Comment on lines +539 to +548
# TODO: incr flexibility by generalizing the source patterns to match the format of "_weight_"
# note that the matching logic is greedy, so for ex, if _weight_scale is before _weight_scale_and_zero in this list, it will match _weight_scale always (this is incorrect)
# thus, the order of source_patterns is intentional
source_patterns=[
"_weight_qdata",
"_weight_scale_and_zero",
"_weight_scale",
"_weight_zero_point",
"_weight_act_pre_scale",
],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is greedy but to match for example _weight_scale and not _weight_scale_and_zero you can just do something like _weight_scale$, but ordering the keys works as well 👍

Comment on lines -2107 to 2108
print(name)
setattr(self, name, decoder)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this

else:
param_data[f"{full_layer_name}:qdata"] = input_dict["weight:qdata"]
for suffix in input_dict.keys():
assert len(input_dict[suffix]) == 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's do an if/else and raise an error to follow the same pattern we use in transformers, but i'm not sure if it's necessary since if we have the suffix in input_dict it means we collected at least a tensor for that suffix

Comment on lines +227 to +228
full_layer_name: "model.layers.0.self_attn.k_proj"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
full_layer_name: "model.layers.0.self_attn.k_proj"
full_layer_name: "model.layers.0.self_attn.k_proj.weight"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants