|
5 | 5 |
|
6 | 6 | import torch |
7 | 7 | from torch import nn |
| 8 | +from copy import deepcopy |
8 | 9 |
|
9 | 10 | from vllm.config import ModelConfig |
10 | 11 | from vllm.logger import init_logger |
|
13 | 14 |
|
14 | 15 | logger = init_logger(__name__) |
15 | 16 |
|
| 17 | +SUPPORTED_QUANT_CONFIGS = { |
| 18 | + "torchao", |
| 19 | + "fp8", |
| 20 | +} |
| 21 | + |
16 | 22 | # Notes for Online Quantization |
17 | 23 | # In terms of state of checkpoints, quantization config and their |
18 | 24 | # correspondance to online quantization: |
|
64 | 70 | def maybe_save_metadata_and_attributes_for_weight_reloading( |
65 | 71 | model: nn.Module, model_config: ModelConfig |
66 | 72 | ): |
67 | | - # following is to support on the fly quantization, currently only supported |
68 | | - # for torchao |
69 | | - if model_config.quantization != "torchao": |
70 | | - return |
71 | | - |
72 | | - if getattr(model, "process_weights_after_loading_already_called", False): |
73 | | - # In case `process_weights_after_loading` is called multiple times |
74 | | - # we'll skip it at later times |
75 | | - logger.warning( |
76 | | - "process_weights_after_loading already called for model %s", model |
77 | | - ) |
78 | | - return |
79 | | - |
| 73 | + # assume this is called right after weight loading and before/ at the start of process_weights_after_loading |
80 | 74 | from vllm.model_executor.model_loader.weight_utils import get_quant_config |
81 | 75 |
|
82 | 76 | quant_config = get_quant_config(model_config, None) |
83 | | - |
84 | | - # If checkpoint is already torchao serialized, this means it's |
85 | | - # pre-quantized quantization case, we'll skip saving the metadata |
86 | | - # Otherwise, this is Step I2 of initialization steps of |
87 | | - # online quantization |
88 | | - # This step record the weights metadata and weight attributes so we can |
89 | | - # restore the bfloat16 model weights during the relad step (R1 and R2) |
90 | | - # see Notes in online_quantization.py for more details |
91 | | - if not ( |
92 | | - hasattr(quant_config, "is_checkpoint_torchao_serialized") |
93 | | - and not quant_config.is_checkpoint_torchao_serialized |
94 | | - ): |
| 77 | + if quant_config.get_name() not in SUPPORTED_QUANT_CONFIGS: |
95 | 78 | return |
96 | | - |
97 | | - # This is the I2 step of online quantiztion that saves |
98 | | - # metadata and attributes of weights so they can be used in R1 and |
99 | | - # R2 step, note that we only save these during initialization |
100 | | - |
101 | | - # Includes two things |
102 | | - # 1. save floating point metadata (shape, dtype, device) for init |
103 | | - # 2. save weight attributes, e.g. `output_dim`, `weight_loader` for init |
104 | | - |
105 | | - if getattr(model, "weight_metadata_and_attr_saved", False): |
106 | | - return |
107 | | - |
108 | | - # save the dtype, shape and device for model parameter, used for |
109 | | - # restoring the model high precision parameters before |
110 | | - # reloading the weights |
111 | | - assert not hasattr(model, "original_weights_rebuild_keys") |
112 | | - model.original_weights_rebuild_keys = {} |
113 | | - for name, p in model.named_parameters(): |
114 | | - model.original_weights_rebuild_keys[name] = { |
115 | | - "shape": p.shape, |
116 | | - "dtype": p.dtype, |
117 | | - "device": p.device, |
118 | | - } |
119 | | - |
120 | | - # record the weight attributes (loader functions etc.) |
121 | | - # so these can be recovered later when we reload the weights |
122 | | - # structure: {"weight_name": {"weight_attr_key": attr}} |
123 | | - assert not hasattr(model, "recorded_weight_attr") |
124 | | - model.recorded_weight_attr = {} |
125 | | - for name, param in model.named_parameters(): |
126 | | - model.recorded_weight_attr[name] = {} |
127 | | - for key in param.__dict__: |
128 | | - if hasattr(param, key): |
129 | | - attr = getattr(param, key) |
130 | | - if not callable(attr): |
131 | | - model.recorded_weight_attr[name][key] = attr |
132 | | - elif hasattr(attr, "__self__") and param is attr.__self__: |
133 | | - # if attr is a bonded method for an instance, and |
134 | | - # attr.__self__ points to the instance (param) |
135 | | - # we'll record the underlying function object |
136 | | - model.recorded_weight_attr[name][key] = attr.__func__ |
137 | | - else: |
138 | | - model.recorded_weight_attr[name][key] = attr |
139 | | - # mark the metadata and attributes saved so we don't run it again |
140 | | - model.weight_metadata_and_attr_saved = True |
141 | | - |
142 | | - |
143 | | -def _bond_method_to_cls(func, obj): |
144 | | - if hasattr(func, "__self__") or not callable(func): |
145 | | - # If the function is already bound to an instance, return it as is |
146 | | - return func |
147 | | - else: |
148 | | - return types.MethodType(func, obj) |
149 | | - |
150 | | - |
151 | | -def load_weights_and_online_quantize( |
152 | | - model_loader: DefaultModelLoader, model: nn.Module, model_config: ModelConfig |
153 | | -) -> set[str]: |
154 | | - # online quantization, right now only enabled for |
155 | | - # torchao |
156 | | - # R1, R2, R3, R4 in the Notes |
157 | | - |
158 | | - # TODO: Add fp8 support |
159 | | - assert model_config.quantization == "torchao", ( |
160 | | - "online quantization is only enabled for torchao currently" |
| 79 | + |
| 80 | + if not hasattr(model, "weight_loading_metadata"): |
| 81 | + setattr(model, "weight_loading_metadata", { |
| 82 | + name: _copy_to_meta_tensor(param) |
| 83 | + for name, param in model.named_parameters() |
| 84 | + }) |
| 85 | + |
| 86 | + return getattr(model, "weight_loading_metadata") |
| 87 | + |
| 88 | + |
| 89 | +def restore_weights_for_loading(model: nn.Module): |
| 90 | + assert hasattr(model, "weight_loading_metadata") |
| 91 | + metadata: dict[str, torch.Tensor] = getattr(model, "weight_loading_metadata") |
| 92 | + model_param_names = dict(model.named_parameters(remove_duplicate=False)).keys() |
| 93 | + |
| 94 | + # remove parameters which were not present at load time |
| 95 | + params_to_remove = model_param_names - metadata.keys() |
| 96 | + for param_fqn in params_to_remove: |
| 97 | + module_name, param_name = param_fqn.rsplit(".", 1) |
| 98 | + module = model.get_submodule(module_name) |
| 99 | + delattr(module, param_name) |
| 100 | + |
| 101 | + # restore parameters that were present at load time |
| 102 | + for param_fqn, meta_tensor in metadata.items(): |
| 103 | + module_name, param_name = param_fqn.rsplit(".", 1) |
| 104 | + module = model.get_submodule(module_name) |
| 105 | + |
| 106 | + # for faster runtime, skip materialization if the tensors match |
| 107 | + original_tensor = getattr(module, param_name, None) |
| 108 | + if _tensors_alike(original_tensor, meta_tensor): |
| 109 | + continue |
| 110 | + |
| 111 | + param = _materialize_meta_tensor(meta_tensor) |
| 112 | + setattr(module, param) |
| 113 | + |
| 114 | + |
| 115 | +def _copy_to_meta_tensor(tensor: torch.Tensor) -> torch.Tensor: |
| 116 | + new_tensor = tensor.to("meta") |
| 117 | + new_tensor.__class__ = tensor.__class__ |
| 118 | + new_tensor.__dict__ = deepcopy(tensor.__dict__) |
| 119 | + new_tensor._original_device = tensor.device |
| 120 | + return new_tensor |
| 121 | + |
| 122 | + |
| 123 | +def _tensors_alike(tensor: torch.Tensor | None, meta: torch.Tensor) -> bool: |
| 124 | + if tensor is None: |
| 125 | + return False |
| 126 | + |
| 127 | + return ( |
| 128 | + tensor.device == meta._original_device |
| 129 | + and tensor.dtype == meta.dtype |
| 130 | + and tensor.shape == meta.shape |
| 131 | + and tensor.__dict__ == meta.__dict__ |
161 | 132 | ) |
162 | | - # TODO: use create_weights to restore the weights to original state |
163 | | - |
164 | | - # Step R1: First restore the quantized weights to original bfloat16 |
165 | | - # weights, with original metadata (shape, dtype, device) |
166 | | - # and attributes, so that bfloat16 weights can be loaded properly |
167 | | - existing_param_names = dict(model.named_parameters(remove_duplicate=False)).keys() |
168 | | - named_modules = dict(model.named_modules(remove_duplicate=False)) |
169 | | - model_device = None |
170 | | - |
171 | | - # Step R2: recover the parameter to the state before first loading |
172 | | - for name, d in model.original_weights_rebuild_keys.items(): |
173 | | - _shape = d["shape"] |
174 | | - _dtype = d["dtype"] |
175 | | - _device = d["device"] |
176 | | - if model_device is not None: |
177 | | - assert model_device == _device, ( |
178 | | - "Expecting all weights " |
179 | | - "to be in the same device for now, got both: " |
180 | | - f"{model_device} and {_device}" |
181 | | - ) |
182 | | - else: |
183 | | - model_device = _device |
184 | | - |
185 | | - if name in existing_param_names: |
186 | | - module_name, weight_name = name.rsplit(".", 1) |
187 | | - module = named_modules[module_name] |
188 | | - setattr( |
189 | | - module, |
190 | | - weight_name, |
191 | | - torch.nn.Parameter(torch.empty(_shape, dtype=_dtype, device=_device)), |
192 | | - ) |
193 | | - |
194 | | - # recorded_weight_attr is |
195 | | - # {"weight_name": {"weight_attr_key": attr}} |
196 | | - # e.g. |
197 | | - # { |
198 | | - # { |
199 | | - # "layer.0.weight": { |
200 | | - # "weight_loader": weight_loader_function_object, |
201 | | - # "input_dim": 0, ... |
202 | | - # }, |
203 | | - # "layer.1.weight": ..., |
204 | | - # } |
205 | | - # } |
206 | | - for full_weight_name, weight_attr_dict in model.recorded_weight_attr.items(): |
207 | | - for attr_name, attr in weight_attr_dict.items(): |
208 | | - module_name, weight_name = full_weight_name.rsplit(".", 1) |
209 | | - module = named_modules[module_name] |
210 | | - weight = getattr(module, weight_name) |
211 | | - if not hasattr(weight, attr_name): |
212 | | - setattr(weight, attr_name, _bond_method_to_cls(attr, weight)) |
213 | | - |
214 | | - # Step I1: reload bfloat16 / high precision weights |
215 | | - loaded_weights = model.load_weights( |
216 | | - model_loader.get_all_weights(model_config, model) |
| 133 | + |
| 134 | + |
| 135 | + |
| 136 | +def _materialize_meta_tensor(tensor: torch.Tensor) -> torch.Tensor: |
| 137 | + return torch.empty_strided( |
| 138 | + size=tuple(tensor.size()), |
| 139 | + stride=tuple(tensor.stride()), |
| 140 | + dtype=tensor.dtype, |
| 141 | + device=tensor._original_device, |
| 142 | + requires_grad=False, # set below to match input |
217 | 143 | ) |
218 | | - |
219 | | - # Step I2: online quantize the weights |
220 | | - # manually process weights after loading |
221 | | - model.process_weights_after_loading_already_called = False |
222 | | - process_weights_after_loading(model, model_config, model_device) |
223 | | - model.process_weights_after_loading_already_called = True |
224 | | - return loaded_weights |
0 commit comments