Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 63be93a

Browse files
zhenyanzhangzhenyan-zhang-meta
authored andcommitted
Simplify TokenizerArgs.__post_init__ with Enum Tokenizer Type
Summary: Simplify `TokenizerArgs.__post_init__` with enum tokenizer type, since only one of the tokenizer type can be true. We want to touch as less code outside of `__post_init__` as possible at the moment. Test Plan: python torchchat.py generate llama2|llama3|granite-code Reviewers: @Jack-Khuu Subscribers: Issue: #1518
1 parent 98eaf8f commit 63be93a

File tree

1 file changed

+48
-49
lines changed

1 file changed

+48
-49
lines changed

torchchat/cli/builder.py

Lines changed: 48 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616
import torch._dynamo.config
1717
import torch._inductor.config
1818
import torch.distributed as dist
19+
from torchchat.distributed.logging_utils import SingletonLogger
1920

20-
from torchchat.distributed.utils import(
21+
from torchchat.distributed.utils import (
2122
Color as color,
2223
CUDATrackTime,
23-
init_distributed,
2424
GPUMemoryMonitor,
25+
init_distributed,
2526
)
26-
from torchchat.distributed.logging_utils import SingletonLogger
2727

2828
from torchchat.model import Model, ModelArgs, ModelType, Transformer, TransformerArgs
2929
from torchchat.model_config.model_config import resolve_model_config
@@ -36,7 +36,6 @@
3636
from torchchat.utils.measure_time import measure_time
3737
from torchchat.utils.quantize import quantize_model
3838

39-
4039
from torchtune.models.convert_weights import meta_to_tune
4140

4241
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
@@ -188,15 +187,19 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
188187
tp = getattr(args, "tp", 1)
189188
chpt_from = getattr(args, "chpt_from", "hf")
190189
sdp_backend_dict = {
191-
'math': torch.nn.attention.SDPBackend.MATH,
192-
'flash_attention': torch.nn.attention.SDPBackend.FLASH_ATTENTION,
193-
'efficient_attention': torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
194-
'cudnn_attention': torch.nn.attention.SDPBackend.CUDNN_ATTENTION,
190+
"math": torch.nn.attention.SDPBackend.MATH,
191+
"flash_attention": torch.nn.attention.SDPBackend.FLASH_ATTENTION,
192+
"efficient_attention": torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
193+
"cudnn_attention": torch.nn.attention.SDPBackend.CUDNN_ATTENTION,
195194
}
196195
attention_backend = sdp_backend_dict[args.attention_backend]
197-
if args.device == "cpu" and (args.attention_backend == "efficient_attention"
198-
or args.attention_backend == "cudnn_attention"):
199-
print(f"Warning: {args.attention_backend} is not supported on CPU. Using math instead.")
196+
if args.device == "cpu" and (
197+
args.attention_backend == "efficient_attention"
198+
or args.attention_backend == "cudnn_attention"
199+
):
200+
print(
201+
f"Warning: {args.attention_backend} is not supported on CPU. Using math instead."
202+
)
200203
attention_backend = torch.nn.attention.SDPBackend.MATH
201204
return cls(
202205
checkpoint_dir=checkpoint_dir,
@@ -238,11 +241,6 @@ def from_speculative_args(cls, args: argparse.Namespace) -> "BuilderArgs":
238241
speculative_builder_args.pte_path = None
239242
return speculative_builder_args
240243

241-
class TokenizerType(Enum):
242-
NONE = 0
243-
TIKTOKEN = 1
244-
SENTENCEPIECE = 2
245-
HF_TOKENIZER = 3
246244

247245
@dataclass
248246
class TokenizerArgs:
@@ -278,19 +276,12 @@ def __post_init__(self):
278276
except:
279277
pass
280278

281-
self.tokenizer_type = TokenizerType.NONE
279+
self.is_tiktoken = False
280+
self.is_sentencepiece = False
281+
self.is_hf_tokenizer = False
282282
self.t = None
283283
return
284284

285-
def is_tiktoken(self) -> bool:
286-
return self.tokenizer_type == TokenizerType.TIKTOKEN
287-
288-
def is_sentencepiece(self) -> bool:
289-
return self.tokenizer_type == TokenizerType.SENTENCEPIECE
290-
291-
def is_hf_tokenizer(self) -> bool:
292-
return self.tokenizer_type == TokenizerType.HF_TOKENIZER
293-
294285
def validate_model(
295286
self,
296287
model: Optional[Model],
@@ -299,22 +290,20 @@ def validate_model(
299290
if model is None:
300291
return
301292

302-
303-
is_tiktoken = self.is_tiktoken()
304-
is_sentencepiece = self.is_sentencepiece()
305-
is_hf_tokenizer = self.is_hf_tokenizer()
306-
307-
if sum([is_tiktoken, is_hf_tokenizer, is_sentencepiece]) != 1:
293+
if sum([self.is_tiktoken, self.is_hf_tokenizer, self.is_sentencepiece]) != 1:
308294
raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}")
309295

296+
is_tiktoken = self.is_tiktoken
297+
is_sentencepiece = self.is_sentencepiece
298+
is_hf_tokenizer = self.is_hf_tokenizer
310299
use_tiktoken = model.config.use_tiktoken
311300
use_hf_tokenizer = model.config.use_hf_tokenizer
312301
use_sentencepiece = not (use_tiktoken or use_hf_tokenizer)
313302

314303
if (
315-
(is_tiktoken and not use_tiktoken) or
316-
(is_hf_tokenizer and not use_hf_tokenizer) or
317-
(is_sentencepiece and not use_sentencepiece)
304+
(is_tiktoken and not use_tiktoken)
305+
or (is_hf_tokenizer and not use_hf_tokenizer)
306+
or (is_sentencepiece and not use_sentencepiece)
318307
):
319308
raise RuntimeError(
320309
"model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}".format(
@@ -512,6 +501,7 @@ def _load_model(builder_args: BuilderArgs) -> Model:
512501
# AOTI-compoiled model will load its own weights.
513502
# Release weights here to avoid OOM
514503
import gc
504+
515505
if hasattr(model, "model"):
516506
model.model = None
517507
gc.collect()
@@ -569,6 +559,7 @@ def _initialize_model(
569559

570560
def do_nothing(max_batch_size, max_seq_length):
571561
pass
562+
572563
model.setup_caches = do_nothing
573564

574565
model.forward = torch._export.aot_load(
@@ -606,6 +597,7 @@ def do_nothing(max_batch_size, max_seq_length):
606597

607598
def do_nothing(max_batch_size, max_seq_length):
608599
pass
600+
609601
model.setup_caches = do_nothing
610602

611603
model.forward = aoti_compiled_model
@@ -657,12 +649,15 @@ def do_nothing(max_batch_size, max_seq_length):
657649
try:
658650
model = torch.load(builder_args.snapshot_path, weights_only=False)
659651
except Exception:
660-
raise RuntimeError(f"Failed to load torchchat snapshot {builder_args.snapshot_path}")
652+
raise RuntimeError(
653+
f"Failed to load torchchat snapshot {builder_args.snapshot_path}"
654+
)
661655
# _active_backend() does not allow DSO & AOTI to be true.
662656
# Choose either.
663657
from torchchat.utils.build_utils import set_backend
664-
set_backend (dso=True, pte=False, aoti_package=False)
665-
if (model.config != config):
658+
659+
set_backend(dso=True, pte=False, aoti_package=False)
660+
if model.config != config:
666661
raise RuntimeError("loaded model architecture mismatch")
667662
##
668663
## import all libraries with custom kernels ans custom operators
@@ -680,7 +675,9 @@ def do_nothing(max_batch_size, max_seq_length):
680675
logger = SingletonLogger.get_logger()
681676

682677
gpu_memory_monitor = GPUMemoryMonitor("cuda")
683-
logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}")
678+
logger.info(
679+
f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}"
680+
)
684681

685682
# Model-level config
686683
if builder_args.params_table:
@@ -691,20 +688,16 @@ def do_nothing(max_batch_size, max_seq_length):
691688
config = TransformerArgs.from_params(model_config.transformer_args["text"])
692689
logger.info(f"Transformer Config: {config}")
693690

694-
#TODO: Move into head of file after solving circular import
695-
from torchchat.distributed.checkpoint_utils import (
696-
load_model_weights,
697-
)
691+
# TODO: Move into head of file after solving circular import
692+
from torchchat.distributed.checkpoint_utils import load_model_weights
698693

699694
# Validate pipeline degree
700695
assert config.n_layers % pp_degree == 0
701696

702697
# Create device mesh
703698
device_mesh = dist.init_device_mesh(
704-
"cuda",
705-
(pp_degree, tp_degree),
706-
mesh_dim_names=("pp", "tp")
707-
)
699+
"cuda", (pp_degree, tp_degree), mesh_dim_names=("pp", "tp")
700+
)
708701
tp_mesh = device_mesh["tp"]
709702
pp_mesh = device_mesh["pp"]
710703
logger.info(f"Created device mesh: {device_mesh}\n{tp_mesh=}, {pp_mesh=}")
@@ -733,7 +726,13 @@ def do_nothing(max_batch_size, max_seq_length):
733726
# Load weights
734727
logger.info(f"Loading weights for {pp_rank=} on {device=}")
735728
with CUDATrackTime() as timer:
736-
load_model_weights(model, builder_args.distribution_path, device, config, builder_args.chpt_from)
729+
load_model_weights(
730+
model,
731+
builder_args.distribution_path,
732+
device,
733+
config,
734+
builder_args.chpt_from,
735+
)
737736

738737
logger.info(
739738
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
@@ -747,7 +746,7 @@ def do_nothing(max_batch_size, max_seq_length):
747746
# lanes.
748747
# TODO: bump up the lane count
749748
pipeline_lanes = 1
750-
seqlen_prefill=1024
749+
seqlen_prefill = 1024
751750
with device:
752751
model.setup_caches(1, seqlen_prefill, cache_lanes=pipeline_lanes)
753752

0 commit comments

Comments
 (0)