Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit b896262

Browse files
Rollback to 98eaf8f
1 parent 379c07b commit b896262

File tree

1 file changed

+33
-50
lines changed

1 file changed

+33
-50
lines changed

torchchat/cli/builder.py

Lines changed: 33 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616
import torch._dynamo.config
1717
import torch._inductor.config
1818
import torch.distributed as dist
19-
from torchchat.distributed.logging_utils import SingletonLogger
2019

21-
from torchchat.distributed.utils import (
20+
from torchchat.distributed.utils import(
2221
Color as color,
2322
CUDATrackTime,
24-
GPUMemoryMonitor,
2523
init_distributed,
24+
GPUMemoryMonitor,
2625
)
26+
from torchchat.distributed.logging_utils import SingletonLogger
2727

2828
from torchchat.model import Model, ModelArgs, ModelType, Transformer, TransformerArgs
2929
from torchchat.model_config.model_config import resolve_model_config
@@ -36,6 +36,7 @@
3636
from torchchat.utils.measure_time import measure_time
3737
from torchchat.utils.quantize import quantize_model
3838

39+
3940
from torchtune.models.convert_weights import meta_to_tune
4041

4142
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
@@ -187,19 +188,15 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
187188
tp = getattr(args, "tp", 1)
188189
chpt_from = getattr(args, "chpt_from", "hf")
189190
sdp_backend_dict = {
190-
"math": torch.nn.attention.SDPBackend.MATH,
191-
"flash_attention": torch.nn.attention.SDPBackend.FLASH_ATTENTION,
192-
"efficient_attention": torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
193-
"cudnn_attention": torch.nn.attention.SDPBackend.CUDNN_ATTENTION,
191+
'math': torch.nn.attention.SDPBackend.MATH,
192+
'flash_attention': torch.nn.attention.SDPBackend.FLASH_ATTENTION,
193+
'efficient_attention': torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
194+
'cudnn_attention': torch.nn.attention.SDPBackend.CUDNN_ATTENTION,
194195
}
195196
attention_backend = sdp_backend_dict[args.attention_backend]
196-
if args.device == "cpu" and (
197-
args.attention_backend == "efficient_attention"
198-
or args.attention_backend == "cudnn_attention"
199-
):
200-
print(
201-
f"Warning: {args.attention_backend} is not supported on CPU. Using math instead."
202-
)
197+
if args.device == "cpu" and (args.attention_backend == "efficient_attention"
198+
or args.attention_backend == "cudnn_attention"):
199+
print(f"Warning: {args.attention_backend} is not supported on CPU. Using math instead.")
203200
attention_backend = torch.nn.attention.SDPBackend.MATH
204201
return cls(
205202
checkpoint_dir=checkpoint_dir,
@@ -281,6 +278,8 @@ def __post_init__(self):
281278
except:
282279
pass
283280

281+
self.tokenizer_type = TokenizerType.NONE
282+
self.t = None
284283
return
285284

286285
def is_tiktoken(self) -> bool:
@@ -292,13 +291,6 @@ def is_sentencepiece(self) -> bool:
292291
def is_hf_tokenizer(self) -> bool:
293292
return self.tokenizer_type == TokenizerType.HF_TOKENIZER
294293

295-
def is_tokenizer_none(self) -> bool:
296-
if self.tokenizer_type != TokenizerType.NONE:
297-
return False
298-
299-
assert self.t is None, "tokenizer_type is NONE but t is not None"
300-
return True
301-
302294
def validate_model(
303295
self,
304296
model: Optional[Model],
@@ -307,21 +299,22 @@ def validate_model(
307299
if model is None:
308300
return
309301

310-
if self.is_tokenizer_none():
311-
raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}")
312302

313303
is_tiktoken = self.is_tiktoken()
314304
is_sentencepiece = self.is_sentencepiece()
315305
is_hf_tokenizer = self.is_hf_tokenizer()
316306

307+
if sum([is_tiktoken, is_hf_tokenizer, is_sentencepiece]) != 1:
308+
raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}")
309+
317310
use_tiktoken = model.config.use_tiktoken
318311
use_hf_tokenizer = model.config.use_hf_tokenizer
319312
use_sentencepiece = not (use_tiktoken or use_hf_tokenizer)
320313

321314
if (
322-
(is_tiktoken and not use_tiktoken)
323-
or (is_hf_tokenizer and not use_hf_tokenizer)
324-
or (is_sentencepiece and not use_sentencepiece)
315+
(is_tiktoken and not use_tiktoken) or
316+
(is_hf_tokenizer and not use_hf_tokenizer) or
317+
(is_sentencepiece and not use_sentencepiece)
325318
):
326319
raise RuntimeError(
327320
"model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}".format(
@@ -519,7 +512,6 @@ def _load_model(builder_args: BuilderArgs) -> Model:
519512
# AOTI-compoiled model will load its own weights.
520513
# Release weights here to avoid OOM
521514
import gc
522-
523515
if hasattr(model, "model"):
524516
model.model = None
525517
gc.collect()
@@ -577,7 +569,6 @@ def _initialize_model(
577569

578570
def do_nothing(max_batch_size, max_seq_length):
579571
pass
580-
581572
model.setup_caches = do_nothing
582573

583574
model.forward = torch._export.aot_load(
@@ -615,7 +606,6 @@ def do_nothing(max_batch_size, max_seq_length):
615606

616607
def do_nothing(max_batch_size, max_seq_length):
617608
pass
618-
619609
model.setup_caches = do_nothing
620610

621611
model.forward = aoti_compiled_model
@@ -667,15 +657,12 @@ def do_nothing(max_batch_size, max_seq_length):
667657
try:
668658
model = torch.load(builder_args.snapshot_path, weights_only=False)
669659
except Exception:
670-
raise RuntimeError(
671-
f"Failed to load torchchat snapshot {builder_args.snapshot_path}"
672-
)
660+
raise RuntimeError(f"Failed to load torchchat snapshot {builder_args.snapshot_path}")
673661
# _active_backend() does not allow DSO & AOTI to be true.
674662
# Choose either.
675663
from torchchat.utils.build_utils import set_backend
676-
677-
set_backend(dso=True, pte=False, aoti_package=False)
678-
if model.config != config:
664+
set_backend (dso=True, pte=False, aoti_package=False)
665+
if (model.config != config):
679666
raise RuntimeError("loaded model architecture mismatch")
680667
##
681668
## import all libraries with custom kernels ans custom operators
@@ -693,9 +680,7 @@ def do_nothing(max_batch_size, max_seq_length):
693680
logger = SingletonLogger.get_logger()
694681

695682
gpu_memory_monitor = GPUMemoryMonitor("cuda")
696-
logger.info(
697-
f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}"
698-
)
683+
logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}")
699684

700685
# Model-level config
701686
if builder_args.params_table:
@@ -706,16 +691,20 @@ def do_nothing(max_batch_size, max_seq_length):
706691
config = TransformerArgs.from_params(model_config.transformer_args["text"])
707692
logger.info(f"Transformer Config: {config}")
708693

709-
# TODO: Move into head of file after solving circular import
710-
from torchchat.distributed.checkpoint_utils import load_model_weights
694+
#TODO: Move into head of file after solving circular import
695+
from torchchat.distributed.checkpoint_utils import (
696+
load_model_weights,
697+
)
711698

712699
# Validate pipeline degree
713700
assert config.n_layers % pp_degree == 0
714701

715702
# Create device mesh
716703
device_mesh = dist.init_device_mesh(
717-
"cuda", (pp_degree, tp_degree), mesh_dim_names=("pp", "tp")
718-
)
704+
"cuda",
705+
(pp_degree, tp_degree),
706+
mesh_dim_names=("pp", "tp")
707+
)
719708
tp_mesh = device_mesh["tp"]
720709
pp_mesh = device_mesh["pp"]
721710
logger.info(f"Created device mesh: {device_mesh}\n{tp_mesh=}, {pp_mesh=}")
@@ -744,13 +733,7 @@ def do_nothing(max_batch_size, max_seq_length):
744733
# Load weights
745734
logger.info(f"Loading weights for {pp_rank=} on {device=}")
746735
with CUDATrackTime() as timer:
747-
load_model_weights(
748-
model,
749-
builder_args.distribution_path,
750-
device,
751-
config,
752-
builder_args.chpt_from,
753-
)
736+
load_model_weights(model, builder_args.distribution_path, device, config, builder_args.chpt_from)
754737

755738
logger.info(
756739
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
@@ -764,7 +747,7 @@ def do_nothing(max_batch_size, max_seq_length):
764747
# lanes.
765748
# TODO: bump up the lane count
766749
pipeline_lanes = 1
767-
seqlen_prefill = 1024
750+
seqlen_prefill=1024
768751
with device:
769752
model.setup_caches(1, seqlen_prefill, cache_lanes=pipeline_lanes)
770753

0 commit comments

Comments
 (0)