Skip to content

Commit f45b57c

Browse files
committed
remove unused post_init
1 parent 77ff529 commit f45b57c

File tree

5 files changed

+1
-56
lines changed

5 files changed

+1
-56
lines changed

gptqmodel/nn_modules/qlinear/awq_exllamav2.py renamed to gptqmodel/nn_modules/qlinear/exllamav2_awq.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ def post_init(self, scratch_space: ScratchSpace):
118118
def forward(self, x: torch.Tensor):
119119
assert self.q_handle is not None, (
120120
"module.post_init() must be called before module.forward(). "
121-
"Use exllamav2_post_init() on the whole model."
122121
)
123122
if exlv2_ext is None:
124123
raise ModuleNotFoundError("External ExLlamaV2 kernels are not properly installed." + msg)

gptqmodel/quantization/awq/modules/linear/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44
# Contact: [email protected], x.com/qubitium
55

6-
from .exllama import WQLinear_Exllama, exllama_post_init
7-
from .exllamav2 import WQLinear_ExllamaV2, exllamav2_post_init
86
from .gemm import WQLinear_GEMM
97
from .gemv import WQLinear_GEMV
108
from .gemv_fast import WQLinear_GEMVFast
11-
from .marlin import WQLinear_Marlin, marlin_post_init
9+
from .marlin import WQLinear_Marlin

gptqmodel/quantization/awq/modules/linear/exllama.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,3 @@ def forward(self, x):
133133

134134
return out.view(out_shape)
135135

136-
137-
def exllama_post_init(model):
138-
for _, submodule in model.named_modules():
139-
if isinstance(submodule, WQLinear_Exllama):
140-
submodule.post_init()
141-
142-
return model

gptqmodel/quantization/awq/modules/linear/exllamav2.py

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,6 @@ def scratch_space_fixed(self, max_input_len=2048, max_batch_size=8):
133133
def forward(self, x):
134134
assert self.q_handle is not None, (
135135
"module.post_init() must be called before module.forward(). "
136-
"Use exllamav2_post_init() on the whole model."
137136
)
138137
if exlv2_ext is None:
139138
raise ModuleNotFoundError("External ExLlamaV2 kernels are not properly installed." + msg)
@@ -160,47 +159,3 @@ def forward(self, x):
160159
out.add_(self.bias)
161160

162161
return out.view(out_shape)
163-
164-
165-
class ScratchSpace:
166-
def __init__(self, scratch_bytes, dev):
167-
self.scratch_bytes = scratch_bytes
168-
self.scratch = torch.empty(
169-
self.scratch_bytes // 2,
170-
dtype=torch.float16,
171-
device=dev,
172-
)
173-
174-
def get_slice(self, size_bytes):
175-
size_halfs = next_multiple(size_bytes, 128) // 2
176-
scratch_slice = self.scratch.narrow(0, 0, size_halfs)
177-
178-
return scratch_slice
179-
180-
181-
def exllamav2_post_init(model, max_input_len: int = 2048, max_batch_size: int = 8):
182-
# we search for the maximum number of bytes required for each device's scratch space
183-
fixed_bytes: Dict[torch.device, int] = {}
184-
for _, submodule in model.named_modules():
185-
if isinstance(submodule, AwqExllamaV2QuantLinear):
186-
device = submodule.qweight.device
187-
scratch_fixed = submodule.scratch_space_fixed(
188-
max_input_len=max_input_len, max_batch_size=max_batch_size
189-
)
190-
fixed_bytes[device] = max(fixed_bytes.get(device, 0), scratch_fixed)
191-
192-
# we allocate a model-persistent scratch space for each device
193-
model.scratch_spaces: Dict[torch.device, ScratchSpace] = {}
194-
for device, scratch_bytes in fixed_bytes.items():
195-
model.scratch_spaces[device] = ScratchSpace(scratch_bytes, device)
196-
197-
for _, submodule in model.named_modules():
198-
if isinstance(submodule, AwqExllamaV2QuantLinear):
199-
device = submodule.qweight.device
200-
submodule.post_init(scratch_space=model.scratch_spaces[device])
201-
202-
return model
203-
204-
205-
def next_multiple(x, multiple):
206-
return ((x + multiple - 1) // multiple) * multiple

0 commit comments

Comments
 (0)