@@ -78,12 +78,13 @@ def build_glob_alternation(
7878 group_name = f"g{ i } "
7979 src_group_to_glob [group_name ] = src
8080 i += 1
81- # Convert the glob pattern to a regex with capture groups for wildcards
8281 pattern_with_captures = src .replace ("*" , r"(.*?)" )
8382 group_to_pattern [group_name ] = re .compile (f"^{ pattern_with_captures } $" )
8483 body = src .replace ("*" , r".*" )
8584 branches .append (f"(?P<{ group_name } >{ body } )" )
86- tgt_group_to_glob [group_name ] = glob .target_keys [0 ] if isinstance (glob .target_keys , list ) else glob .target_keys
85+ tgt_group_to_glob [group_name ] = (
86+ glob .target_patterns [0 ] if isinstance (glob .target_patterns , list ) else glob .target_patterns
87+ )
8788 else :
8889 group_name = f"g{ i } "
8990 src_group_to_glob [group_name ] = glob
@@ -336,7 +337,11 @@ def __post_init__(self):
336337 branches = []
337338 for i , source_pattern in enumerate (self .source_patterns ):
338339 group_name = f"g{ i } "
339- pattern = source_pattern .replace (".*." , r"\..*\." )
340+ # support both glob-style (*) and regex-style (.*) wildcards
341+ if "*" in source_pattern and ".*" not in source_pattern :
342+ pattern = source_pattern .replace ("*" , r".*" )
343+ else :
344+ pattern = source_pattern .replace (".*." , r"\..*\." )
340345 branches .append (f"(?P<{ group_name } >{ pattern } )" )
341346 self .compiled_sources = re .compile ("|" .join (branches ))
342347
@@ -364,12 +369,20 @@ def rename_source_key(self, source_key: str) -> tuple[str, str | None]:
364369 source_pattern_that_matched = self .source_patterns [int (matching_group_name [1 :])]
365370 # If we matched, we always replace with the first target pattern, in case we have several (one to many transform)
366371 replacement = self .target_patterns [0 ]
367- # # Allow capturing groups in patterns, i.e. to add a prefix to all keys (e.g. timm_wrapper, sam3)
368- if r"\1" in replacement :
372+
373+ if "*" in replacement and "*" in source_pattern_that_matched :
374+ pattern_with_captures = source_pattern_that_matched .replace ("*" , r"(.*?)" )
375+ pattern_regex = re .compile (f"^{ pattern_with_captures } $" )
376+ match = pattern_regex .match (source_key )
377+ if match :
378+ groups = match .groups ()
379+ replacement = replacement .replace ("*" , groups [0 ], 1 )
380+ elif r"\1" in replacement :
369381 # The index of the internal group we need to replace is the index of the matched named group as it comes
370382 # inside that matched named group
371383 replaced_group_idx = self .compiled_sources .groupindex [matching_group_name ] + 1
372384 replacement = replacement .replace (r"\1" , match_object .group (replaced_group_idx ))
385+
373386 renamed_key = source_key .replace (match_object .group (0 ), replacement )
374387
375388 return renamed_key , source_pattern_that_matched
@@ -877,7 +890,15 @@ def convert_and_load_state_dict_in_model(
877890 param_device = "cpu" if param_device == "disk" else param_device
878891 future = spawn_materialize (thread_pool , tensor , param_device , _dtype )
879892
880- mapping .add_tensor (renamed_key , original_key , source_pattern , future )
893+ concrete_source_pattern = source_pattern
894+ if isinstance (mapping , WeightConverter ) and source_pattern is not None and "*" in source_pattern :
895+ pattern_with_captures = source_pattern .replace ("*" , r"(.*?)" )
896+ pattern_regex = re .compile (f"^{ pattern_with_captures } $" )
897+ concrete_source_pattern = extract_concrete_key_from_regex_pattern (
898+ original_key , source_pattern , pattern_regex
899+ )
900+
901+ mapping .add_tensor (renamed_key , original_key , concrete_source_pattern , future )
881902 elif source_pattern is not None : # add all target keys as unexpected
882903 mapping = pattern_to_converter [source_pattern ]
883904 for k in mapping .target_patterns :
0 commit comments