Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions install/install_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ then
torchvision=="0.22.0.${VISION_NIGHTLY_VERSION}"
#torchtune=="0.7.0" # no 0.6.0 on xpu nightly
)
elif [[ -x "$(command -v npu-smi)" ]];
then
REQUIREMENTS_TO_INSTALL=(
torch=="2.7.0.dev20250310+cpu"
torchvision=="0.22.0.dev20250310"
torchtune=="0.6.0"
)
else
REQUIREMENTS_TO_INSTALL=(
torch=="2.8.0.${PYTORCH_NIGHTLY_VERSION}"
Expand Down
12 changes: 5 additions & 7 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from torchchat.utils.build_utils import (
device_sync,
is_cpu_device,
is_cuda_or_cpu_or_xpu_device,
is_supported_device,
name_to_dtype,
)
from torchchat.utils.measure_time import measure_time
Expand Down Expand Up @@ -74,10 +74,8 @@ class BuilderArgs:

def __post_init__(self):
if self.device is None:
if torch.cuda.is_available():
self.device = "cuda"
elif torch.xpu.is_available():
self.device = "xpu"
if torch.accelerator.is_available():
self.device = torch.accelerator.current_accelerator().type
else:
self.device = "cpu"

Expand Down Expand Up @@ -539,7 +537,7 @@ def _initialize_model(
_set_gguf_kwargs(builder_args, is_et=is_pte, context="generate")

if builder_args.dso_path:
if not is_cuda_or_cpu_or_xpu_device(builder_args.device):
if not is_supported_device(builder_args.device):
print(
f"Cannot load specified DSO to {builder_args.device}. Attempting to load model to CPU instead"
)
Expand Down Expand Up @@ -573,7 +571,7 @@ def do_nothing(max_batch_size, max_seq_length):
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.dso_path}")

elif builder_args.aoti_package_path:
if not is_cuda_or_cpu_or_xpu_device(builder_args.device):
if not is_supported_device(builder_args.device):
print(
f"Cannot load specified PT2 to {builder_args.device}. Attempting to load model to CPU instead"
)
Expand Down
4 changes: 2 additions & 2 deletions torchchat/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ def _add_model_config_args(parser, verb: str) -> None:
"--device",
type=str,
default=None,
choices=["fast", "cpu", "cuda", "mps", "xpu"],
help="Hardware device to use. Options: fast, cpu, cuda, mps, xpu",
choices=["fast", "cpu", "cuda", "mps", "xpu", "npu"],
help="Hardware device to use. Options: fast, cpu, cuda, mps, xpu, npu",
)
model_config_parser.add_argument(
"--attention-backend",
Expand Down
13 changes: 9 additions & 4 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,8 +1213,10 @@ def callback(x, *, done_generating=False):
print(prof.key_averages().table(sort_by="self_cpu_time_total"))
elif self.builder_args.device == "cuda":
print(prof.key_averages().table(sort_by="self_cuda_time_total"))
else:
elif self.builder_args.device == "xpu":
print(prof.key_averages().table(sort_by="self_xpu_time_total"))
elif self.builder_args.device == "npu":
print(prof.key_averages().table(sort_by="self_npu_time_total"))
prof.export_chrome_trace(f"{self.profile}.json")

if start_pos >= max_seq_length:
Expand Down Expand Up @@ -1299,8 +1301,10 @@ def callback(x, *, done_generating=False):
)
if torch.cuda.is_available():
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
if torch.xpu.is_available():
elif torch.xpu.is_available():
print(f"Memory used: {torch.xpu.max_memory_reserved() / 1e9:.02f} GB")
elif hasattr(torch, "npu") and torch.npu.is_available():
print(f"Memory used: {torch.npu.max_memory_reserved() / 1e9:.02f} GB")



Expand Down Expand Up @@ -1595,7 +1599,6 @@ def sample(

return idx_next, probs


def run_generator(
args,
rank: Optional[int] =None
Expand Down Expand Up @@ -1628,8 +1631,10 @@ def run_generator(
)
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
if torch.xpu.is_available():
elif torch.xpu.is_available():
torch.xpu.reset_peak_memory_stats()
elif hasattr(torch, "npu") and torch.npu.is_available():
torch.npu.reset_peak_memory_stats()

for _ in gen.chat(generator_args):
pass
Expand Down
29 changes: 15 additions & 14 deletions torchchat/utils/build_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ def device_sync(device="cpu"):
torch.cuda.synchronize(device)
elif "xpu" in device:
torch.xpu.synchronize(device)
elif "npu" in device:
torch.npu.synchronize(device)
elif ("cpu" in device) or ("mps" in device):
pass
else:
Expand Down Expand Up @@ -275,33 +277,32 @@ def is_mps_available() -> bool:
# MPS, is that you?
return True

def select_device() -> str:
if torch.accelerator.is_available():
device = torch.accelerator.current_accelerator().type
if device == "mps" and not is_mps_available():
return "cpu"
return device
else:
return "cpu"

def get_device_str(device) -> str:
if isinstance(device, str) and device == "fast":
device = (
"cuda"
if torch.cuda.is_available()
else "mps" if is_mps_available()
else "xpu" if torch.xpu.is_available() else "cpu"
)
device = select_device()
return device
else:
return str(device)


def get_device(device) -> str:
if isinstance(device, str) and device == "fast":
device = (
"cuda"
if torch.cuda.is_available()
else "mps" if is_mps_available()
else "xpu" if torch.xpu.is_available() else "cpu"
)
device = select_device()
return torch.device(device)


def is_cpu_device(device) -> bool:
return device == "" or str(device) == "cpu"

def is_cuda_or_cpu_or_xpu_device(device) -> bool:
return is_cpu_device(device) or ("cuda" in str(device)) or ("xpu" in str(device))
def is_supported_device(device) -> bool:
device_str = str(device)
return is_cpu_device(device) or any(dev in device_str for dev in ('cuda', 'xpu', 'npu'))
5 changes: 3 additions & 2 deletions torchchat/utils/device_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@

import torch


def get_device_info(device: str) -> str:
"""Returns a human-readable description of the hardware based on a torch.device.type

Args:
device: A torch.device.type string: one of {"cpu", "cuda", "xpu"}.
device: A torch.device.type string: one of {"cpu", "cuda", "xpu", "npu"}.
Returns:
str: A human-readable description of the hardware or an empty string if the device type is unhandled.

Expand Down Expand Up @@ -46,4 +45,6 @@ def get_device_info(device: str) -> str:
.split("\n")[0]
.split("Device Name:")[1]
)
if device == "npu":
return torch.npu.get_device_name(0)
return ""
2 changes: 1 addition & 1 deletion torchchat/utils/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def quantize_model(
raise RuntimeError(f"unknown quantizer {quantizer} specified")
else:
# Use tensor subclass API for int4 weight only.
if (device == "cuda" or device == "xpu") and quantizer == "linear:int4":
if (device in ["cuda", "xpu", "npu"]) and quantizer == "linear:int4":
quantize_(model, int4_weight_only(q_kwargs["groupsize"]))
if not support_tensor_subclass:
unwrap_tensor_subclass(model)
Expand Down