Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 83 additions & 57 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from compressed_tensors.utils import (
align_modules,
get_execution_device,
match_modules_set,
match_named_modules,
update_offload_parameter,
)
Expand Down Expand Up @@ -319,64 +320,48 @@ def _set_resolved_mappings(self, model: Module) -> None:
repeat for model.layer.1 and so on
"""
resolved_mappings: list[ResolvedMapping] = []
for mapping_idx, mapping in enumerate(self.mappings):
num_skipped_mappings = 0

for smooth_name, smooth_layer in (
pbar := tqdm(
match_named_modules(model, [mapping.smooth_layer], self.ignore)
module_to_name = {}
for name, module in model.named_modules():
if module in module_to_name:
logger.info(
f"Warning, {name} and {module_to_name[module]} both "
"share the same module the same module, "
"may have trouble resolving mappings."
)
module_to_name[module] = name

for mapping in self.mappings:
target_patterns = (mapping.smooth_layer, *mapping.balance_layers)

for smooth_layer, *balance_layers in match_modules_set(
model, target_patterns, self.ignore
):
pbar.set_description(
f"Resolving mapping {mapping_idx+1}/{len(self.mappings)}"
f" ({num_skipped_mappings} skipped)"
smooth_name = module_to_name.get(smooth_layer)
balance_names = [
module_to_name.get(balance_layer)
for balance_layer in balance_layers
]

all_compatible = _check_layers_are_compatible(
smooth_layer, smooth_name, balance_layers, balance_names
)

smooth_parent_name = ".".join(smooth_name.split(".")[:-1])
smooth_parent = get_layer_by_name(smooth_parent_name, model)

balance_layers, balance_names = [], []
for balance_regex in mapping.balance_layers:
# find the submodules that match the activation layer
for balance_suffix, balance_layer in match_named_modules(
smooth_parent, [balance_regex], self.ignore
):
balance_name = f"{smooth_parent_name}.{balance_suffix}"

# exclude v_proj->o_proj mappings whose shapes are incompatible
# https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777
if (
isinstance(smooth_layer, torch.nn.Linear)
and isinstance(balance_layer, torch.nn.Linear)
and balance_name.endswith(".o_proj")
and (
(
smooth_name.endswith(".v_proj")
and smooth_layer.out_features
!= balance_layer.in_features
)
or (
smooth_name.endswith(".qkv_proj")
and smooth_layer.out_features
!= 3 * balance_layer.in_features
)
)
):
num_skipped_mappings += 1
continue

balance_layers.append(balance_layer)
balance_names.append(balance_name)
# skip mapping if any of the balance layers are incompatible
if not all_compatible or len(balance_layers) == 0:
logger.info(
f"skipping AWQ for {smooth_name} for mapping {mapping}"
+ (
" because found incompatible balance layers"
if not all_compatible
else " because no balance layers were found"
)
)

if len(balance_layers) == 0:
continue

elif len(balance_layers) == 1:
# for single balance layer, parent is the balance layer
parent_name, parent = balance_name, balance_layer
else:
# for multiple balance layers, find lowest common parent
parent_name, parent = get_lowest_common_parent(balance_names, model)
parent_name, parent = get_lowest_common_module(balance_names, model)

resolved_mappings.append(
ResolvedMapping(
Expand Down Expand Up @@ -721,6 +706,35 @@ def _assert_all_activations_consumed(self):
raise RuntimeError("Some cached activations were not used")


def _check_layers_are_compatible(
smooth_layer, smooth_name, balance_layers, balance_names
):
"""
returns True if they are all compatible
returns False if any smooth & balance layers are incompatible
"""
for balance_layer, balance_name in zip(balance_layers, balance_names):
# exclude v_proj->o_proj mappings whose shapes are incompatible
# https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777
if (
isinstance(smooth_layer, torch.nn.Linear)
and isinstance(balance_layer, torch.nn.Linear)
and balance_name.endswith(".o_proj")
and (
(
smooth_name.endswith(".v_proj")
and smooth_layer.out_features != balance_layer.in_features
)
or (
smooth_name.endswith(".qkv_proj")
and smooth_layer.out_features != 3 * balance_layer.in_features
)
)
):
return False
return True


def _pseudo_quantize_tensor(
w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1
):
Expand Down Expand Up @@ -781,29 +795,41 @@ def _accumulate_mean(
return (prev_sum + sum_added) / new_count, new_count


def get_lowest_common_parent(names: list[str], module: Module) -> tuple[str, Module]:
def get_lowest_common_module(names: list[str], module: Module) -> tuple[str, Module]:
"""
Given a list of names, returns the lowest-scope common parent.
Given a list of names, returns the lowest-scope common module.

NOTE: function excludes parents of type ModuleList, which don't play
NOTE: function excludes modules of type ModuleList, which don't play
nicely with hooks because their forward method is never directly
called for MoE models. See Qwen3MoeSparseMoeBlock for example, experts
are selected based on router output and their forward method is called.
https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L233

Returns name of parent and pointer to parent module
Returns name of module and pointer to module

Implementation is a small alteration of os.path.commonprefix
https://docs.python.org/3/library/os.path.html#os.path.commonprefix
"""
s1 = min(names)
s2 = max(names)
parent_name = ""
# adding "." before and after allows for handling a lot of corner
# cases which were previously mishandled ([case]->prefix->result)
# case 0: single module: [.abc.] -> .abc. -> abc
# case 1: substring modules: [.abc., .ab.] -> .ab -> ""
# case 2: parent & child: [.ab., .ab.a.] -> .ab. -> ab
s1 = min(names) + "."
s2 = max(names) + "."

# 1) find longest shared prefix
parent_name = "."
for i, c in enumerate(s1):
if c != s2[i]:
parent_name = s1[:i].rstrip(".")
break
parent_name += c

# 2) throw away module name fragment and leading dot
# ".keep.thro" -> "keep"
parent_name = parent_name[1 : parent_name.rfind(".")]

# 3) return first common module that is not a module list
while True:
if parent_name == "":
return "", module
Expand Down
76 changes: 52 additions & 24 deletions tests/llmcompressor/modifiers/awq/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme
from pydantic import ValidationError
from torch.nn import Linear

from llmcompressor.modifiers.awq import AWQMapping, AWQModifier
from llmcompressor.modifiers.awq.base import get_lowest_common_parent
from llmcompressor.modifiers.awq.base import get_lowest_common_module
from llmcompressor.modifiers.factory import ModifierFactory


Expand Down Expand Up @@ -40,16 +41,16 @@ def test_set_resolved_mappings():
)
self_attn = torch.nn.ModuleDict(
{
"q_proj": torch.nn.Linear(4, 4),
"k_proj": torch.nn.Linear(4, 4),
"v_proj": torch.nn.Linear(4, 4),
"o_proj": torch.nn.Linear(4, 4),
"q_proj": Linear(4, 4),
"k_proj": Linear(4, 4),
"v_proj": Linear(4, 4),
"o_proj": Linear(4, 4),
}
)
mlp = torch.nn.ModuleDict(
{
"up_proj": torch.nn.Linear(4, 10),
"down_proj": torch.nn.Linear(10, 4),
"up_proj": Linear(4, 10),
"down_proj": Linear(10, 4),
}
)
model = torch.nn.ModuleDict(
Expand Down Expand Up @@ -85,10 +86,12 @@ def test_set_resolved_mappings():
assert set(mapping.balance_names) == {"decoder.mlp.down_proj"}
assert mapping.parent_name == "decoder.mlp.down_proj"

# make sure we exclude case where o_proj/v_proj shapes are mismatched
awq = AWQModifier(
mappings=[
# make sure we exclude case where o_proj/v_proj shapes are mismatched
AWQMapping("re:.*v_proj", ["re:.*o_proj"]),
# make sure we exclude mapping if any balance layers are skipped
AWQMapping("re:.*v_proj", ["re:.*z_proj", "re:.*o_proj"]),
],
scheme="W4A16_ASYM",
)
Expand All @@ -98,17 +101,28 @@ def test_set_resolved_mappings():
{
"self_attn": torch.nn.ModuleDict(
{
"q_proj": torch.nn.Linear(4, 2),
"k_proj": torch.nn.Linear(4, 2),
"v_proj": torch.nn.Linear(4, 2),
"o_proj": torch.nn.Linear(4, 4),
"q_proj": Linear(4, 2),
"k_proj": Linear(4, 2),
"v_proj": Linear(4, 2),
"z_proj": Linear(2, 4),
"o_proj": Linear(4, 4),
}
)
}
)
}
)
awq._set_resolved_mappings(model)
if len(awq._resolved_mappings) > 0:
assert all(
"o_proj" not in name for name in awq._resolved_mappings[0].balance_names
), "should have skipped v->o mapping because o is incompatible"
assert all(
"z_proj" not in name for name in awq._resolved_mappings[0].balance_names
), (
"should have skipped v->[z,o] mapping because o is incompatible even though"
"z is compatible"
)
assert len(awq._resolved_mappings) == 0


Expand Down Expand Up @@ -179,15 +193,15 @@ def test_validate():


@pytest.mark.unit
def test_get_lowest_common_parent():
def test_get_lowest_common_module():
mlp = torch.nn.ModuleDict(
{
"experts": torch.nn.ModuleList(
[
torch.nn.ModuleDict(
{
"gate_proj": torch.nn.Linear(4, 2),
"down_proj": torch.nn.Linear(4, 2),
"gate_proj": Linear(4, 2),
"down_proj": Linear(4, 2),
}
)
for _ in range(10)
Expand All @@ -197,15 +211,15 @@ def test_get_lowest_common_parent():
)
self_attn = torch.nn.ModuleDict(
{
"q_proj": torch.nn.Linear(4, 2),
"k_proj": torch.nn.Linear(4, 2),
"v_proj": torch.nn.Linear(4, 2),
"o_proj": torch.nn.Linear(4, 4),
"q_proj": Linear(4, 2),
"k_proj": Linear(4, 2),
"v_proj": Linear(4, 2),
"o_proj": Linear(4, 4),
}
)
model = torch.nn.ModuleDict(
{
"embed_tokens": torch.nn.Linear(4, 2),
"embed_tokens": Linear(4, 2),
"decoder": torch.nn.ModuleDict(
{
"self_attn": self_attn,
Expand All @@ -215,22 +229,36 @@ def test_get_lowest_common_parent():
}
)

parent_name, parent = get_lowest_common_parent(
parent_name, parent = get_lowest_common_module(
["decoder.mlp.experts.1.gate_proj", "decoder.mlp.experts.4.down_proj"], model
)
assert parent_name == "decoder.mlp" and parent == mlp

parent_name, parent = get_lowest_common_parent(
parent_name, parent = get_lowest_common_module(
["decoder.self_attn.q_proj", "decoder.self_attn.v_proj"], model
)
assert parent_name == "decoder.self_attn" and parent == self_attn

parent_name, parent = get_lowest_common_parent(
parent_name, parent = get_lowest_common_module(
["decoder.mlp.experts.1.gate_proj", "decoder.self_attn.v_proj"], model
)
assert parent_name == "decoder" and parent == model["decoder"]

parent_name, parent = get_lowest_common_parent(
parent_name, parent = get_lowest_common_module(
["embed_tokens", "decoder.self_attn.v_proj"], model
)
assert parent_name == "" and parent == model

m = torch.nn.ModuleDict(
{
"abc": Linear(3, 3),
"ab": torch.nn.ModuleDict({"a": Linear(3, 3)}),
"z": Linear(3, 3),
}
)
parent_name, parent = get_lowest_common_module(["abc", "ab"], m)
assert parent_name == ""
parent_name, parent = get_lowest_common_module(["ab", "ab.a"], m)
assert parent_name == "ab"
parent_name, parent = get_lowest_common_module(["z"], m)
assert parent_name == "z"