@@ -133,7 +133,6 @@ def scratch_space_fixed(self, max_input_len=2048, max_batch_size=8):
133133 def forward (self , x ):
134134 assert self .q_handle is not None , (
135135 "module.post_init() must be called before module.forward(). "
136- "Use exllamav2_post_init() on the whole model."
137136 )
138137 if exlv2_ext is None :
139138 raise ModuleNotFoundError ("External ExLlamaV2 kernels are not properly installed." + msg )
@@ -160,47 +159,3 @@ def forward(self, x):
160159 out .add_ (self .bias )
161160
162161 return out .view (out_shape )
163-
164-
165- class ScratchSpace :
166- def __init__ (self , scratch_bytes , dev ):
167- self .scratch_bytes = scratch_bytes
168- self .scratch = torch .empty (
169- self .scratch_bytes // 2 ,
170- dtype = torch .float16 ,
171- device = dev ,
172- )
173-
174- def get_slice (self , size_bytes ):
175- size_halfs = next_multiple (size_bytes , 128 ) // 2
176- scratch_slice = self .scratch .narrow (0 , 0 , size_halfs )
177-
178- return scratch_slice
179-
180-
181- def exllamav2_post_init (model , max_input_len : int = 2048 , max_batch_size : int = 8 ):
182- # we search for the maximum number of bytes required for each device's scratch space
183- fixed_bytes : Dict [torch .device , int ] = {}
184- for _ , submodule in model .named_modules ():
185- if isinstance (submodule , AwqExllamaV2QuantLinear ):
186- device = submodule .qweight .device
187- scratch_fixed = submodule .scratch_space_fixed (
188- max_input_len = max_input_len , max_batch_size = max_batch_size
189- )
190- fixed_bytes [device ] = max (fixed_bytes .get (device , 0 ), scratch_fixed )
191-
192- # we allocate a model-persistent scratch space for each device
193- model .scratch_spaces : Dict [torch .device , ScratchSpace ] = {}
194- for device , scratch_bytes in fixed_bytes .items ():
195- model .scratch_spaces [device ] = ScratchSpace (scratch_bytes , device )
196-
197- for _ , submodule in model .named_modules ():
198- if isinstance (submodule , AwqExllamaV2QuantLinear ):
199- device = submodule .qweight .device
200- submodule .post_init (scratch_space = model .scratch_spaces [device ])
201-
202- return model
203-
204-
205- def next_multiple (x , multiple ):
206- return ((x + multiple - 1 ) // multiple ) * multiple
0 commit comments