Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 98 additions & 4 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,18 +242,93 @@ def encode_tokens(tokenizer, string, bos=True, device=default_device):

def _load_model(checkpoint_path, device, precision):
checkpoint = torch.load(
str(checkpoint_path), mmap=True, weights_only=True, map_location="cpu"
str(checkpoint_path), mmap=True, weights_only=True, map_location=device
)
if "model" in checkpoint and "stories" in str(checkpoint_path):
checkpoint = checkpoint["model"]

with torch.device("meta"):
model = Transformer.from_name(checkpoint_path)

model.load_state_dict(checkpoint, assign=True)
model = model.to(device=device, dtype=precision)

return model.eval()


def _load_model_with_streaming_quantization(checkpoint_path, device, precision, quant_config):
"""Load model with streaming quantization to avoid OOM.

Strategy:
1. Create model on meta device
2. Load and quantize one transformer layer at a time
3. Each layer: load weights -> quantize -> move to device -> clear cache
"""
from torchao.quantization import quantize_
import gc

# Load checkpoint with mmap for lazy loading
checkpoint = torch.load(
str(checkpoint_path), mmap=True, weights_only=True, map_location="cpu"
)
if "model" in checkpoint and "stories" in str(checkpoint_path):
checkpoint = checkpoint["model"]

# Create model on meta device (no memory allocated)
with torch.device("meta"):
model = Transformer.from_name(checkpoint_path)

# First, load non-layer parameters (tok_embeddings, norm, output) directly to device
non_layer_keys = [k for k in checkpoint.keys() if not k.startswith("layers.")]
for key in non_layer_keys:
parts = key.split(".")
target = model
for part in parts[:-1]:
target = getattr(target, part)
param_value = checkpoint[key].to(dtype=precision, device=device)
setattr(target, parts[-1], torch.nn.Parameter(param_value, requires_grad=False))

# Get number of layers
n_layers = model.config.n_layer
print(f" Loading {n_layers} transformer layers with streaming quantization...")

# Process each transformer layer
for layer_idx in range(n_layers):
layer_prefix = f"layers.{layer_idx}."
layer = model.layers[layer_idx]

# Collect all keys for this layer
layer_keys = [k for k in checkpoint.keys() if k.startswith(layer_prefix)]

# Load all parameters for this layer to CPU first, then to device
for key in layer_keys:
param_name = key[len(layer_prefix):] # Remove "layers.X." prefix
parts = param_name.split(".")
target = layer
for part in parts[:-1]:
target = getattr(target, part)

# Load to device with precision
param_value = checkpoint[key].to(dtype=precision, device=device)
setattr(target, parts[-1], torch.nn.Parameter(param_value, requires_grad=False))

# Quantize this layer on device
quantize_(layer, quant_config)

# Clear cache after each layer
if "xpu" in str(device):
torch.xpu.empty_cache()
elif "cuda" in str(device):
torch.cuda.empty_cache()
gc.collect()

if (layer_idx + 1) % 10 == 0:
print(f" Processed {layer_idx + 1}/{n_layers} layers")

print(f" All {n_layers} layers loaded and quantized")
return model.eval()


B_INST, E_INST = "[INST]", "[/INST]"


Expand Down Expand Up @@ -308,7 +383,25 @@ def main(

print("Loading model ...")
t0 = time.time()
model = _load_model(checkpoint_path, device, precision)

# For int4wo quantization on XPU, use HuggingFace's TorchAoConfig approach
# This quantizes layer-by-layer during loading to avoid OOM
_int4wo_done = False
if quantization and "int4wo" in quantization and "xpu" in device:
from torchao.quantization import Int4WeightOnlyConfig

use_hqq = "hqq" in quantization
group_size = int(quantization.split("-")[1])
assert group_size in [32, 64, 128, 256], (
f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}"
)
quant_config = Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq, int4_packing_format="plain_int32")

# Use streaming quantization: load and quantize one layer at a time
model = _load_model_with_streaming_quantization(checkpoint_path, device, precision, quant_config)
_int4wo_done = True
else:
model = _load_model(checkpoint_path, device, precision)

device_sync(device=device) # MKG
print(f"Time to load model: {time.time() - t0:.02f} seconds")
Expand Down Expand Up @@ -417,7 +510,8 @@ def ffn_or_attn_only(mod, fqn):
)
else:
quantize_(model, Int8DynamicActivationInt8WeightConfig())
if "int4wo" in quantization:
if "int4wo" in quantization and not _int4wo_done:
# Skip if already done via HuggingFace TorchAoConfig loading
use_hqq = False
if "hqq" in quantization:
use_hqq = True
Expand All @@ -432,7 +526,7 @@ def ffn_or_attn_only(mod, fqn):
)
quantize_(
model,
Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq, version=1),
Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq),
)
elif "int4dq-" in quantization:
from torchao.dtypes import CutlassInt4PackedLayout
Expand Down