Skip to content

Commit 4fc9043

Browse files
committed
WIP: reload
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 749c91c commit 4fc9043

File tree

3 files changed

+108
-189
lines changed

3 files changed

+108
-189
lines changed

vllm/model_executor/model_loader/default_loader.py

Lines changed: 10 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import os
66
import time
77
from collections.abc import Generator, Iterable
8-
from typing import cast
8+
from typing import cast, Optional
99

1010
import torch
1111
from torch import nn
@@ -272,44 +272,21 @@ def download_model(self, model_config: ModelConfig) -> None:
272272
allow_patterns_overrides=None,
273273
)
274274

275-
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
275+
def load_weights(self, model: nn.Module, model_config: ModelConfig, weights_iterator: Optional[Iterable[tuple[str, torch.Tensor]]] = None) -> None:
276276
if model_config.quantization == "torchao" and torchao_version_at_least(
277277
"0.14.0"
278278
):
279279
self.load_config.safetensors_load_strategy = "torchao"
280-
weights_to_load = {name for name, _ in model.named_parameters()}
281-
282-
# if we don't have `model.weight_metadata_and_attr_saved` defined and
283-
# set to True, it means that this is either offline quantization case
284-
# or the first run of online quantization
285-
# see online_quantization.py for detailed notes
286-
offline_quantization_or_first_run_of_online_quantization = not getattr(
287-
model, "weight_metadata_and_attr_saved", False
288-
)
280+
281+
# use provided weights or load from disk
282+
if weights_iterator is None:
283+
weights_iterator = self.get_all_weights(model_config, model)
289284

290-
if model_config.quantization is None:
291-
# model is not quantized
292-
loaded_weights = model.load_weights(
293-
self.get_all_weights(model_config, model)
294-
)
295-
elif offline_quantization_or_first_run_of_online_quantization:
296-
# case 1: offline quantized checkpoint
297-
# case 2: Step I1 first run of weight loading with
298-
# online quantization
299-
# see online_quantization.py for detailed notes
300-
loaded_weights = model.load_weights(
301-
self.get_all_weights(model_config, model)
302-
)
303-
else:
304-
# to avoid circular dependency
305-
from vllm.model_executor.model_loader.online_quantization import (
306-
load_weights_and_online_quantize,
307-
)
308-
309-
# subsequent runs of weight loading with online
310-
# quantization
311-
loaded_weights = load_weights_and_online_quantize(self, model, model_config)
285+
# load weights into model
286+
weights_to_load = {name for name, _ in model.named_parameters()}
287+
loaded_weights = model.load_weights(weights_iterator)
312288

289+
# logging and validation
313290
self.counter_after_loading_weights = time.perf_counter()
314291
logger.info_once(
315292
"Loading weights took %.2f seconds",

vllm/model_executor/model_loader/online_quantization.py

Lines changed: 71 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import torch
77
from torch import nn
8+
from copy import deepcopy
89

910
from vllm.config import ModelConfig
1011
from vllm.logger import init_logger
@@ -13,6 +14,11 @@
1314

1415
logger = init_logger(__name__)
1516

17+
SUPPORTED_QUANT_CONFIGS = {
18+
"torchao",
19+
"fp8",
20+
}
21+
1622
# Notes for Online Quantization
1723
# In terms of state of checkpoints, quantization config and their
1824
# correspondance to online quantization:
@@ -64,161 +70,74 @@
6470
def maybe_save_metadata_and_attributes_for_weight_reloading(
6571
model: nn.Module, model_config: ModelConfig
6672
):
67-
# following is to support on the fly quantization, currently only supported
68-
# for torchao
69-
if model_config.quantization != "torchao":
70-
return
71-
72-
if getattr(model, "process_weights_after_loading_already_called", False):
73-
# In case `process_weights_after_loading` is called multiple times
74-
# we'll skip it at later times
75-
logger.warning(
76-
"process_weights_after_loading already called for model %s", model
77-
)
78-
return
79-
73+
# assume this is called right after weight loading and before/ at the start of process_weights_after_loading
8074
from vllm.model_executor.model_loader.weight_utils import get_quant_config
8175

8276
quant_config = get_quant_config(model_config, None)
83-
84-
# If checkpoint is already torchao serialized, this means it's
85-
# pre-quantized quantization case, we'll skip saving the metadata
86-
# Otherwise, this is Step I2 of initialization steps of
87-
# online quantization
88-
# This step record the weights metadata and weight attributes so we can
89-
# restore the bfloat16 model weights during the relad step (R1 and R2)
90-
# see Notes in online_quantization.py for more details
91-
if not (
92-
hasattr(quant_config, "is_checkpoint_torchao_serialized")
93-
and not quant_config.is_checkpoint_torchao_serialized
94-
):
77+
if quant_config.get_name() not in SUPPORTED_QUANT_CONFIGS:
9578
return
96-
97-
# This is the I2 step of online quantiztion that saves
98-
# metadata and attributes of weights so they can be used in R1 and
99-
# R2 step, note that we only save these during initialization
100-
101-
# Includes two things
102-
# 1. save floating point metadata (shape, dtype, device) for init
103-
# 2. save weight attributes, e.g. `output_dim`, `weight_loader` for init
104-
105-
if getattr(model, "weight_metadata_and_attr_saved", False):
106-
return
107-
108-
# save the dtype, shape and device for model parameter, used for
109-
# restoring the model high precision parameters before
110-
# reloading the weights
111-
assert not hasattr(model, "original_weights_rebuild_keys")
112-
model.original_weights_rebuild_keys = {}
113-
for name, p in model.named_parameters():
114-
model.original_weights_rebuild_keys[name] = {
115-
"shape": p.shape,
116-
"dtype": p.dtype,
117-
"device": p.device,
118-
}
119-
120-
# record the weight attributes (loader functions etc.)
121-
# so these can be recovered later when we reload the weights
122-
# structure: {"weight_name": {"weight_attr_key": attr}}
123-
assert not hasattr(model, "recorded_weight_attr")
124-
model.recorded_weight_attr = {}
125-
for name, param in model.named_parameters():
126-
model.recorded_weight_attr[name] = {}
127-
for key in param.__dict__:
128-
if hasattr(param, key):
129-
attr = getattr(param, key)
130-
if not callable(attr):
131-
model.recorded_weight_attr[name][key] = attr
132-
elif hasattr(attr, "__self__") and param is attr.__self__:
133-
# if attr is a bonded method for an instance, and
134-
# attr.__self__ points to the instance (param)
135-
# we'll record the underlying function object
136-
model.recorded_weight_attr[name][key] = attr.__func__
137-
else:
138-
model.recorded_weight_attr[name][key] = attr
139-
# mark the metadata and attributes saved so we don't run it again
140-
model.weight_metadata_and_attr_saved = True
141-
142-
143-
def _bond_method_to_cls(func, obj):
144-
if hasattr(func, "__self__") or not callable(func):
145-
# If the function is already bound to an instance, return it as is
146-
return func
147-
else:
148-
return types.MethodType(func, obj)
149-
150-
151-
def load_weights_and_online_quantize(
152-
model_loader: DefaultModelLoader, model: nn.Module, model_config: ModelConfig
153-
) -> set[str]:
154-
# online quantization, right now only enabled for
155-
# torchao
156-
# R1, R2, R3, R4 in the Notes
157-
158-
# TODO: Add fp8 support
159-
assert model_config.quantization == "torchao", (
160-
"online quantization is only enabled for torchao currently"
79+
80+
if not hasattr(model, "weight_loading_metadata"):
81+
setattr(model, "weight_loading_metadata", {
82+
name: _copy_to_meta_tensor(param)
83+
for name, param in model.named_parameters()
84+
})
85+
86+
return getattr(model, "weight_loading_metadata")
87+
88+
89+
def restore_weights_for_loading(model: nn.Module):
90+
assert hasattr(model, "weight_loading_metadata")
91+
metadata: dict[str, torch.Tensor] = getattr(model, "weight_loading_metadata")
92+
model_param_names = dict(model.named_parameters(remove_duplicate=False)).keys()
93+
94+
# remove parameters which were not present at load time
95+
params_to_remove = model_param_names - metadata.keys()
96+
for param_fqn in params_to_remove:
97+
module_name, param_name = param_fqn.rsplit(".", 1)
98+
module = model.get_submodule(module_name)
99+
delattr(module, param_name)
100+
101+
# restore parameters that were present at load time
102+
for param_fqn, meta_tensor in metadata.items():
103+
module_name, param_name = param_fqn.rsplit(".", 1)
104+
module = model.get_submodule(module_name)
105+
106+
# for faster runtime, skip materialization if the tensors match
107+
original_tensor = getattr(module, param_name, None)
108+
if _tensors_alike(original_tensor, meta_tensor):
109+
continue
110+
111+
param = _materialize_meta_tensor(meta_tensor)
112+
setattr(module, param)
113+
114+
115+
def _copy_to_meta_tensor(tensor: torch.Tensor) -> torch.Tensor:
116+
new_tensor = tensor.to("meta")
117+
new_tensor.__class__ = tensor.__class__
118+
new_tensor.__dict__ = deepcopy(tensor.__dict__)
119+
new_tensor._original_device = tensor.device
120+
return new_tensor
121+
122+
123+
def _tensors_alike(tensor: torch.Tensor | None, meta: torch.Tensor) -> bool:
124+
if tensor is None:
125+
return False
126+
127+
return (
128+
tensor.device == meta._original_device
129+
and tensor.dtype == meta.dtype
130+
and tensor.shape == meta.shape
131+
and tensor.__dict__ == meta.__dict__
161132
)
162-
# TODO: use create_weights to restore the weights to original state
163-
164-
# Step R1: First restore the quantized weights to original bfloat16
165-
# weights, with original metadata (shape, dtype, device)
166-
# and attributes, so that bfloat16 weights can be loaded properly
167-
existing_param_names = dict(model.named_parameters(remove_duplicate=False)).keys()
168-
named_modules = dict(model.named_modules(remove_duplicate=False))
169-
model_device = None
170-
171-
# Step R2: recover the parameter to the state before first loading
172-
for name, d in model.original_weights_rebuild_keys.items():
173-
_shape = d["shape"]
174-
_dtype = d["dtype"]
175-
_device = d["device"]
176-
if model_device is not None:
177-
assert model_device == _device, (
178-
"Expecting all weights "
179-
"to be in the same device for now, got both: "
180-
f"{model_device} and {_device}"
181-
)
182-
else:
183-
model_device = _device
184-
185-
if name in existing_param_names:
186-
module_name, weight_name = name.rsplit(".", 1)
187-
module = named_modules[module_name]
188-
setattr(
189-
module,
190-
weight_name,
191-
torch.nn.Parameter(torch.empty(_shape, dtype=_dtype, device=_device)),
192-
)
193-
194-
# recorded_weight_attr is
195-
# {"weight_name": {"weight_attr_key": attr}}
196-
# e.g.
197-
# {
198-
# {
199-
# "layer.0.weight": {
200-
# "weight_loader": weight_loader_function_object,
201-
# "input_dim": 0, ...
202-
# },
203-
# "layer.1.weight": ...,
204-
# }
205-
# }
206-
for full_weight_name, weight_attr_dict in model.recorded_weight_attr.items():
207-
for attr_name, attr in weight_attr_dict.items():
208-
module_name, weight_name = full_weight_name.rsplit(".", 1)
209-
module = named_modules[module_name]
210-
weight = getattr(module, weight_name)
211-
if not hasattr(weight, attr_name):
212-
setattr(weight, attr_name, _bond_method_to_cls(attr, weight))
213-
214-
# Step I1: reload bfloat16 / high precision weights
215-
loaded_weights = model.load_weights(
216-
model_loader.get_all_weights(model_config, model)
133+
134+
135+
136+
def _materialize_meta_tensor(tensor: torch.Tensor) -> torch.Tensor:
137+
return torch.empty_strided(
138+
size=tuple(tensor.size()),
139+
stride=tuple(tensor.stride()),
140+
dtype=tensor.dtype,
141+
device=tensor._original_device,
142+
requires_grad=False, # set below to match input
217143
)
218-
219-
# Step I2: online quantize the weights
220-
# manually process weights after loading
221-
model.process_weights_after_loading_already_called = False
222-
process_weights_after_loading(model, model_config, model_device)
223-
model.process_weights_after_loading_already_called = True
224-
return loaded_weights

vllm/v1/worker/gpu_model_runner.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from copy import deepcopy
1111
from functools import reduce
1212
from itertools import product
13-
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
13+
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast, Optional, Iterable
1414

1515
import numpy as np
1616
import torch
@@ -3177,13 +3177,36 @@ def _get_eagle3_aux_layers_from_config(self) -> tuple[int, ...] | None:
31773177

31783178
return None
31793179

3180-
def reload_weights(self) -> None:
3180+
def reload_weights(
3181+
self, weights_iterator: Optional[Iterable[tuple[str, torch.Tensor]]] = None,
3182+
process_weights_after_loading: bool = True) -> None:
3183+
from vllm.model_executor.model_loader.utils import process_weights_after_loading as _process
3184+
from vllm.model_executor.model_loader.online_quantization import (
3185+
restore_weights_for_loading,
3186+
)
3187+
31813188
assert getattr(self, "model", None) is not None, (
31823189
"Cannot reload weights before model is loaded."
31833190
)
3184-
model_loader = get_model_loader(self.load_config)
3191+
model = self.get_model()
3192+
3193+
# for select quant configs, regenerate weights for proper weight loading
3194+
if process_weights_after_loading and hasattr("weight_loading_metadata"):
3195+
restore_weights_for_loading(model)
3196+
31853197
logger.info("Reloading weights inplace...")
3186-
model_loader.load_weights(self.get_model(), model_config=self.model_config)
3198+
model_loader = get_model_loader(self.load_config)
3199+
model_loader.load_weights(model, model_config=self.model_config, weights_iterator=weights_iterator)
3200+
3201+
if process_weights_after_loading:
3202+
device_config = self.vllm_config.device_config
3203+
load_config = self.vllm_config.load_config
3204+
load_device = (
3205+
device_config.device if load_config.device is None else load_config.device
3206+
)
3207+
_process(model, self.model_config, load_device)
3208+
3209+
# TODO: logging total reload time
31873210

31883211
def save_tensorized_model(
31893212
self,

0 commit comments

Comments
 (0)