@@ -69,7 +69,6 @@ def build_glob_alternation(
6969 """
7070 src_group_to_glob : dict [str , str ] = {}
7171 tgt_group_to_glob : dict [str , str ] = {}
72- group_to_pattern : dict [str , re .Pattern ] = {}
7372 branches : list [str ] = []
7473 i = 0
7574 for glob in globs :
@@ -78,12 +77,11 @@ def build_glob_alternation(
7877 group_name = f"g{ i } "
7978 src_group_to_glob [group_name ] = src
8079 i += 1
81- # Convert the glob pattern to a regex with capture groups for wildcards
82- pattern_with_captures = src .replace ("*" , r"(.*?)" )
83- group_to_pattern [group_name ] = re .compile (f"^{ pattern_with_captures } $" )
8480 body = src .replace ("*" , r".*" )
8581 branches .append (f"(?P<{ group_name } >{ body } )" )
86- tgt_group_to_glob [group_name ] = glob .target_keys [0 ] if isinstance (glob .target_keys , list ) else glob .target_keys
82+ tgt_group_to_glob [group_name ] = (
83+ glob .target_patterns [0 ] if isinstance (glob .target_patterns , list ) else glob .target_patterns
84+ )
8785 else :
8886 group_name = f"g{ i } "
8987 src_group_to_glob [group_name ] = glob
@@ -94,7 +92,7 @@ def build_glob_alternation(
9492 tgt_group_to_glob [group_name ] = glob
9593
9694 alternation = re .compile ("|" .join (branches ))
97- return alternation , src_group_to_glob , tgt_group_to_glob , group_to_pattern
95+ return alternation , src_group_to_glob , tgt_group_to_glob
9896
9997
10098class ConversionOps :
@@ -336,7 +334,11 @@ def __post_init__(self):
336334 branches = []
337335 for i , source_pattern in enumerate (self .source_patterns ):
338336 group_name = f"g{ i } "
339- pattern = source_pattern .replace (".*." , r"\..*\." )
337+ # support both glob-style (*) and regex-style (.*) wildcards
338+ if "*" in source_pattern and ".*" not in source_pattern :
339+ pattern = source_pattern .replace ("*" , r".*" )
340+ else :
341+ pattern = source_pattern .replace (".*." , r"\..*\." )
340342 branches .append (f"(?P<{ group_name } >{ pattern } )" )
341343 self .compiled_sources = re .compile ("|" .join (branches ))
342344
@@ -364,12 +366,20 @@ def rename_source_key(self, source_key: str) -> tuple[str, str | None]:
364366 source_pattern_that_matched = self .source_patterns [int (matching_group_name [1 :])]
365367 # If we matched, we always replace with the first target pattern, in case we have several (one to many transform)
366368 replacement = self .target_patterns [0 ]
367- # # Allow capturing groups in patterns, i.e. to add a prefix to all keys (e.g. timm_wrapper, sam3)
368- if r"\1" in replacement :
369+
370+ if "*" in replacement and "*" in source_pattern_that_matched :
371+ pattern_with_captures = source_pattern_that_matched .replace ("*" , r"(.*?)" )
372+ pattern_regex = re .compile (f"^{ pattern_with_captures } $" )
373+ match = pattern_regex .match (source_key )
374+ if match :
375+ groups = match .groups ()
376+ replacement = replacement .replace ("*" , groups [0 ], 1 )
377+ elif r"\1" in replacement :
369378 # The index of the internal group we need to replace is the index of the matched named group as it comes
370379 # inside that matched named group
371380 replaced_group_idx = self .compiled_sources .groupindex [matching_group_name ] + 1
372381 replacement = replacement .replace (r"\1" , match_object .group (replaced_group_idx ))
382+
373383 renamed_key = source_key .replace (match_object .group (0 ), replacement )
374384
375385 return renamed_key , source_pattern_that_matched
@@ -805,9 +815,9 @@ def convert_and_load_state_dict_in_model(
805815 # build '(?P<g0>.*.*\\.block_sparse_moe\\..*)' and group to source {'g0': '*.block_sparse_moe.'}
806816 # and target to source {'g0': '*.mlp.'}. This allows us to quickly find which pattern matched.
807817 if tp_plan != {}:
808- tp_plan_alt , tp_plan_by_group_name , _ , _ = build_glob_alternation (list (tp_plan .keys ()))
818+ tp_plan_alt , tp_plan_by_group_name , _ = build_glob_alternation (list (tp_plan .keys ()))
809819 if dtype_plan != {}:
810- dtype_policy_alt , dtype_policy_by_group_name , _ , _ = build_glob_alternation (list (dtype_plan .keys ()))
820+ dtype_policy_alt , dtype_policy_by_group_name , _ = build_glob_alternation (list (dtype_plan .keys ()))
811821
812822 pattern_to_converter = {k : converter for converter in converters for k in converter .source_patterns }
813823
@@ -877,7 +887,15 @@ def convert_and_load_state_dict_in_model(
877887 param_device = "cpu" if param_device == "disk" else param_device
878888 future = spawn_materialize (thread_pool , tensor , param_device , _dtype )
879889
880- mapping .add_tensor (renamed_key , original_key , source_pattern , future )
890+ concrete_source_pattern = source_pattern
891+ if isinstance (mapping , WeightConverter ) and source_pattern is not None and "*" in source_pattern :
892+ pattern_with_captures = source_pattern .replace ("*" , r"(.*?)" )
893+ pattern_regex = re .compile (f"^{ pattern_with_captures } $" )
894+ concrete_source_pattern = extract_concrete_key_from_regex_pattern (
895+ original_key , source_pattern , pattern_regex
896+ )
897+
898+ mapping .add_tensor (renamed_key , original_key , concrete_source_pattern , future )
881899 elif source_pattern is not None : # add all target keys as unexpected
882900 mapping = pattern_to_converter [source_pattern ]
883901 for k in mapping .target_patterns :
0 commit comments