diff --git a/install/install_torch.sh b/install/install_torch.sh index 26b8a830e..cc7bccaf0 100755 --- a/install/install_torch.sh +++ b/install/install_torch.sh @@ -66,6 +66,13 @@ then torchvision=="0.22.0.${VISION_NIGHTLY_VERSION}" #torchtune=="0.7.0" # no 0.6.0 on xpu nightly ) +elif [[ -x "$(command -v npu-smi)" ]]; +then + REQUIREMENTS_TO_INSTALL=( + torch=="2.7.0.dev20250310+cpu" + torchvision=="0.22.0.dev20250310" + torchtune=="0.6.0" + ) else REQUIREMENTS_TO_INSTALL=( torch=="2.8.0.${PYTORCH_NIGHTLY_VERSION}" diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index fcc2d5f66..12463dd78 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -29,7 +29,7 @@ from torchchat.utils.build_utils import ( device_sync, is_cpu_device, - is_cuda_or_cpu_or_xpu_device, + is_supported_device, name_to_dtype, ) from torchchat.utils.measure_time import measure_time @@ -74,10 +74,8 @@ class BuilderArgs: def __post_init__(self): if self.device is None: - if torch.cuda.is_available(): - self.device = "cuda" - elif torch.xpu.is_available(): - self.device = "xpu" + if torch.accelerator.is_available(): + self.device = torch.accelerator.current_accelerator().type else: self.device = "cpu" @@ -539,7 +537,7 @@ def _initialize_model( _set_gguf_kwargs(builder_args, is_et=is_pte, context="generate") if builder_args.dso_path: - if not is_cuda_or_cpu_or_xpu_device(builder_args.device): + if not is_supported_device(builder_args.device): print( f"Cannot load specified DSO to {builder_args.device}. Attempting to load model to CPU instead" ) @@ -573,7 +571,7 @@ def do_nothing(max_batch_size, max_seq_length): raise RuntimeError(f"Failed to load AOTI compiled {builder_args.dso_path}") elif builder_args.aoti_package_path: - if not is_cuda_or_cpu_or_xpu_device(builder_args.device): + if not is_supported_device(builder_args.device): print( f"Cannot load specified PT2 to {builder_args.device}. Attempting to load model to CPU instead" ) diff --git a/torchchat/cli/cli.py b/torchchat/cli/cli.py index 2198ac819..007fe0c8f 100644 --- a/torchchat/cli/cli.py +++ b/torchchat/cli/cli.py @@ -176,8 +176,8 @@ def _add_model_config_args(parser, verb: str) -> None: "--device", type=str, default=None, - choices=["fast", "cpu", "cuda", "mps", "xpu"], - help="Hardware device to use. Options: fast, cpu, cuda, mps, xpu", + choices=["fast", "cpu", "cuda", "mps", "xpu", "npu"], + help="Hardware device to use. Options: fast, cpu, cuda, mps, xpu, npu", ) model_config_parser.add_argument( "--attention-backend", diff --git a/torchchat/generate.py b/torchchat/generate.py index 53e855483..8555b85bd 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -1213,8 +1213,10 @@ def callback(x, *, done_generating=False): print(prof.key_averages().table(sort_by="self_cpu_time_total")) elif self.builder_args.device == "cuda": print(prof.key_averages().table(sort_by="self_cuda_time_total")) - else: + elif self.builder_args.device == "xpu": print(prof.key_averages().table(sort_by="self_xpu_time_total")) + elif self.builder_args.device == "npu": + print(prof.key_averages().table(sort_by="self_npu_time_total")) prof.export_chrome_trace(f"{self.profile}.json") if start_pos >= max_seq_length: @@ -1299,8 +1301,10 @@ def callback(x, *, done_generating=False): ) if torch.cuda.is_available(): print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") - if torch.xpu.is_available(): + elif torch.xpu.is_available(): print(f"Memory used: {torch.xpu.max_memory_reserved() / 1e9:.02f} GB") + elif hasattr(torch, "npu") and torch.npu.is_available(): + print(f"Memory used: {torch.npu.max_memory_reserved() / 1e9:.02f} GB") @@ -1595,7 +1599,6 @@ def sample( return idx_next, probs - def run_generator( args, rank: Optional[int] =None @@ -1628,8 +1631,10 @@ def run_generator( ) if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() - if torch.xpu.is_available(): + elif torch.xpu.is_available(): torch.xpu.reset_peak_memory_stats() + elif hasattr(torch, "npu") and torch.npu.is_available(): + torch.npu.reset_peak_memory_stats() for _ in gen.chat(generator_args): pass diff --git a/torchchat/utils/build_utils.py b/torchchat/utils/build_utils.py index b9c32a7fe..14a955bbf 100644 --- a/torchchat/utils/build_utils.py +++ b/torchchat/utils/build_utils.py @@ -233,6 +233,8 @@ def device_sync(device="cpu"): torch.cuda.synchronize(device) elif "xpu" in device: torch.xpu.synchronize(device) + elif "npu" in device: + torch.npu.synchronize(device) elif ("cpu" in device) or ("mps" in device): pass else: @@ -275,15 +277,18 @@ def is_mps_available() -> bool: # MPS, is that you? return True +def select_device() -> str: + if torch.accelerator.is_available(): + device = torch.accelerator.current_accelerator().type + if device == "mps" and not is_mps_available(): + return "cpu" + return device + else: + return "cpu" def get_device_str(device) -> str: if isinstance(device, str) and device == "fast": - device = ( - "cuda" - if torch.cuda.is_available() - else "mps" if is_mps_available() - else "xpu" if torch.xpu.is_available() else "cpu" - ) + device = select_device() return device else: return str(device) @@ -291,17 +296,13 @@ def get_device_str(device) -> str: def get_device(device) -> str: if isinstance(device, str) and device == "fast": - device = ( - "cuda" - if torch.cuda.is_available() - else "mps" if is_mps_available() - else "xpu" if torch.xpu.is_available() else "cpu" - ) + device = select_device() return torch.device(device) def is_cpu_device(device) -> bool: return device == "" or str(device) == "cpu" -def is_cuda_or_cpu_or_xpu_device(device) -> bool: - return is_cpu_device(device) or ("cuda" in str(device)) or ("xpu" in str(device)) +def is_supported_device(device) -> bool: + device_str = str(device) + return is_cpu_device(device) or any(dev in device_str for dev in ('cuda', 'xpu', 'npu')) diff --git a/torchchat/utils/device_info.py b/torchchat/utils/device_info.py index 950c03002..8acc8e14e 100644 --- a/torchchat/utils/device_info.py +++ b/torchchat/utils/device_info.py @@ -9,12 +9,11 @@ import torch - def get_device_info(device: str) -> str: """Returns a human-readable description of the hardware based on a torch.device.type Args: - device: A torch.device.type string: one of {"cpu", "cuda", "xpu"}. + device: A torch.device.type string: one of {"cpu", "cuda", "xpu", "npu"}. Returns: str: A human-readable description of the hardware or an empty string if the device type is unhandled. @@ -46,4 +45,6 @@ def get_device_info(device: str) -> str: .split("\n")[0] .split("Device Name:")[1] ) + if device == "npu": + return torch.npu.get_device_name(0) return "" diff --git a/torchchat/utils/quantize.py b/torchchat/utils/quantize.py index 6246f1c05..aca81ef5c 100644 --- a/torchchat/utils/quantize.py +++ b/torchchat/utils/quantize.py @@ -123,7 +123,7 @@ def quantize_model( raise RuntimeError(f"unknown quantizer {quantizer} specified") else: # Use tensor subclass API for int4 weight only. - if (device == "cuda" or device == "xpu") and quantizer == "linear:int4": + if (device in ["cuda", "xpu", "npu"]) and quantizer == "linear:int4": quantize_(model, int4_weight_only(q_kwargs["groupsize"])) if not support_tensor_subclass: unwrap_tensor_subclass(model)