Skip to content

Commit f83b74c

Browse files
committed
after rebase
1 parent a213c68 commit f83b74c

File tree

5 files changed

+51
-61
lines changed

5 files changed

+51
-61
lines changed

src/transformers/core_model_loading.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def build_glob_alternation(
6969
"""
7070
src_group_to_glob: dict[str, str] = {}
7171
tgt_group_to_glob: dict[str, str] = {}
72-
group_to_pattern: dict[str, re.Pattern] = {}
7372
branches: list[str] = []
7473
i = 0
7574
for glob in globs:
@@ -78,12 +77,11 @@ def build_glob_alternation(
7877
group_name = f"g{i}"
7978
src_group_to_glob[group_name] = src
8079
i += 1
81-
# Convert the glob pattern to a regex with capture groups for wildcards
82-
pattern_with_captures = src.replace("*", r"(.*?)")
83-
group_to_pattern[group_name] = re.compile(f"^{pattern_with_captures}$")
8480
body = src.replace("*", r".*")
8581
branches.append(f"(?P<{group_name}>{body})")
86-
tgt_group_to_glob[group_name] = glob.target_keys[0] if isinstance(glob.target_keys, list) else glob.target_keys
82+
tgt_group_to_glob[group_name] = (
83+
glob.target_patterns[0] if isinstance(glob.target_patterns, list) else glob.target_patterns
84+
)
8785
else:
8886
group_name = f"g{i}"
8987
src_group_to_glob[group_name] = glob
@@ -94,7 +92,7 @@ def build_glob_alternation(
9492
tgt_group_to_glob[group_name] = glob
9593

9694
alternation = re.compile("|".join(branches))
97-
return alternation, src_group_to_glob, tgt_group_to_glob, group_to_pattern
95+
return alternation, src_group_to_glob, tgt_group_to_glob
9896

9997

10098
class ConversionOps:
@@ -336,7 +334,11 @@ def __post_init__(self):
336334
branches = []
337335
for i, source_pattern in enumerate(self.source_patterns):
338336
group_name = f"g{i}"
339-
pattern = source_pattern.replace(".*.", r"\..*\.")
337+
# support both glob-style (*) and regex-style (.*) wildcards
338+
if "*" in source_pattern and ".*" not in source_pattern:
339+
pattern = source_pattern.replace("*", r".*")
340+
else:
341+
pattern = source_pattern.replace(".*.", r"\..*\.")
340342
branches.append(f"(?P<{group_name}>{pattern})")
341343
self.compiled_sources = re.compile("|".join(branches))
342344

@@ -364,12 +366,20 @@ def rename_source_key(self, source_key: str) -> tuple[str, str | None]:
364366
source_pattern_that_matched = self.source_patterns[int(matching_group_name[1:])]
365367
# If we matched, we always replace with the first target pattern, in case we have several (one to many transform)
366368
replacement = self.target_patterns[0]
367-
# # Allow capturing groups in patterns, i.e. to add a prefix to all keys (e.g. timm_wrapper, sam3)
368-
if r"\1" in replacement:
369+
370+
if "*" in replacement and "*" in source_pattern_that_matched:
371+
pattern_with_captures = source_pattern_that_matched.replace("*", r"(.*?)")
372+
pattern_regex = re.compile(f"^{pattern_with_captures}$")
373+
match = pattern_regex.match(source_key)
374+
if match:
375+
groups = match.groups()
376+
replacement = replacement.replace("*", groups[0], 1)
377+
elif r"\1" in replacement:
369378
# The index of the internal group we need to replace is the index of the matched named group as it comes
370379
# inside that matched named group
371380
replaced_group_idx = self.compiled_sources.groupindex[matching_group_name] + 1
372381
replacement = replacement.replace(r"\1", match_object.group(replaced_group_idx))
382+
373383
renamed_key = source_key.replace(match_object.group(0), replacement)
374384

375385
return renamed_key, source_pattern_that_matched
@@ -805,9 +815,9 @@ def convert_and_load_state_dict_in_model(
805815
# build '(?P<g0>.*.*\\.block_sparse_moe\\..*)' and group to source {'g0': '*.block_sparse_moe.'}
806816
# and target to source {'g0': '*.mlp.'}. This allows us to quickly find which pattern matched.
807817
if tp_plan != {}:
808-
tp_plan_alt, tp_plan_by_group_name, _, _ = build_glob_alternation(list(tp_plan.keys()))
818+
tp_plan_alt, tp_plan_by_group_name, _ = build_glob_alternation(list(tp_plan.keys()))
809819
if dtype_plan != {}:
810-
dtype_policy_alt, dtype_policy_by_group_name, _, _ = build_glob_alternation(list(dtype_plan.keys()))
820+
dtype_policy_alt, dtype_policy_by_group_name, _ = build_glob_alternation(list(dtype_plan.keys()))
811821

812822
pattern_to_converter = {k: converter for converter in converters for k in converter.source_patterns}
813823

@@ -877,7 +887,15 @@ def convert_and_load_state_dict_in_model(
877887
param_device = "cpu" if param_device == "disk" else param_device
878888
future = spawn_materialize(thread_pool, tensor, param_device, _dtype)
879889

880-
mapping.add_tensor(renamed_key, original_key, source_pattern, future)
890+
concrete_source_pattern = source_pattern
891+
if isinstance(mapping, WeightConverter) and source_pattern is not None and "*" in source_pattern:
892+
pattern_with_captures = source_pattern.replace("*", r"(.*?)")
893+
pattern_regex = re.compile(f"^{pattern_with_captures}$")
894+
concrete_source_pattern = extract_concrete_key_from_regex_pattern(
895+
original_key, source_pattern, pattern_regex
896+
)
897+
898+
mapping.add_tensor(renamed_key, original_key, concrete_source_pattern, future)
881899
elif source_pattern is not None: # add all target keys as unexpected
882900
mapping = pattern_to_converter[source_pattern]
883901
for k in mapping.target_patterns:

src/transformers/integrations/torchao.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
if is_torchao_available():
3434
TORCHAO_VERSION = version.parse(importlib.metadata.version("torchao"))
35-
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.14.0"):
35+
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.15.0"):
3636
from torchao.prototype.safetensors.safetensors_support import (
3737
unflatten_tensor_state_dict,
3838
)
@@ -218,27 +218,28 @@ def convert(
218218
is_unsafe_serialization = "_weight_" not in list(input_dict.keys())[0]
219219

220220
param_data = {}
221-
layer_name = '.'.join(full_layer_name.split(".")[:-1])
221+
layer_name = ".".join(full_layer_name.split(".")[:-1])
222222
if is_unsafe_serialization:
223223
if isinstance(input_dict["weight"], list):
224224
weight = input_dict["weight"][0]
225225
else:
226226
weight = input_dict["weight"]
227227
else:
228228
for suffix in input_dict.keys():
229-
if isinstance(input_dict[suffix], list):
230-
param_data[f"{layer_name}.{suffix}"] = input_dict[suffix][0]
231-
else:
232-
param_data[f"{layer_name}.{suffix}"] = input_dict[suffix]
229+
assert len(input_dict[suffix]) == 1
230+
param_data[f"{layer_name}.{suffix}"] = input_dict[suffix][0]
233231

234232
# If it's unsafe-serialized (i.e. not safetensors), no need for anything
235233
if is_unsafe_serialization:
236234
return {full_layer_name: weight}
237235
# Sanity check for the new serialization format
238236
elif not (TORCHAO_VERSION >= version.parse("0.15.0") and is_metadata_torchao(self.hf_quantizer.metadata)):
239-
raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed")
237+
raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.15.0` installed")
240238

241-
unflattened_state_dict, _ = unflatten_tensor_state_dict(param_data, self.hf_quantizer.metadata)
239+
unflattened_state_dict, leftover_state_dict = unflatten_tensor_state_dict(
240+
param_data, self.hf_quantizer.metadata
241+
)
242+
assert not leftover_state_dict # there should be no unprocessed tensors
242243
new_param = unflattened_state_dict[full_layer_name]
243244

244245
module, _ = get_module_from_name(model, full_layer_name)

src/transformers/modeling_utils.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3946,19 +3946,6 @@ def from_pretrained(
39463946

39473947
is_quantized = hf_quantizer is not None
39483948

3949-
weight_conversions: Optional[list[WeightConverter | WeightRenaming]] = None
3950-
model_type = getattr(config, "model_type", None)
3951-
if model_type is not None:
3952-
weight_conversions = get_checkpoint_conversion_mapping(model_type)
3953-
if weight_conversions is None:
3954-
weight_conversions = get_checkpoint_conversion_mapping("legacy")
3955-
if key_mapping is not None:
3956-
weight_conversions.extend(
3957-
[WeightRenaming(source_keys=k, target_keys=v) for k, v in key_mapping.items()]
3958-
)
3959-
if hf_quantizer is not None:
3960-
weight_conversions.extend(hf_quantizer.get_weight_conversions())
3961-
39623949
if gguf_file:
39633950
from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
39643951

src/transformers/quantizers/quantizer_torchao.py

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,11 @@
4040
import torch.nn as nn
4141

4242
if is_torchao_available():
43-
import torchao
44-
45-
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.14.0"):
43+
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.15.0"):
4644
from torchao.prototype.safetensors.safetensors_support import (
4745
flatten_tensor_state_dict,
4846
unflatten_tensor_state_dict,
4947
)
50-
from torchao.prototype.awq import AWQConfig
5148
from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao
5249

5350

@@ -89,16 +86,6 @@ def _linear_extra_repr(self):
8986

9087

9188
if is_torchao_available():
92-
SUPPORTED_SAFE_SERIALIZATION_CONFIGS = [
93-
torchao.quantization.Float8WeightOnlyConfig,
94-
torchao.quantization.Float8DynamicActivationFloat8WeightConfig,
95-
torchao.quantization.Int4WeightOnlyConfig,
96-
torchao.quantization.IntxWeightOnlyConfig,
97-
torchao.quantization.Int8DynamicActivationIntxWeightConfig,
98-
torchao.quantization.ModuleFqnToConfig,
99-
AWQConfig,
100-
]
101-
10289
TORCHAO_VERSION = version.parse(importlib.metadata.version("torchao"))
10390

10491

@@ -177,12 +164,12 @@ def get_state_dict_and_metadata(self, model, safe_serialization: bool | None = F
177164
If the model is safe serializable, we flatten the state dict of tensor subclasses so that it is compatible with
178165
the safetensors format.
179166
"""
180-
if type(self.quantization_config.quant_type) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and safe_serialization:
181-
if TORCHAO_VERSION >= version.parse("0.14.0"):
167+
if safe_serialization:
168+
if TORCHAO_VERSION >= version.parse("0.15.0"):
182169
return flatten_tensor_state_dict(model.state_dict())
183170
else:
184171
raise RuntimeError(
185-
f"In order to use safetensors with torchao, please use torchao version >= 0.14.0. Current version: {TORCHAO_VERSION}"
172+
f"In order to use safetensors with torchao, please use torchao version >= 0.15.0. Current version: {TORCHAO_VERSION}"
186173
)
187174
else:
188175
return None, {}
@@ -314,8 +301,8 @@ def create_quantized_param(
314301
)
315302
return
316303
# Sanity check for the new serialization format
317-
elif not (TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(self.metadata)):
318-
raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed")
304+
elif not (TORCHAO_VERSION >= version.parse("0.15.0") and is_metadata_torchao(self.metadata)):
305+
raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.15.0` installed")
319306

320307
# Save the states for later quantization when they are all gathered
321308
if not hasattr(self, "ao_params"):
@@ -460,13 +447,10 @@ def _process_model_after_weight_loading(self, model, **kwargs):
460447

461448
def is_serializable(self, safe_serialization=None) -> bool:
462449
if safe_serialization:
463-
_is_torchao_serializable = type(
464-
self.quantization_config.quant_type
465-
) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and TORCHAO_VERSION >= version.parse("0.14.0")
466-
if not _is_torchao_serializable:
450+
_is_torchao_serializable = TORCHAO_VERSION >= version.parse("0.15.0")
451+
if not TORCHAO_VERSION >= version.parse("0.15.0"):
467452
logger.warning(
468-
f"torchao quantized model only supports safe serialization for {SUPPORTED_SAFE_SERIALIZATION_CONFIGS}, \
469-
and torchao version >= 0.14.0, please set `safe_serialization` to False for \
453+
f"torchao quantized model only supports safe serialization for torchao version >= 0.15.0, please set `safe_serialization` to False for \
470454
{type(self.quantization_config.quant_type)} and {TORCHAO_VERSION}."
471455
)
472456
return _is_torchao_serializable
@@ -556,8 +540,8 @@ def get_weight_conversions(self):
556540
if self.pre_quantized:
557541
return [
558542
WeightConverter(
559-
source_keys=["*_weight_*"],
560-
target_keys="*weight",
543+
source_patterns=["*_weight_*"],
544+
target_patterns="*weight",
561545
operations=[TorchAoDeserialize(self)],
562546
),
563547
]

tests/quantization/torchao_integration/test_torchao.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -738,7 +738,7 @@ def test_serialization_expected_output(self):
738738

739739

740740
@require_torchao
741-
@require_torchao_version_greater_or_equal("0.14.0")
741+
@require_torchao_version_greater_or_equal("0.15.0")
742742
class TorchAoSafeSerializationTest(TorchAoSerializationTest):
743743
# called only once for all test in this class
744744
@classmethod

0 commit comments

Comments
 (0)