Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 3ad161c

Browse files
Merge branch 'main' into Tokenizer-New-Type-Onboarding
2 parents a94afd9 + d59a88d commit 3ad161c

File tree

6 files changed

+260
-152
lines changed

6 files changed

+260
-152
lines changed

install/.pins/et-pin.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
73740e9268a4a47baeaedc58a1f75597038d2377
1+
b173722085b3f555d6ba4533d6bbaddfd7c71144

torchchat/cli/builder.py

Lines changed: 51 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@
1717
import torch._inductor.config
1818
import torch.distributed as dist
1919

20-
from torchchat.distributed.utils import(
20+
from torchchat.distributed.logging_utils import SingletonLogger
21+
22+
from torchchat.distributed.utils import (
2123
Color as color,
2224
CUDATrackTime,
23-
init_distributed,
2425
GPUMemoryMonitor,
26+
init_distributed,
2527
)
26-
from torchchat.distributed.logging_utils import SingletonLogger
2728

2829
from torchchat.model import Model, ModelArgs, ModelType, Transformer, TransformerArgs
2930
from torchchat.model_config.model_config import resolve_model_config
@@ -37,15 +38,6 @@
3738
from torchchat.utils.quantize import quantize_model
3839

3940

40-
from torchtune.models.convert_weights import meta_to_tune
41-
42-
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
43-
44-
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune
45-
46-
from torchtune.training import set_default_dtype
47-
48-
4941
@dataclass
5042
class BuilderArgs:
5143
checkpoint_path: Optional[Union[Path, str]] = None
@@ -188,15 +180,19 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
188180
tp = getattr(args, "tp", 1)
189181
chpt_from = getattr(args, "chpt_from", "hf")
190182
sdp_backend_dict = {
191-
'math': torch.nn.attention.SDPBackend.MATH,
192-
'flash_attention': torch.nn.attention.SDPBackend.FLASH_ATTENTION,
193-
'efficient_attention': torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
194-
'cudnn_attention': torch.nn.attention.SDPBackend.CUDNN_ATTENTION,
183+
"math": torch.nn.attention.SDPBackend.MATH,
184+
"flash_attention": torch.nn.attention.SDPBackend.FLASH_ATTENTION,
185+
"efficient_attention": torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
186+
"cudnn_attention": torch.nn.attention.SDPBackend.CUDNN_ATTENTION,
195187
}
196188
attention_backend = sdp_backend_dict[args.attention_backend]
197-
if args.device == "cpu" and (args.attention_backend == "efficient_attention"
198-
or args.attention_backend == "cudnn_attention"):
199-
print(f"Warning: {args.attention_backend} is not supported on CPU. Using math instead.")
189+
if args.device == "cpu" and (
190+
args.attention_backend == "efficient_attention"
191+
or args.attention_backend == "cudnn_attention"
192+
):
193+
print(
194+
f"Warning: {args.attention_backend} is not supported on CPU. Using math instead."
195+
)
200196
attention_backend = torch.nn.attention.SDPBackend.MATH
201197
return cls(
202198
checkpoint_dir=checkpoint_dir,
@@ -294,9 +290,9 @@ def validate_model(
294290
use_sentencepiece = not (use_tiktoken or use_hf_tokenizer)
295291

296292
if (
297-
(is_tiktoken and not use_tiktoken) or
298-
(is_hf_tokenizer and not use_hf_tokenizer) or
299-
(is_sentencepiece and not use_sentencepiece)
293+
(is_tiktoken and not use_tiktoken)
294+
or (is_hf_tokenizer and not use_hf_tokenizer)
295+
or (is_sentencepiece and not use_sentencepiece)
300296
):
301297
raise RuntimeError(
302298
"model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}".format(
@@ -404,6 +400,7 @@ def _load_model_gguf(builder_args: BuilderArgs) -> Model:
404400

405401
def _load_checkpoint(builder_args: BuilderArgs):
406402
if builder_args.params_table and builder_args.params_table.endswith("Tune"):
403+
from torchtune.models.convert_weights import meta_to_tune
407404
print("Loading Tune checkpoint")
408405
meta_checkpoint = torch.load(
409406
str(builder_args.checkpoint_path), mmap=True, weights_only=True
@@ -456,9 +453,15 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
456453
checkpoint = checkpoint["model"]
457454

458455
if model.config.model_type == ModelType.Flamingo:
456+
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
457+
from torchtune.models.llama3_2_vision._convert_weights import (
458+
llama3_vision_meta_to_tune,
459+
)
460+
from torchtune.training import set_default_dtype
459461
# TODO: Refactor this. For now, overwrite the model with model loaded from params_path
460-
with set_default_dtype(builder_args.precision), torch.device(
461-
builder_args.device
462+
with (
463+
set_default_dtype(builder_args.precision),
464+
torch.device(builder_args.device),
462465
):
463466
# It doubles the model size the memory, with redundancies of the initialized weights.
464467
# model = Model.from_params(builder_args.params_path)
@@ -494,6 +497,7 @@ def _load_model(builder_args: BuilderArgs) -> Model:
494497
# AOTI-compoiled model will load its own weights.
495498
# Release weights here to avoid OOM
496499
import gc
500+
497501
if hasattr(model, "model"):
498502
model.model = None
499503
gc.collect()
@@ -551,6 +555,7 @@ def _initialize_model(
551555

552556
def do_nothing(max_batch_size, max_seq_length):
553557
pass
558+
554559
model.setup_caches = do_nothing
555560

556561
model.forward = torch._export.aot_load(
@@ -588,6 +593,7 @@ def do_nothing(max_batch_size, max_seq_length):
588593

589594
def do_nothing(max_batch_size, max_seq_length):
590595
pass
596+
591597
model.setup_caches = do_nothing
592598

593599
model.forward = aoti_compiled_model
@@ -639,12 +645,15 @@ def do_nothing(max_batch_size, max_seq_length):
639645
try:
640646
model = torch.load(builder_args.snapshot_path, weights_only=False)
641647
except Exception:
642-
raise RuntimeError(f"Failed to load torchchat snapshot {builder_args.snapshot_path}")
648+
raise RuntimeError(
649+
f"Failed to load torchchat snapshot {builder_args.snapshot_path}"
650+
)
643651
# _active_backend() does not allow DSO & AOTI to be true.
644652
# Choose either.
645653
from torchchat.utils.build_utils import set_backend
646-
set_backend (dso=True, pte=False, aoti_package=False)
647-
if (model.config != config):
654+
655+
set_backend(dso=True, pte=False, aoti_package=False)
656+
if model.config != config:
648657
raise RuntimeError("loaded model architecture mismatch")
649658
##
650659
## import all libraries with custom kernels ans custom operators
@@ -662,7 +671,9 @@ def do_nothing(max_batch_size, max_seq_length):
662671
logger = SingletonLogger.get_logger()
663672

664673
gpu_memory_monitor = GPUMemoryMonitor("cuda")
665-
logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}")
674+
logger.info(
675+
f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}"
676+
)
666677

667678
# Model-level config
668679
if builder_args.params_table:
@@ -673,20 +684,16 @@ def do_nothing(max_batch_size, max_seq_length):
673684
config = TransformerArgs.from_params(model_config.transformer_args["text"])
674685
logger.info(f"Transformer Config: {config}")
675686

676-
#TODO: Move into head of file after solving circular import
677-
from torchchat.distributed.checkpoint_utils import (
678-
load_model_weights,
679-
)
687+
# TODO: Move into head of file after solving circular import
688+
from torchchat.distributed.checkpoint_utils import load_model_weights
680689

681690
# Validate pipeline degree
682691
assert config.n_layers % pp_degree == 0
683692

684693
# Create device mesh
685694
device_mesh = dist.init_device_mesh(
686-
"cuda",
687-
(pp_degree, tp_degree),
688-
mesh_dim_names=("pp", "tp")
689-
)
695+
"cuda", (pp_degree, tp_degree), mesh_dim_names=("pp", "tp")
696+
)
690697
tp_mesh = device_mesh["tp"]
691698
pp_mesh = device_mesh["pp"]
692699
logger.info(f"Created device mesh: {device_mesh}\n{tp_mesh=}, {pp_mesh=}")
@@ -715,7 +722,13 @@ def do_nothing(max_batch_size, max_seq_length):
715722
# Load weights
716723
logger.info(f"Loading weights for {pp_rank=} on {device=}")
717724
with CUDATrackTime() as timer:
718-
load_model_weights(model, builder_args.distribution_path, device, config, builder_args.chpt_from)
725+
load_model_weights(
726+
model,
727+
builder_args.distribution_path,
728+
device,
729+
config,
730+
builder_args.chpt_from,
731+
)
719732

720733
logger.info(
721734
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
@@ -729,7 +742,7 @@ def do_nothing(max_batch_size, max_seq_length):
729742
# lanes.
730743
# TODO: bump up the lane count
731744
pipeline_lanes = 1
732-
seqlen_prefill=1024
745+
seqlen_prefill = 1024
733746
with device:
734747
model.setup_caches(1, seqlen_prefill, cache_lanes=pipeline_lanes)
735748

0 commit comments

Comments
 (0)