Skip to content

Commit 32df8ec

Browse files
committed
after rebase
1 parent 661e342 commit 32df8ec

File tree

5 files changed

+32
-24
lines changed

5 files changed

+32
-24
lines changed

src/transformers/core_model_loading.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,13 @@ def build_glob_alternation(
7878
group_name = f"g{i}"
7979
src_group_to_glob[group_name] = src
8080
i += 1
81-
# Convert the glob pattern to a regex with capture groups for wildcards
8281
pattern_with_captures = src.replace("*", r"(.*?)")
8382
group_to_pattern[group_name] = re.compile(f"^{pattern_with_captures}$")
8483
body = src.replace("*", r".*")
8584
branches.append(f"(?P<{group_name}>{body})")
86-
tgt_group_to_glob[group_name] = glob.target_keys[0] if isinstance(glob.target_keys, list) else glob.target_keys
85+
tgt_group_to_glob[group_name] = (
86+
glob.target_patterns[0] if isinstance(glob.target_patterns, list) else glob.target_patterns
87+
)
8788
else:
8889
group_name = f"g{i}"
8990
src_group_to_glob[group_name] = glob
@@ -336,7 +337,11 @@ def __post_init__(self):
336337
branches = []
337338
for i, source_pattern in enumerate(self.source_patterns):
338339
group_name = f"g{i}"
339-
pattern = source_pattern.replace(".*.", r"\..*\.")
340+
# support both glob-style (*) and regex-style (.*) wildcards
341+
if "*" in source_pattern and ".*" not in source_pattern:
342+
pattern = source_pattern.replace("*", r".*")
343+
else:
344+
pattern = source_pattern.replace(".*.", r"\..*\.")
340345
branches.append(f"(?P<{group_name}>{pattern})")
341346
self.compiled_sources = re.compile("|".join(branches))
342347

@@ -364,12 +369,20 @@ def rename_source_key(self, source_key: str) -> tuple[str, str | None]:
364369
source_pattern_that_matched = self.source_patterns[int(matching_group_name[1:])]
365370
# If we matched, we always replace with the first target pattern, in case we have several (one to many transform)
366371
replacement = self.target_patterns[0]
367-
# # Allow capturing groups in patterns, i.e. to add a prefix to all keys (e.g. timm_wrapper, sam3)
368-
if r"\1" in replacement:
372+
373+
if "*" in replacement and "*" in source_pattern_that_matched:
374+
pattern_with_captures = source_pattern_that_matched.replace("*", r"(.*?)")
375+
pattern_regex = re.compile(f"^{pattern_with_captures}$")
376+
match = pattern_regex.match(source_key)
377+
if match:
378+
groups = match.groups()
379+
replacement = replacement.replace("*", groups[0], 1)
380+
elif r"\1" in replacement:
369381
# The index of the internal group we need to replace is the index of the matched named group as it comes
370382
# inside that matched named group
371383
replaced_group_idx = self.compiled_sources.groupindex[matching_group_name] + 1
372384
replacement = replacement.replace(r"\1", match_object.group(replaced_group_idx))
385+
373386
renamed_key = source_key.replace(match_object.group(0), replacement)
374387

375388
return renamed_key, source_pattern_that_matched
@@ -877,7 +890,15 @@ def convert_and_load_state_dict_in_model(
877890
param_device = "cpu" if param_device == "disk" else param_device
878891
future = spawn_materialize(thread_pool, tensor, param_device, _dtype)
879892

880-
mapping.add_tensor(renamed_key, original_key, source_pattern, future)
893+
concrete_source_pattern = source_pattern
894+
if isinstance(mapping, WeightConverter) and source_pattern is not None and "*" in source_pattern:
895+
pattern_with_captures = source_pattern.replace("*", r"(.*?)")
896+
pattern_regex = re.compile(f"^{pattern_with_captures}$")
897+
concrete_source_pattern = extract_concrete_key_from_regex_pattern(
898+
original_key, source_pattern, pattern_regex
899+
)
900+
901+
mapping.add_tensor(renamed_key, original_key, concrete_source_pattern, future)
881902
elif source_pattern is not None: # add all target keys as unexpected
882903
mapping = pattern_to_converter[source_pattern]
883904
for k in mapping.target_patterns:

src/transformers/integrations/torchao.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def convert(
218218
is_unsafe_serialization = "_weight_" not in list(input_dict.keys())[0]
219219

220220
param_data = {}
221-
layer_name = '.'.join(full_layer_name.split(".")[:-1])
221+
layer_name = ".".join(full_layer_name.split(".")[:-1])
222222
if is_unsafe_serialization:
223223
if isinstance(input_dict["weight"], list):
224224
weight = input_dict["weight"][0]

src/transformers/modeling_utils.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3946,19 +3946,6 @@ def from_pretrained(
39463946

39473947
is_quantized = hf_quantizer is not None
39483948

3949-
weight_conversions: Optional[list[WeightConverter | WeightRenaming]] = None
3950-
model_type = getattr(config, "model_type", None)
3951-
if model_type is not None:
3952-
weight_conversions = get_checkpoint_conversion_mapping(model_type)
3953-
if weight_conversions is None:
3954-
weight_conversions = get_checkpoint_conversion_mapping("legacy")
3955-
if key_mapping is not None:
3956-
weight_conversions.extend(
3957-
[WeightRenaming(source_keys=k, target_keys=v) for k, v in key_mapping.items()]
3958-
)
3959-
if hf_quantizer is not None:
3960-
weight_conversions.extend(hf_quantizer.get_weight_conversions())
3961-
39623949
if gguf_file:
39633950
from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
39643951

src/transformers/quantizers/quantizer_torchao.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,11 @@
4343
import torchao
4444

4545
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.14.0"):
46+
from torchao.prototype.awq import AWQConfig
4647
from torchao.prototype.safetensors.safetensors_support import (
4748
flatten_tensor_state_dict,
4849
unflatten_tensor_state_dict,
4950
)
50-
from torchao.prototype.awq import AWQConfig
5151
from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao
5252

5353

@@ -556,8 +556,8 @@ def get_weight_conversions(self):
556556
if self.pre_quantized:
557557
return [
558558
WeightConverter(
559-
source_keys=["*_weight_*"],
560-
target_keys="*weight",
559+
source_patterns=["*_weight_*"],
560+
target_patterns="*weight",
561561
operations=[TorchAoDeserialize(self)],
562562
),
563563
]

tests/quantization/torchao_integration/test_torchao.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -738,7 +738,7 @@ def test_serialization_expected_output(self):
738738

739739

740740
@require_torchao
741-
@require_torchao_version_greater_or_equal("0.14.0")
741+
@require_torchao_version_greater_or_equal("0.15.0")
742742
class TorchAoSafeSerializationTest(TorchAoSerializationTest):
743743
# called only once for all test in this class
744744
@classmethod

0 commit comments

Comments
 (0)