Skip to content

Commit 805ba66

Browse files
authored
Merge branch 'main' into remove-deprecated-download-from-file
2 parents a179143 + 5169c23 commit 805ba66

File tree

10 files changed

+436
-67
lines changed

10 files changed

+436
-67
lines changed

src/transformers/core_model_loading.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ def convert(
115115
source_keys: list[str],
116116
target_keys: list[str],
117117
full_layer_name: str,
118+
model,
119+
missing_keys,
118120
config,
119121
**kwargs,
120122
) -> dict[str, list[torch.Tensor]]:
@@ -138,6 +140,8 @@ def convert(
138140
source_keys: list[str],
139141
target_keys: list[str],
140142
full_layer_name: str,
143+
model,
144+
missing_keys,
141145
config,
142146
) -> dict[str, list[torch.Tensor]]:
143147
tensors = next(iter(value.values()))
@@ -163,6 +167,8 @@ def convert(
163167
source_keys: list[str],
164168
target_keys: list[str],
165169
full_layer_name: str,
170+
model,
171+
missing_keys,
166172
config,
167173
) -> dict[str, torch.Tensor]:
168174
if len(target_keys) != 1:
@@ -191,6 +197,8 @@ def convert(
191197
source_keys: list[str],
192198
target_keys: list[str],
193199
full_layer_name: str,
200+
model,
201+
missing_keys,
194202
config,
195203
) -> dict[str, torch.Tensor]:
196204
merged: dict[str, torch.Tensor] = {}
@@ -220,6 +228,8 @@ def convert(
220228
source_keys: list[str],
221229
target_keys: list[str],
222230
full_layer_name: str,
231+
model,
232+
missing_keys,
223233
config,
224234
) -> dict[str, list[torch.Tensor]]:
225235
if len(value) != len(self.sizes):
@@ -258,6 +268,8 @@ def convert(
258268
source_keys: list[str],
259269
target_keys: list[str],
260270
full_layer_name: str,
271+
model,
272+
missing_keys,
261273
config,
262274
) -> dict[str, list[torch.Tensor]]:
263275
self.config = config
@@ -298,21 +310,28 @@ def add_tensor(self, target_key: str, source_key: str, source_pattern: str, futu
298310
class WeightRenaming(WeightTransform):
299311
# Special case of WeightTransform that only renames keys without any conversion.
300312

301-
def convert(self, layer_name: str, config=None, quantizer=None, missing_keys: Optional[MutableSet[str]] = None):
313+
def convert(
314+
self,
315+
layer_name: str,
316+
model=None,
317+
config=None,
318+
hf_quantizer=None,
319+
missing_keys: Optional[MutableSet[str]] = None,
320+
):
302321
misc = {}
303322
for pattern, futures in self.collected_tensors.items():
304323
self.collected_tensors[pattern] = [future.result() for future in futures]
305324

306325
collected_tensors = self.collected_tensors
307-
if quantizer is not None and self.quantization_operation is not None:
326+
if hf_quantizer is not None and self.quantization_operation is not None:
308327
with log_to_misc(layer_name, misc, (self.collected_tensors, layer_name), self.quantization_operation):
309328
collected_tensors = self.quantization_operation.convert(
310329
self.collected_tensors,
311330
source_keys=self.source_keys,
312331
target_keys=self.target_keys,
313332
full_layer_name=layer_name,
333+
model=model,
314334
config=config,
315-
quant_config=quantizer.quantization_config,
316335
missing_keys=missing_keys,
317336
)
318337

@@ -332,7 +351,14 @@ def __post_init__(self):
332351
if not self.operations:
333352
raise ValueError("WeightConverter requires at least one operation.")
334353

335-
def convert(self, layer_name: str, config=None, quantizer=None, missing_keys: Optional[MutableSet[str]] = None):
354+
def convert(
355+
self,
356+
layer_name: str,
357+
model=None,
358+
config=None,
359+
hf_quantizer=None,
360+
missing_keys: Optional[MutableSet[str]] = None,
361+
):
336362
misc = {}
337363
for pattern, futures in self.collected_tensors.items():
338364
self.collected_tensors[pattern] = [future.result() for future in futures]
@@ -345,17 +371,19 @@ def convert(self, layer_name: str, config=None, quantizer=None, missing_keys: Op
345371
source_keys=self.source_keys,
346372
target_keys=self.target_keys,
347373
full_layer_name=layer_name,
374+
model=model,
348375
config=config,
376+
missing_keys=missing_keys,
349377
)
350-
if quantizer is not None and self.quantization_operation is not None:
378+
if hf_quantizer is not None and self.quantization_operation is not None:
351379
with log_to_misc(layer_name, misc, (collected_tensors, layer_name), self.quantization_operation):
352380
collected_tensors = self.quantization_operation.convert(
353381
collected_tensors,
354382
source_keys=self.source_keys,
355383
target_keys=self.target_keys,
356384
full_layer_name=layer_name,
357385
config=config,
358-
quant_config=quantizer.quantization_config,
386+
model=model,
359387
missing_keys=missing_keys,
360388
)
361389
return collected_tensors, misc
@@ -626,7 +654,6 @@ def convert_and_load_state_dict_in_model(
626654
```
627655
628656
"""
629-
630657
prefix = model.base_model_prefix
631658
tp_plan = tp_plan or {}
632659
device_map = device_map or {"": "cpu"}
@@ -750,7 +777,11 @@ def convert_and_load_state_dict_in_model(
750777
pbar.refresh()
751778
try:
752779
realized_value, misc = mapping.convert(
753-
first_param_name, config=model.config, quantizer=hf_quantizer, missing_keys=missing_keys
780+
first_param_name,
781+
model=model,
782+
config=model.config,
783+
hf_quantizer=hf_quantizer,
784+
missing_keys=missing_keys,
754785
)
755786
for target_name, param in realized_value.items():
756787
param = param[0] if isinstance(param, list) else param

src/transformers/integrations/accelerate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def all_tensors():
241241
if name in tied_keys:
242242
continue
243243
if hf_quantizer is not None:
244-
dtype_size = hf_quantizer.param_element_size(model, name)
244+
dtype_size = hf_quantizer.param_element_size(model, name, param)
245245
else:
246246
dtype_size = param.element_size()
247247
size = param.numel() * dtype_size

src/transformers/integrations/bitsandbytes.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@ def __init__(self, hf_quantizer):
3636
self.hf_quantizer = hf_quantizer
3737

3838
def convert(
39-
self, input_dict: torch.Tensor, model: Optional[torch.nn.Module] = None, missing_keys=None, **kwargs
39+
self,
40+
input_dict: dict[str, list[torch.Tensor]],
41+
model: Optional[torch.nn.Module] = None,
42+
missing_keys=None,
43+
**kwargs,
4044
) -> dict[str, torch.Tensor]:
4145
"""
4246
we need to store some parameters to create the quantized weight. For example, bnb requires 6 values that are stored in the checkpoint to recover the quantized weight. So we store them in a dict that it stored in hf_quantizer for now as we can't save it in the op since we create an op per tensor.
@@ -59,6 +63,7 @@ def convert(
5963
# remove missing keys that were create when initializing Params4bit
6064
for key in new_value.quant_state.as_dict(packed=True).keys():
6165
missing_keys.discard(f"{full_name}.{key}")
66+
module._is_hf_initialized = True
6267
return {target_key: new_value}
6368
else:
6469
module_name = target_key.rsplit(".", 1)[0]
@@ -77,6 +82,7 @@ def convert(
7782
device=value.device,
7883
module=module,
7984
)
85+
module._is_hf_initialized = True
8086
del self.hf_quantizer.param_quant_stats[module_name]
8187
return {target_key: new_value}
8288
return {}
@@ -87,7 +93,11 @@ def __init__(self, hf_quantizer):
8793
self.hf_quantizer = hf_quantizer
8894

8995
def convert(
90-
self, input_dict: torch.Tensor, model: Optional[torch.nn.Module] = None, missing_keys=None, **kwargs
96+
self,
97+
input_dict: dict[str, list[torch.Tensor]],
98+
model: Optional[torch.nn.Module] = None,
99+
missing_keys=None,
100+
**kwargs,
91101
) -> dict[str, torch.Tensor]:
92102
target_key, value = tuple(input_dict.items())[0]
93103
value = value[0] if isinstance(value, list) else value

src/transformers/integrations/tensor_parallel.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,13 @@
2020
from functools import partial, reduce
2121
from typing import Optional
2222

23-
import torch
24-
import torch.distributed as dist
25-
from torch import nn
23+
from ..utils.import_utils import is_torch_available
24+
25+
26+
if is_torch_available():
27+
import torch
28+
import torch.distributed as dist
29+
from torch import nn
2630

2731
from ..distributed import DistributedConfig
2832
from ..utils import is_torch_greater_or_equal, logging
@@ -31,12 +35,12 @@
3135

3236
logger = logging.get_logger(__name__)
3337

34-
# Cache this result has it's a C FFI call which can be pretty time-consuming
35-
_torch_distributed_available = torch.distributed.is_available()
36-
38+
if is_torch_available():
39+
# Cache this result has it's a C FFI call which can be pretty time-consuming
40+
_torch_distributed_available = torch.distributed.is_available()
3741

38-
if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
39-
from torch.distributed.tensor import DTensor, Placement, Replicate, Shard
42+
if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
43+
from torch.distributed.tensor import DTensor, Placement, Replicate, Shard
4044

4145

4246
def initialize_tensor_parallelism(
@@ -169,19 +173,20 @@ def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weig
169173
return None
170174

171175

172-
str_to_dtype = {
173-
"BOOL": torch.bool,
174-
"U8": torch.uint8,
175-
"I8": torch.int8,
176-
"I16": torch.int16,
177-
"F16": torch.float16,
178-
"BF16": torch.bfloat16,
179-
"I32": torch.int32,
180-
"F32": torch.float32,
181-
"F64": torch.float64,
182-
"I64": torch.int64,
183-
"F8_E4M3": torch.float8_e4m3fn,
184-
}
176+
if is_torch_available():
177+
str_to_dtype = {
178+
"BOOL": torch.bool,
179+
"U8": torch.uint8,
180+
"I8": torch.int8,
181+
"I16": torch.int16,
182+
"F16": torch.float16,
183+
"BF16": torch.bfloat16,
184+
"I32": torch.int32,
185+
"F32": torch.float32,
186+
"F64": torch.float64,
187+
"I64": torch.int64,
188+
"F8_E4M3": torch.float8_e4m3fn,
189+
}
185190

186191

187192
def get_packed_weights(param, empty_param, device_mesh, rank, dim):

0 commit comments

Comments
 (0)