-
Notifications
You must be signed in to change notification settings - Fork 31.3k
[torchao] safetensors #42529
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[torchao] safetensors #42529
Conversation
8ee8375 to
32df8ec
Compare
32df8ec to
78f0a39
Compare
78f0a39 to
543b810
Compare
5fa770a to
80af667
Compare
MekkCyber
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for working on this
| concrete_source_pattern = source_pattern | ||
| if isinstance(mapping, WeightConverter) and source_pattern is not None and "*" in source_pattern: | ||
| pattern_with_captures = source_pattern.replace("*", r"(.*?)") | ||
| pattern_regex = re.compile(f"^{pattern_with_captures}$") | ||
| concrete_source_pattern = extract_concrete_key_from_regex_pattern( | ||
| original_key, source_pattern, pattern_regex | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need this ? could you explain a bit more why we need to change the regex handling here and in the other parts of the code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for the torchao WeightConverter, the original code has hardcoded tensor data components (ie.
weight_qdata, weight_scale) mapped to the consolidated weight. however, these components change depending on the config used, with some also being optional. for max generality, i wanted to use wildcard matching (*_weight_* --> *weight*).
but with the current regex handling, the code was renaming the key with the literal "*weight*" rather than matching the regex. the changes i have extracts the prefix and uses that for source/target. lmk if this is an ok approach/if theres a better solution
80af667 to
f83b74c
Compare
5f770e1 to
0a4e634
Compare
|
@MekkCyber to avoid adding torchao specific logic to the core model loader and unblock landing, i reverted the changes to the regex logic and added all other tensor data names. in the future, maybe we can implement a strategy that generalizes this matching in a better way? cc @jerryzh168 |
| # check if the param_name is not in self.modules_to_not_convert | ||
| if any(key + "." in param_name or key == param_name for key in self.modules_to_not_convert): | ||
| return False | ||
| elif any(param_name.endswith(f":{x}") for x in self.full_ao_keys): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for this one, maybe just change it to (f"_{x} for x in self.full_ao_keys)" to be safe, at least with this we are confident it's correct. if there are no better general way to detect safetensors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think this is achieving the same thing as L237, i think we only need one check
b58354a to
d59bb71
Compare
MekkCyber
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice alternative solution ! Thanks a lot
| else: | ||
| param_data[f"{full_layer_name}:qdata"] = input_dict["weight:qdata"] | ||
| for suffix in input_dict.keys(): | ||
| assert len(input_dict[suffix]) == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's do an if/else and raise an error to follow the same pattern we use in transformers, but i'm not sure if it's necessary since if we have the suffix in input_dict it means we collected at least a tensor for that suffix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will add an if/else! we wanted to add this assert since there also shouldn't be more than one tensor for a component (qdata, scale, etc)
d59bb71 to
cc9070a
Compare
|
[For maintainers] Suggested jobs to run (before merge) run-slow: torchao_integration |
cc9070a to
5b514cb
Compare
|
@MekkCyber @SunMarc can you help me merge this, thanks! |
Summary
updating safetensors logic to support all tensor subclasses generically
Test
verified integration with vllm
unit tests pass:
python tests/quantization/torchao_integration/test_torchao.py -k TorchAoSafeSerializationTest