Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions install/install_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ then
elif [[ -x "$(command -v xpu-smi)" ]];
then
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/xpu"
elif [[ -x "$(command -v npu-smi)" ]]
then
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/test/cpu"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we using a test wheel?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @Jack-Khuu thanks for your review. IMO, we should use nightly pytorch wheels like other backends. What do you think? @hipudding @xuedinge233

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But, we usually use pytorch RC versions, not sure if everything works fine on nightly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Torch_npu is currently released following the test version of the PyTorch wheels. but after testing, nightly is also effective, so following the source code, modify it to the nightly version.

The latest version of torch_npu is 2.7.0, so torch has been fixed at version 2.7.0

else
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/cpu"
fi
Expand All @@ -83,6 +86,13 @@ then
torchvision=="0.22.0.${VISION_NIGHTLY_VERSION}"
#torchtune=="0.7.0" # no 0.6.0 on xpu nightly
)
elif [[ -x "$(command -v npu-smi)" ]];
then
REQUIREMENTS_TO_INSTALL=(
torch=="2.7.0"
torchvision=="0.22.0"
torchtune=="0.6.0"
)
else
REQUIREMENTS_TO_INSTALL=(
torch=="2.8.0.${PYTORCH_NIGHTLY_VERSION}"
Expand Down
8 changes: 5 additions & 3 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from torchchat.utils.build_utils import (
device_sync,
is_cpu_device,
is_cuda_or_cpu_or_xpu_device,
is_supported_device,
name_to_dtype,
)
from torchchat.utils.measure_time import measure_time
Expand Down Expand Up @@ -78,6 +78,8 @@ def __post_init__(self):
self.device = "cuda"
elif torch.xpu.is_available():
self.device = "xpu"
elif hasattr(torch, "npu") and torch.npu.is_available():
self.device = "npu"
else:
self.device = "cpu"

Expand Down Expand Up @@ -539,7 +541,7 @@ def _initialize_model(
_set_gguf_kwargs(builder_args, is_et=is_pte, context="generate")

if builder_args.dso_path:
if not is_cuda_or_cpu_or_xpu_device(builder_args.device):
if not is_supported_device(builder_args.device):
print(
f"Cannot load specified DSO to {builder_args.device}. Attempting to load model to CPU instead"
)
Expand Down Expand Up @@ -573,7 +575,7 @@ def do_nothing(max_batch_size, max_seq_length):
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.dso_path}")

elif builder_args.aoti_package_path:
if not is_cuda_or_cpu_or_xpu_device(builder_args.device):
if not is_supported_device(builder_args.device):
print(
f"Cannot load specified PT2 to {builder_args.device}. Attempting to load model to CPU instead"
)
Expand Down
4 changes: 2 additions & 2 deletions torchchat/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ def _add_model_config_args(parser, verb: str) -> None:
"--device",
type=str,
default=None,
choices=["fast", "cpu", "cuda", "mps", "xpu"],
help="Hardware device to use. Options: fast, cpu, cuda, mps, xpu",
choices=["fast", "cpu", "cuda", "mps", "xpu", "npu"],
help="Hardware device to use. Options: fast, cpu, cuda, mps, xpu, npu",
)
model_config_parser.add_argument(
"--attention-backend",
Expand Down
13 changes: 9 additions & 4 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,8 +1213,10 @@ def callback(x, *, done_generating=False):
print(prof.key_averages().table(sort_by="self_cpu_time_total"))
elif self.builder_args.device == "cuda":
print(prof.key_averages().table(sort_by="self_cuda_time_total"))
else:
elif self.builder_args.device == "xpu":
print(prof.key_averages().table(sort_by="self_xpu_time_total"))
elif self.builder_args.device == "npu":
print(prof.key_averages().table(sort_by="self_npu_time_total"))
prof.export_chrome_trace(f"{self.profile}.json")

if start_pos >= max_seq_length:
Expand Down Expand Up @@ -1299,8 +1301,10 @@ def callback(x, *, done_generating=False):
)
if torch.cuda.is_available():
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
if torch.xpu.is_available():
elif torch.xpu.is_available():
print(f"Memory used: {torch.xpu.max_memory_reserved() / 1e9:.02f} GB")
elif hasattr(torch, "npu") and torch.npu.is_available():
print(f"Memory used: {torch.npu.max_memory_reserved() / 1e9:.02f} GB")



Expand Down Expand Up @@ -1595,7 +1599,6 @@ def sample(

return idx_next, probs


def run_generator(
args,
rank: Optional[int] =None
Expand Down Expand Up @@ -1628,8 +1631,10 @@ def run_generator(
)
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
if torch.xpu.is_available():
elif torch.xpu.is_available():
torch.xpu.reset_peak_memory_stats()
elif hasattr(torch, "npu") and torch.npu.is_available():
torch.npu.reset_peak_memory_stats()

for _ in gen.chat(generator_args):
pass
Expand Down
33 changes: 19 additions & 14 deletions torchchat/utils/build_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ def device_sync(device="cpu"):
torch.cuda.synchronize(device)
elif "xpu" in device:
torch.xpu.synchronize(device)
elif "npu" in device:
torch.npu.synchronize(device)
elif ("cpu" in device) or ("mps" in device):
pass
else:
Expand Down Expand Up @@ -275,33 +277,36 @@ def is_mps_available() -> bool:
# MPS, is that you?
return True

def select_device(device) -> str:
if torch.cuda.is_available():
return "cuda"
elif is_mps_available():
return "mps"
elif hasattr(torch, "npu") and torch.npu.is_available():
return "npu"
elif torch.xpu.is_available():
return "xpu"
else:
return "cpu"


def get_device_str(device) -> str:
if isinstance(device, str) and device == "fast":
device = (
"cuda"
if torch.cuda.is_available()
else "mps" if is_mps_available()
else "xpu" if torch.xpu.is_available() else "cpu"
)
device = select_device(device)
return device
else:
return str(device)


def get_device(device) -> str:
if isinstance(device, str) and device == "fast":
device = (
"cuda"
if torch.cuda.is_available()
else "mps" if is_mps_available()
else "xpu" if torch.xpu.is_available() else "cpu"
)
device = select_device(device)
return torch.device(device)


def is_cpu_device(device) -> bool:
return device == "" or str(device) == "cpu"

def is_cuda_or_cpu_or_xpu_device(device) -> bool:
return is_cpu_device(device) or ("cuda" in str(device)) or ("xpu" in str(device))
def is_supported_device(device) -> bool:
device_str = str(device)
return is_cpu_device(device) or any(dev in device_str for dev in ('cuda', 'xpu', 'npu'))
5 changes: 3 additions & 2 deletions torchchat/utils/device_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@

import torch


def get_device_info(device: str) -> str:
"""Returns a human-readable description of the hardware based on a torch.device.type

Args:
device: A torch.device.type string: one of {"cpu", "cuda", "xpu"}.
device: A torch.device.type string: one of {"cpu", "cuda", "xpu", "npu"}.
Returns:
str: A human-readable description of the hardware or an empty string if the device type is unhandled.

Expand Down Expand Up @@ -46,4 +45,6 @@ def get_device_info(device: str) -> str:
.split("\n")[0]
.split("Device Name:")[1]
)
if device == "npu":
return torch.npu.get_device_name(0)
return ""
2 changes: 1 addition & 1 deletion torchchat/utils/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def quantize_model(
raise RuntimeError(f"unknown quantizer {quantizer} specified")
else:
# Use tensor subclass API for int4 weight only.
if (device == "cuda" or device == "xpu") and quantizer == "linear:int4":
if (device in ["cuda", "xpu", "npu"]) and quantizer == "linear:int4":
quantize_(model, int4_weight_only(q_kwargs["groupsize"]))
if not support_tensor_subclass:
unwrap_tensor_subclass(model)
Expand Down