diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index fc3d371139..4e44939e0c 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -242,18 +242,93 @@ def encode_tokens(tokenizer, string, bos=True, device=default_device): def _load_model(checkpoint_path, device, precision): checkpoint = torch.load( - str(checkpoint_path), mmap=True, weights_only=True, map_location="cpu" + str(checkpoint_path), mmap=True, weights_only=True, map_location=device ) if "model" in checkpoint and "stories" in str(checkpoint_path): checkpoint = checkpoint["model"] + with torch.device("meta"): model = Transformer.from_name(checkpoint_path) + model.load_state_dict(checkpoint, assign=True) model = model.to(device=device, dtype=precision) return model.eval() +def _load_model_with_streaming_quantization(checkpoint_path, device, precision, quant_config): + """Load model with streaming quantization to avoid OOM. + + Strategy: + 1. Create model on meta device + 2. Load and quantize one transformer layer at a time + 3. Each layer: load weights -> quantize -> move to device -> clear cache + """ + from torchao.quantization import quantize_ + import gc + + # Load checkpoint with mmap for lazy loading + checkpoint = torch.load( + str(checkpoint_path), mmap=True, weights_only=True, map_location="cpu" + ) + if "model" in checkpoint and "stories" in str(checkpoint_path): + checkpoint = checkpoint["model"] + + # Create model on meta device (no memory allocated) + with torch.device("meta"): + model = Transformer.from_name(checkpoint_path) + + # First, load non-layer parameters (tok_embeddings, norm, output) directly to device + non_layer_keys = [k for k in checkpoint.keys() if not k.startswith("layers.")] + for key in non_layer_keys: + parts = key.split(".") + target = model + for part in parts[:-1]: + target = getattr(target, part) + param_value = checkpoint[key].to(dtype=precision, device=device) + setattr(target, parts[-1], torch.nn.Parameter(param_value, requires_grad=False)) + + # Get number of layers + n_layers = model.config.n_layer + print(f" Loading {n_layers} transformer layers with streaming quantization...") + + # Process each transformer layer + for layer_idx in range(n_layers): + layer_prefix = f"layers.{layer_idx}." + layer = model.layers[layer_idx] + + # Collect all keys for this layer + layer_keys = [k for k in checkpoint.keys() if k.startswith(layer_prefix)] + + # Load all parameters for this layer to CPU first, then to device + for key in layer_keys: + param_name = key[len(layer_prefix):] # Remove "layers.X." prefix + parts = param_name.split(".") + target = layer + for part in parts[:-1]: + target = getattr(target, part) + + # Load to device with precision + param_value = checkpoint[key].to(dtype=precision, device=device) + setattr(target, parts[-1], torch.nn.Parameter(param_value, requires_grad=False)) + + # Quantize this layer on device + quantize_(layer, quant_config) + + # Clear cache after each layer + if "xpu" in str(device): + torch.xpu.empty_cache() + elif "cuda" in str(device): + torch.cuda.empty_cache() + gc.collect() + + if (layer_idx + 1) % 10 == 0: + print(f" Processed {layer_idx + 1}/{n_layers} layers") + + print(f" All {n_layers} layers loaded and quantized") + return model.eval() + + B_INST, E_INST = "[INST]", "[/INST]" @@ -308,7 +383,25 @@ def main( print("Loading model ...") t0 = time.time() - model = _load_model(checkpoint_path, device, precision) + + # For int4wo quantization on XPU, use HuggingFace's TorchAoConfig approach + # This quantizes layer-by-layer during loading to avoid OOM + _int4wo_done = False + if quantization and "int4wo" in quantization and "xpu" in device: + from torchao.quantization import Int4WeightOnlyConfig + + use_hqq = "hqq" in quantization + group_size = int(quantization.split("-")[1]) + assert group_size in [32, 64, 128, 256], ( + f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}" + ) + quant_config = Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq, int4_packing_format="plain_int32") + + # Use streaming quantization: load and quantize one layer at a time + model = _load_model_with_streaming_quantization(checkpoint_path, device, precision, quant_config) + _int4wo_done = True + else: + model = _load_model(checkpoint_path, device, precision) device_sync(device=device) # MKG print(f"Time to load model: {time.time() - t0:.02f} seconds") @@ -417,7 +510,8 @@ def ffn_or_attn_only(mod, fqn): ) else: quantize_(model, Int8DynamicActivationInt8WeightConfig()) - if "int4wo" in quantization: + if "int4wo" in quantization and not _int4wo_done: + # Skip if already done via HuggingFace TorchAoConfig loading use_hqq = False if "hqq" in quantization: use_hqq = True @@ -432,7 +526,7 @@ def ffn_or_attn_only(mod, fqn): ) quantize_( model, - Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq, version=1), + Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq), ) elif "int4dq-" in quantization: from torchao.dtypes import CutlassInt4PackedLayout