From 69cf95b353dac8272c0566d5fd8b6f950a4f5f7b Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 25 Nov 2025 20:57:13 +0000 Subject: [PATCH 1/5] [AWQ] small refactor to use match_modules_set Summary: modified _set_resolved_mappings to get smoothing and balance layers at same time. Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 93 ++++++++++++++----------- 1 file changed, 52 insertions(+), 41 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 98e53b4e0..3396b5cde 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -7,6 +7,7 @@ from compressed_tensors.utils import ( align_modules, get_execution_device, + match_modules_set, match_named_modules, update_offload_parameter, ) @@ -312,19 +313,22 @@ def _set_resolved_mappings(self, model: Module) -> None: into ResolvedMapping objects, resolving regular expressions. Result is stored in _resolved_mappings. - For each activation in the mapping list, we find the corresponding weight to - balance by searching for the longest substring. For instance, if our balance - weight is ".*re:.*q_proj" and the activation is "re:.*self_attn_layer_norm" we - would match model.layer.0.p_proj to model.layer.0.self_attn_layer_norm and - repeat for model.layer.1 and so on + Uses match_modules_set to find coherent sets of (smooth_layer, *balance_layers) + that belong together in the model architecture. """ + # Build a module-to-name mapping for efficient lookups + module_to_name = {module: name for name, module in model.named_modules()} + resolved_mappings: list[ResolvedMapping] = [] for mapping_idx, mapping in enumerate(self.mappings): num_skipped_mappings = 0 - for smooth_name, smooth_layer in ( + # Use match_modules_set to find coherent sets of modules + target_patterns = (mapping.smooth_layer, *mapping.balance_layers) + + for modules_set in ( pbar := tqdm( - match_named_modules(model, [mapping.smooth_layer], self.ignore) + match_modules_set(model, target_patterns, self.ignore) ) ): pbar.set_description( @@ -332,48 +336,55 @@ def _set_resolved_mappings(self, model: Module) -> None: f" ({num_skipped_mappings} skipped)" ) - smooth_parent_name = ".".join(smooth_name.split(".")[:-1]) - smooth_parent = get_layer_by_name(smooth_parent_name, model) + # Unpack the matched set: first is smooth_layer, rest are balance_layers + smooth_layer = modules_set[0] + all_balance_layers = list(modules_set[1:]) - balance_layers, balance_names = [], [] - for balance_regex in mapping.balance_layers: - # find the submodules that match the activation layer - for balance_suffix, balance_layer in match_named_modules( - smooth_parent, [balance_regex], self.ignore - ): - balance_name = f"{smooth_parent_name}.{balance_suffix}" - - # exclude v_proj->o_proj mappings whose shapes are incompatible - # https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777 - if ( - isinstance(smooth_layer, torch.nn.Linear) - and isinstance(balance_layer, torch.nn.Linear) - and balance_name.endswith(".o_proj") - and ( - ( - smooth_name.endswith(".v_proj") - and smooth_layer.out_features - != balance_layer.in_features - ) - or ( - smooth_name.endswith(".qkv_proj") - and smooth_layer.out_features - != 3 * balance_layer.in_features - ) + # Get names using the pre-built mapping + smooth_name = module_to_name.get(smooth_layer) + if smooth_name is None: + continue + + # Filter balance layers, skipping incompatible ones + balance_layers = [] + balance_names = [] + + for balance_layer in all_balance_layers: + balance_name = module_to_name.get(balance_layer) + if balance_name is None: + continue + + # exclude v_proj->o_proj mappings whose shapes are incompatible + # https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777 + if ( + isinstance(smooth_layer, torch.nn.Linear) + and isinstance(balance_layer, torch.nn.Linear) + and balance_name.endswith(".o_proj") + and ( + ( + smooth_name.endswith(".v_proj") + and smooth_layer.out_features + != balance_layer.in_features + ) + or ( + smooth_name.endswith(".qkv_proj") + and smooth_layer.out_features + != 3 * balance_layer.in_features ) - ): - num_skipped_mappings += 1 - continue + ) + ): + num_skipped_mappings += 1 + continue - balance_layers.append(balance_layer) - balance_names.append(balance_name) + balance_layers.append(balance_layer) + balance_names.append(balance_name) if len(balance_layers) == 0: continue - elif len(balance_layers) == 1: + if len(balance_layers) == 1: # for single balance layer, parent is the balance layer - parent_name, parent = balance_name, balance_layer + parent_name, parent = balance_names[0], balance_layers[0] else: # for multiple balance layers, find lowest common parent parent_name, parent = get_lowest_common_parent(balance_names, model) From dc8e3aeab4099713d78f830e206da00a1b4894a5 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 25 Nov 2025 21:07:36 +0000 Subject: [PATCH 2/5] formatting Summary Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 3396b5cde..e52ae54b6 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -327,9 +327,7 @@ def _set_resolved_mappings(self, model: Module) -> None: target_patterns = (mapping.smooth_layer, *mapping.balance_layers) for modules_set in ( - pbar := tqdm( - match_modules_set(model, target_patterns, self.ignore) - ) + pbar := tqdm(match_modules_set(model, target_patterns, self.ignore)) ): pbar.set_description( f"Resolving mapping {mapping_idx+1}/{len(self.mappings)}" From 351568d913ebac61a52a45721007e1b5b5cea677 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 26 Nov 2025 18:09:44 +0000 Subject: [PATCH 3/5] fixing logic and test update Summary Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 96 +++++++++---------- .../llmcompressor/modifiers/awq/test_base.py | 15 ++- 2 files changed, 61 insertions(+), 50 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index e52ae54b6..b5643d240 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -313,20 +313,21 @@ def _set_resolved_mappings(self, model: Module) -> None: into ResolvedMapping objects, resolving regular expressions. Result is stored in _resolved_mappings. - Uses match_modules_set to find coherent sets of (smooth_layer, *balance_layers) - that belong together in the model architecture. + For each activation in the mapping list, we find the corresponding weight to + balance by searching for the longest substring. For instance, if our balance + weight is ".*re:.*q_proj" and the activation is "re:.*self_attn_layer_norm" we + would match model.layer.0.p_proj to model.layer.0.self_attn_layer_norm and + repeat for model.layer.1 and so on """ - # Build a module-to-name mapping for efficient lookups - module_to_name = {module: name for name, module in model.named_modules()} - resolved_mappings: list[ResolvedMapping] = [] + module_to_name = {module: name for name, module in model.named_modules()} for mapping_idx, mapping in enumerate(self.mappings): num_skipped_mappings = 0 # Use match_modules_set to find coherent sets of modules target_patterns = (mapping.smooth_layer, *mapping.balance_layers) - for modules_set in ( + for smooth_layer, *balance_layers in ( pbar := tqdm(match_modules_set(model, target_patterns, self.ignore)) ): pbar.set_description( @@ -334,53 +335,21 @@ def _set_resolved_mappings(self, model: Module) -> None: f" ({num_skipped_mappings} skipped)" ) - # Unpack the matched set: first is smooth_layer, rest are balance_layers - smooth_layer = modules_set[0] - all_balance_layers = list(modules_set[1:]) - - # Get names using the pre-built mapping smooth_name = module_to_name.get(smooth_layer) - if smooth_name is None: - continue + balance_names = [ + module_to_name.get(balance_layer) + for balance_layer in balance_layers + ] - # Filter balance layers, skipping incompatible ones - balance_layers = [] - balance_names = [] - - for balance_layer in all_balance_layers: - balance_name = module_to_name.get(balance_layer) - if balance_name is None: - continue - - # exclude v_proj->o_proj mappings whose shapes are incompatible - # https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777 - if ( - isinstance(smooth_layer, torch.nn.Linear) - and isinstance(balance_layer, torch.nn.Linear) - and balance_name.endswith(".o_proj") - and ( - ( - smooth_name.endswith(".v_proj") - and smooth_layer.out_features - != balance_layer.in_features - ) - or ( - smooth_name.endswith(".qkv_proj") - and smooth_layer.out_features - != 3 * balance_layer.in_features - ) - ) - ): - num_skipped_mappings += 1 - continue - - balance_layers.append(balance_layer) - balance_names.append(balance_name) + all_compatible = _check_layers_are_compatible( + smooth_layer, smooth_name, balance_layers, balance_names + ) - if len(balance_layers) == 0: + # skip mapping if any of the balance layers are incompatible + if not all_compatible or len(balance_layers) == 0: + num_skipped_mappings += 1 continue - - if len(balance_layers) == 1: + elif len(balance_layers) == 1: # for single balance layer, parent is the balance layer parent_name, parent = balance_names[0], balance_layers[0] else: @@ -730,6 +699,35 @@ def _assert_all_activations_consumed(self): raise RuntimeError("Some cached activations were not used") +def _check_layers_are_compatible( + smooth_layer, smooth_name, balance_layers, balance_names +): + """ + returns True if they are all compatible + returns False if any smooth & balance layers are incompatible + """ + for balance_layer, balance_name in zip(balance_layers, balance_names): + # exclude v_proj->o_proj mappings whose shapes are incompatible + # https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777 + if ( + isinstance(smooth_layer, torch.nn.Linear) + and isinstance(balance_layer, torch.nn.Linear) + and balance_name.endswith(".o_proj") + and ( + ( + smooth_name.endswith(".v_proj") + and smooth_layer.out_features != balance_layer.in_features + ) + or ( + smooth_name.endswith(".qkv_proj") + and smooth_layer.out_features != 3 * balance_layer.in_features + ) + ) + ): + return False + return True + + def _pseudo_quantize_tensor( w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1 ): diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index 950ab0f51..83eaaa970 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -85,10 +85,12 @@ def test_set_resolved_mappings(): assert set(mapping.balance_names) == {"decoder.mlp.down_proj"} assert mapping.parent_name == "decoder.mlp.down_proj" - # make sure we exclude case where o_proj/v_proj shapes are mismatched awq = AWQModifier( mappings=[ + # make sure we exclude case where o_proj/v_proj shapes are mismatched AWQMapping("re:.*v_proj", ["re:.*o_proj"]), + # make sure we exclude mapping if any balance layers are skipped + AWQMapping("re:.*v_proj", ["re:.*z_proj", "re:.*o_proj"]), ], scheme="W4A16_ASYM", ) @@ -101,6 +103,7 @@ def test_set_resolved_mappings(): "q_proj": torch.nn.Linear(4, 2), "k_proj": torch.nn.Linear(4, 2), "v_proj": torch.nn.Linear(4, 2), + "z_proj": torch.nn.Linear(2, 4), "o_proj": torch.nn.Linear(4, 4), } ) @@ -109,6 +112,16 @@ def test_set_resolved_mappings(): } ) awq._set_resolved_mappings(model) + if len(awq._resolved_mappings) > 0: + assert all( + "o_proj" not in name for name in awq._resolved_mappings[0].balance_names + ), "should have skipped v->o mapping because o is incompatible" + assert all( + "z_proj" not in name for name in awq._resolved_mappings[0].balance_names + ), ( + "should have skipped v->[z,o] mapping because o is incompatible even though" + "z is compatible" + ) assert len(awq._resolved_mappings) == 0 From dea5eabbdbd72a3da23f7690a7bff4d2df6d25cf Mon Sep 17 00:00:00 2001 From: HDCharles Date: Thu, 27 Nov 2025 03:08:53 +0000 Subject: [PATCH 4/5] updates to get_lowest_common_x Summary Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 67 ++++++++++++------- .../llmcompressor/modifiers/awq/test_base.py | 65 +++++++++++------- 2 files changed, 84 insertions(+), 48 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index b5643d240..8ae2c41e9 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -320,21 +320,26 @@ def _set_resolved_mappings(self, model: Module) -> None: repeat for model.layer.1 and so on """ resolved_mappings: list[ResolvedMapping] = [] - module_to_name = {module: name for name, module in model.named_modules()} - for mapping_idx, mapping in enumerate(self.mappings): - num_skipped_mappings = 0 + + module_to_name = {} + for name, module in model.named_modules(): + if module in module_to_name: + logger.info( + f"Warning, {name} and {module_to_name[module]} both " + "share the same module the same module, " + "may have trouble resolving mappings." + ) + module_to_name[module] = name + + + + for mapping in self.mappings: - # Use match_modules_set to find coherent sets of modules target_patterns = (mapping.smooth_layer, *mapping.balance_layers) for smooth_layer, *balance_layers in ( - pbar := tqdm(match_modules_set(model, target_patterns, self.ignore)) + match_modules_set(model, target_patterns, self.ignore) ): - pbar.set_description( - f"Resolving mapping {mapping_idx+1}/{len(self.mappings)}" - f" ({num_skipped_mappings} skipped)" - ) - smooth_name = module_to_name.get(smooth_layer) balance_names = [ module_to_name.get(balance_layer) @@ -347,14 +352,18 @@ def _set_resolved_mappings(self, model: Module) -> None: # skip mapping if any of the balance layers are incompatible if not all_compatible or len(balance_layers) == 0: - num_skipped_mappings += 1 + logger.info( + f"skipping AWQ for {smooth_name} for mapping {mapping}" + ( + " because found incompatible balance layers" + if not all_compatible else + f" because no balance layers were found" + ) + ) + continue - elif len(balance_layers) == 1: - # for single balance layer, parent is the balance layer - parent_name, parent = balance_names[0], balance_layers[0] else: # for multiple balance layers, find lowest common parent - parent_name, parent = get_lowest_common_parent(balance_names, model) + parent_name, parent = get_lowest_common_module(balance_names, model) resolved_mappings.append( ResolvedMapping( @@ -788,29 +797,41 @@ def _accumulate_mean( return (prev_sum + sum_added) / new_count, new_count -def get_lowest_common_parent(names: list[str], module: Module) -> tuple[str, Module]: +def get_lowest_common_module(names: list[str], module: Module) -> tuple[str, Module]: """ - Given a list of names, returns the lowest-scope common parent. + Given a list of names, returns the lowest-scope common module. - NOTE: function excludes parents of type ModuleList, which don't play + NOTE: function excludes modules of type ModuleList, which don't play nicely with hooks because their forward method is never directly called for MoE models. See Qwen3MoeSparseMoeBlock for example, experts are selected based on router output and their forward method is called. https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L233 - Returns name of parent and pointer to parent module + Returns name of module and pointer to module Implementation is a small alteration of os.path.commonprefix https://docs.python.org/3/library/os.path.html#os.path.commonprefix """ - s1 = min(names) - s2 = max(names) - parent_name = "" + # adding "." before and after allows for handling a lot of corner + # cases which were previously mishandled ([case]->prefix->result) + # case 0: single module: [.abc.] -> .abc. -> abc + # case 1: substring modules: [.abc., .ab.] -> .ab -> "" + # case 2: parent & child: [.ab., .ab.a.] -> .ab. -> ab + s1 = min(names) + "." + s2 = max(names) + "." + + # 1) find longest shared prefix + parent_name = "." for i, c in enumerate(s1): if c != s2[i]: - parent_name = s1[:i].rstrip(".") break + parent_name += c + + # 2) throw away module name fragment and leading dot + # ".keep.thro" -> "keep" + parent_name = parent_name[1:parent_name.rfind(".")] + # 3) return first parent that is not a module list while True: if parent_name == "": return "", module diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index 83eaaa970..2cb78fb65 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -2,9 +2,9 @@ import torch from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme from pydantic import ValidationError - +from torch.nn import Linear from llmcompressor.modifiers.awq import AWQMapping, AWQModifier -from llmcompressor.modifiers.awq.base import get_lowest_common_parent +from llmcompressor.modifiers.awq.base import get_lowest_common_module from llmcompressor.modifiers.factory import ModifierFactory @@ -40,16 +40,16 @@ def test_set_resolved_mappings(): ) self_attn = torch.nn.ModuleDict( { - "q_proj": torch.nn.Linear(4, 4), - "k_proj": torch.nn.Linear(4, 4), - "v_proj": torch.nn.Linear(4, 4), - "o_proj": torch.nn.Linear(4, 4), + "q_proj": Linear(4, 4), + "k_proj": Linear(4, 4), + "v_proj": Linear(4, 4), + "o_proj": Linear(4, 4), } ) mlp = torch.nn.ModuleDict( { - "up_proj": torch.nn.Linear(4, 10), - "down_proj": torch.nn.Linear(10, 4), + "up_proj": Linear(4, 10), + "down_proj": Linear(10, 4), } ) model = torch.nn.ModuleDict( @@ -100,11 +100,11 @@ def test_set_resolved_mappings(): { "self_attn": torch.nn.ModuleDict( { - "q_proj": torch.nn.Linear(4, 2), - "k_proj": torch.nn.Linear(4, 2), - "v_proj": torch.nn.Linear(4, 2), - "z_proj": torch.nn.Linear(2, 4), - "o_proj": torch.nn.Linear(4, 4), + "q_proj": Linear(4, 2), + "k_proj": Linear(4, 2), + "v_proj": Linear(4, 2), + "z_proj": Linear(2, 4), + "o_proj": Linear(4, 4), } ) } @@ -192,15 +192,15 @@ def test_validate(): @pytest.mark.unit -def test_get_lowest_common_parent(): +def test_get_lowest_common_module(): mlp = torch.nn.ModuleDict( { "experts": torch.nn.ModuleList( [ torch.nn.ModuleDict( { - "gate_proj": torch.nn.Linear(4, 2), - "down_proj": torch.nn.Linear(4, 2), + "gate_proj": Linear(4, 2), + "down_proj": Linear(4, 2), } ) for _ in range(10) @@ -210,15 +210,15 @@ def test_get_lowest_common_parent(): ) self_attn = torch.nn.ModuleDict( { - "q_proj": torch.nn.Linear(4, 2), - "k_proj": torch.nn.Linear(4, 2), - "v_proj": torch.nn.Linear(4, 2), - "o_proj": torch.nn.Linear(4, 4), + "q_proj": Linear(4, 2), + "k_proj": Linear(4, 2), + "v_proj": Linear(4, 2), + "o_proj": Linear(4, 4), } ) model = torch.nn.ModuleDict( { - "embed_tokens": torch.nn.Linear(4, 2), + "embed_tokens": Linear(4, 2), "decoder": torch.nn.ModuleDict( { "self_attn": self_attn, @@ -228,22 +228,37 @@ def test_get_lowest_common_parent(): } ) - parent_name, parent = get_lowest_common_parent( + parent_name, parent = get_lowest_common_module( ["decoder.mlp.experts.1.gate_proj", "decoder.mlp.experts.4.down_proj"], model ) assert parent_name == "decoder.mlp" and parent == mlp - parent_name, parent = get_lowest_common_parent( + parent_name, parent = get_lowest_common_module( ["decoder.self_attn.q_proj", "decoder.self_attn.v_proj"], model ) assert parent_name == "decoder.self_attn" and parent == self_attn - parent_name, parent = get_lowest_common_parent( + parent_name, parent = get_lowest_common_module( ["decoder.mlp.experts.1.gate_proj", "decoder.self_attn.v_proj"], model ) assert parent_name == "decoder" and parent == model["decoder"] - parent_name, parent = get_lowest_common_parent( + parent_name, parent = get_lowest_common_module( ["embed_tokens", "decoder.self_attn.v_proj"], model ) assert parent_name == "" and parent == model + + m = torch.nn.ModuleDict( + { + "abc": Linear(3,3), + "ab": torch.nn.ModuleDict({"a": Linear(3,3)}), + "z": Linear(3,3) + } + ) + parent_name, parent = get_lowest_common_module(["abc", "ab"], m) + assert parent_name == "" + parent_name, parent = get_lowest_common_module(["ab", "ab.a"], m) + assert parent_name == "ab" + parent_name, parent = get_lowest_common_module(["z"], m) + assert parent_name == "z" + From 728b8c0803232784f653da4b09a1bedf5b403bf9 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Thu, 27 Nov 2025 03:26:33 +0000 Subject: [PATCH 5/5] format Summary Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 24 +++++++++---------- .../llmcompressor/modifiers/awq/test_base.py | 8 +++---- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 8ae2c41e9..420dadf2b 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -320,7 +320,7 @@ def _set_resolved_mappings(self, model: Module) -> None: repeat for model.layer.1 and so on """ resolved_mappings: list[ResolvedMapping] = [] - + module_to_name = {} for name, module in model.named_modules(): if module in module_to_name: @@ -331,14 +331,11 @@ def _set_resolved_mappings(self, model: Module) -> None: ) module_to_name[module] = name - - for mapping in self.mappings: - target_patterns = (mapping.smooth_layer, *mapping.balance_layers) - for smooth_layer, *balance_layers in ( - match_modules_set(model, target_patterns, self.ignore) + for smooth_layer, *balance_layers in match_modules_set( + model, target_patterns, self.ignore ): smooth_name = module_to_name.get(smooth_layer) balance_names = [ @@ -353,10 +350,11 @@ def _set_resolved_mappings(self, model: Module) -> None: # skip mapping if any of the balance layers are incompatible if not all_compatible or len(balance_layers) == 0: logger.info( - f"skipping AWQ for {smooth_name} for mapping {mapping}" + ( - " because found incompatible balance layers" - if not all_compatible else - f" because no balance layers were found" + f"skipping AWQ for {smooth_name} for mapping {mapping}" + + ( + " because found incompatible balance layers" + if not all_compatible + else " because no balance layers were found" ) ) @@ -812,7 +810,7 @@ def get_lowest_common_module(names: list[str], module: Module) -> tuple[str, Mod Implementation is a small alteration of os.path.commonprefix https://docs.python.org/3/library/os.path.html#os.path.commonprefix """ - # adding "." before and after allows for handling a lot of corner + # adding "." before and after allows for handling a lot of corner # cases which were previously mishandled ([case]->prefix->result) # case 0: single module: [.abc.] -> .abc. -> abc # case 1: substring modules: [.abc., .ab.] -> .ab -> "" @@ -829,9 +827,9 @@ def get_lowest_common_module(names: list[str], module: Module) -> tuple[str, Mod # 2) throw away module name fragment and leading dot # ".keep.thro" -> "keep" - parent_name = parent_name[1:parent_name.rfind(".")] + parent_name = parent_name[1 : parent_name.rfind(".")] - # 3) return first parent that is not a module list + # 3) return first common module that is not a module list while True: if parent_name == "": return "", module diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index 2cb78fb65..e8103f9e3 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -3,6 +3,7 @@ from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme from pydantic import ValidationError from torch.nn import Linear + from llmcompressor.modifiers.awq import AWQMapping, AWQModifier from llmcompressor.modifiers.awq.base import get_lowest_common_module from llmcompressor.modifiers.factory import ModifierFactory @@ -250,9 +251,9 @@ def test_get_lowest_common_module(): m = torch.nn.ModuleDict( { - "abc": Linear(3,3), - "ab": torch.nn.ModuleDict({"a": Linear(3,3)}), - "z": Linear(3,3) + "abc": Linear(3, 3), + "ab": torch.nn.ModuleDict({"a": Linear(3, 3)}), + "z": Linear(3, 3), } ) parent_name, parent = get_lowest_common_module(["abc", "ab"], m) @@ -261,4 +262,3 @@ def test_get_lowest_common_module(): assert parent_name == "ab" parent_name, parent = get_lowest_common_module(["z"], m) assert parent_name == "z" -