Skip to content

Commit 3aa6696

Browse files
committed
modify unflatten for vllm
1 parent 6c78c4d commit 3aa6696

File tree

4 files changed

+22
-1
lines changed

4 files changed

+22
-1
lines changed

benchmarks/benchmark_uintx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
from copy import deepcopy
77

88
import torch
9-
109
from torchao.prototype.uintx import (
1110
uintx_affine_weight_only,
1211
unpack_cpu,
1312
)
13+
1414
from torchao.quantization.quant_api import quantize_
1515

1616

test/prototype/safetensors/test_safetensors_support.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def test_safetensors(self, config, act_pre_scale=False):
7777
reconstructed_dict = unflatten_tensor_state_dict(
7878
tensors_data_dict, metadata
7979
)
80+
assert not tensors_data_dict
8081

8182
model = torch.nn.Sequential(
8283
torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")

test/test_low_bit_optim.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
common_utils.SEED = 1234
3131

3232
from packaging.version import Version
33+
3334
from torchao import optim
3435
from torchao.optim.quant_utils import (
3536
_fp32_to_bf16_sr,

torchao/prototype/safetensors/safetensors_support.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,21 @@ def unflatten_tensor_state_dict(
6868
result = {}
6969

7070
for tensor_name in tensor_names:
71+
to_be_deleted = []
72+
7173
module_fqn, weight_name = tensor_name.rsplit(".", 1)
7274

7375
prefix = f"{module_fqn}._{weight_name}_"
7476
tensor_tensors = {}
77+
7578
for key, value in combined_data.items():
7679
if key.startswith(prefix):
7780
# Remove the prefix
7881
tensor_tensors[key[len(prefix) :]] = value
82+
full_tensor_name_in_state_dict = key
83+
to_be_deleted.append(
84+
full_tensor_name_in_state_dict
85+
) # for tensor subclass
7986

8087
tensor_metadata = json.loads(metadata.get(tensor_name))
8188
tensor_type = tensor_metadata.get("_type")
@@ -89,9 +96,21 @@ def unflatten_tensor_state_dict(
8996
tensor_metadata["_data"].update(tensor_tensors)
9097
result[tensor_name] = object_from_dict(tensor_metadata)
9198
elif tensor_type == torch.Tensor.__name__:
99+
if tensor_name not in tensors_data_dict.keys():
100+
# we allow the option of loading in state_dict info for a single tensor
101+
# if tensor state dict info is not loaded in yet, we wait for it to be provided
102+
# in a future call
103+
continue
92104
result[tensor_name] = tensors_data_dict[tensor_name]
105+
to_be_deleted.append(
106+
tensor_name
107+
) # add here because key for torch.Tensor has no prefix
93108
else:
94109
raise ValueError(f"Unsupported tensor type: {tensor_type}")
110+
111+
for tensor_name in to_be_deleted:
112+
del tensors_data_dict[tensor_name]
113+
95114
return result
96115

97116

0 commit comments

Comments
 (0)