Skip to content

Commit e379c89

Browse files
committed
fix fla crash on plugin
Signed-off-by: Hank <[email protected]>
1 parent 1bf43ae commit e379c89

File tree

1 file changed

+3
-2
lines changed
  • vllm/model_executor/layers/fla/ops

1 file changed

+3
-2
lines changed

vllm/model_executor/layers/fla/ops/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import torch
1919

20+
from vllm.platforms import current_platform
2021
from vllm.triton_utils import triton
2122

2223
logger = logging.getLogger(__name__)
@@ -137,8 +138,8 @@ def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]:
137138
# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
138139
# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
139140
# Therefore, we need to check the triton backend to determine the actual GPU vendor.
140-
device = get_available_device() if get_available_device() != "hip" else "cuda"
141-
device_torch_lib = getattr(torch, device)
141+
device = "cuda" if current_platform.is_cuda_alike() else get_available_device()
142+
device_torch_lib = getattr(torch, device, None)
142143
device_platform = _check_platform()
143144

144145
is_amd = device_platform == "amd"

0 commit comments

Comments
 (0)