Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 52 additions & 43 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from compressed_tensors.utils import (
align_modules,
get_execution_device,
match_modules_set,
match_named_modules,
update_offload_parameter,
)
Expand Down Expand Up @@ -312,68 +313,76 @@ def _set_resolved_mappings(self, model: Module) -> None:
into ResolvedMapping objects, resolving regular expressions.
Result is stored in _resolved_mappings.

For each activation in the mapping list, we find the corresponding weight to
balance by searching for the longest substring. For instance, if our balance
weight is ".*re:.*q_proj" and the activation is "re:.*self_attn_layer_norm" we
would match model.layer.0.p_proj to model.layer.0.self_attn_layer_norm and
repeat for model.layer.1 and so on
Uses match_modules_set to find coherent sets of (smooth_layer, *balance_layers)
that belong together in the model architecture.
"""
# Build a module-to-name mapping for efficient lookups
module_to_name = {module: name for name, module in model.named_modules()}

resolved_mappings: list[ResolvedMapping] = []
for mapping_idx, mapping in enumerate(self.mappings):
num_skipped_mappings = 0

for smooth_name, smooth_layer in (
pbar := tqdm(
match_named_modules(model, [mapping.smooth_layer], self.ignore)
)
# Use match_modules_set to find coherent sets of modules
target_patterns = (mapping.smooth_layer, *mapping.balance_layers)

for modules_set in (
pbar := tqdm(match_modules_set(model, target_patterns, self.ignore))
):
pbar.set_description(
f"Resolving mapping {mapping_idx+1}/{len(self.mappings)}"
f" ({num_skipped_mappings} skipped)"
)

smooth_parent_name = ".".join(smooth_name.split(".")[:-1])
smooth_parent = get_layer_by_name(smooth_parent_name, model)
# Unpack the matched set: first is smooth_layer, rest are balance_layers
smooth_layer = modules_set[0]
all_balance_layers = list(modules_set[1:])

balance_layers, balance_names = [], []
for balance_regex in mapping.balance_layers:
# find the submodules that match the activation layer
for balance_suffix, balance_layer in match_named_modules(
smooth_parent, [balance_regex], self.ignore
):
balance_name = f"{smooth_parent_name}.{balance_suffix}"

# exclude v_proj->o_proj mappings whose shapes are incompatible
# https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777
if (
isinstance(smooth_layer, torch.nn.Linear)
and isinstance(balance_layer, torch.nn.Linear)
and balance_name.endswith(".o_proj")
and (
(
smooth_name.endswith(".v_proj")
and smooth_layer.out_features
!= balance_layer.in_features
)
or (
smooth_name.endswith(".qkv_proj")
and smooth_layer.out_features
!= 3 * balance_layer.in_features
)
# Get names using the pre-built mapping
smooth_name = module_to_name.get(smooth_layer)
if smooth_name is None:
continue

# Filter balance layers, skipping incompatible ones
balance_layers = []
balance_names = []

for balance_layer in all_balance_layers:
balance_name = module_to_name.get(balance_layer)
if balance_name is None:
continue

# exclude v_proj->o_proj mappings whose shapes are incompatible
# https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777
if (
isinstance(smooth_layer, torch.nn.Linear)
and isinstance(balance_layer, torch.nn.Linear)
and balance_name.endswith(".o_proj")
and (
(
smooth_name.endswith(".v_proj")
and smooth_layer.out_features
!= balance_layer.in_features
)
or (
smooth_name.endswith(".qkv_proj")
and smooth_layer.out_features
!= 3 * balance_layer.in_features
)
):
num_skipped_mappings += 1
continue
)
):
num_skipped_mappings += 1
continue
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we replace and generalize this to

if any(smooth_layer.weight.size(0) % layer.weight.size(-1) != 0 for layer in balance_layers):
    continue

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't that not work for the qkv_proj? I think mlp layers can do similar with gate_up_proj too.


balance_layers.append(balance_layer)
balance_names.append(balance_name)
balance_layers.append(balance_layer)
balance_names.append(balance_name)

if len(balance_layers) == 0:
continue

elif len(balance_layers) == 1:
if len(balance_layers) == 1:
# for single balance layer, parent is the balance layer
parent_name, parent = balance_name, balance_layer
parent_name, parent = balance_names[0], balance_layers[0]
else:
# for multiple balance layers, find lowest common parent
parent_name, parent = get_lowest_common_parent(balance_names, model)
Expand Down
Loading