Skip to content

Commit 089addc

Browse files
committed
quantize lora linears
1 parent b56f3cf commit 089addc

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

examples/models/llama/source_transformation/quantize.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,13 +159,27 @@ def quantize( # noqa C901
159159
from torchao.utils import unwrap_tensor_subclass
160160

161161
def filter_fn(m, fqn):
162+
# Check if it's a regular nn.Linear
162163
is_linear = isinstance(m, nn.Linear)
164+
165+
# Check if it's a LoRALinear (which has a base weight parameter to quantize)
166+
is_lora_linear = False
167+
try:
168+
from executorch.examples.models.llama.lora import LoRALinear
169+
170+
is_lora_linear = isinstance(m, LoRALinear)
171+
except ImportError:
172+
pass
173+
174+
# Check if the weight shape is compatible with group size
163175
has_shape_compatible_with_group_size = False
164-
if is_linear:
176+
if is_linear or is_lora_linear:
165177
has_shape_compatible_with_group_size = (
166178
m.weight.shape[1] % group_size == 0
167179
)
168-
return is_linear and has_shape_compatible_with_group_size
180+
return (
181+
is_linear or is_lora_linear
182+
) and has_shape_compatible_with_group_size
169183

170184
quantize_(
171185
model,

0 commit comments

Comments
 (0)