Skip to content
Open
2 changes: 2 additions & 0 deletions vllm_ascend/ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def __init__(self, vllm_config):
"recompute_scheduler_enable", False)
self.lmhead_tensor_parallel_size = additional_config.get(
"lmhead_tensor_parallel_size", None)
self.fault_tolerance_level = additional_config.get(
"fault_tolerance_lecel",0)
if self.lmhead_tensor_parallel_size is not None:
logger.info(
f"Enable lmhead_tensor_parallel_size={self.lmhead_tensor_parallel_size} in pure DP scenario"
Expand Down
48 changes: 48 additions & 0 deletions vllm_ascend/worker/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch
from enum import Enum

class FaultToleranceLevel(Enum):
"""
Fault tolerance level
level 0: disable fault tolerance
level 1: enable base fault tolerance for weight UCE/Activation UCE/Network Error
level 2: enable all fault tolerance for weight UCE/Activation UCE/KVCache UCE/Network Error
"""
OFF = 0
BASIC = 1
FULL = 2

class FaultStatus(Enum):
"""
Fault status which fault_tolerance put into fault_queue
"""
ACTIVE = torch.tensor([0])
UCE_ERR = torch.tensor([1])
FORCE_STOP = torch.tensor([2])
NETWORK_ERR = torch.tensor([3])

class FaultCommand:
"""
Fault command which rank 0 broadcast in fault_aware
"""
INIT_CMD = torch.tensor([0])
SILENCE_CMD = torch.tensor([1])
STOP_DEVICE_CMD = torch.tensor([2])

class UCEType(Enum):
"""
Specific uce type for HBM UCE
"""
WEIGHTS_UCE = "WEIGHTS_UCE"
KVCACHE_UCE = "KVCACHE_UCE"
ACTIVATION_UCE = "ACTIVATION_UCE"
UNKNOWN_UCE = "UNKNOWN_UCE"

class RecoveryStatus:
SUCCESS = torch.tensor([0])
FAILED = torch.tensor([1])

class FaultAction:
RAISE_EXCEPTION = torch.tensor([0])
RETURN = torch.tensor([1])
RECOMPUTE = torch.tensor([2])
Comment on lines +15 to +48
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The classes FaultStatus, FaultCommand, RecoveryStatus, and FaultAction use torch.tensor objects as values for constants. This is unconventional and can lead to subtle bugs related to device placement or tensor properties. It is also inconsistent, as FaultStatus is an Enum while the others are plain classes. A better practice is to use simple integer values in enums (e.g., using enum.IntEnum) and create tensors from them only when needed for distributed communication. This would improve readability, maintainability, and avoid potential pitfalls with using tensors as constants.

from enum import IntEnum

class FaultStatus(IntEnum):
    """
    Fault status which fault_tolerance put into fault_queue
    """
    ACTIVE = 0
    UCE_ERR = 1
    FORCE_STOP = 2
    NETWORK_ERR = 3

class FaultCommand(IntEnum):
    """
    Fault command which rank 0 broadcast in fault_aware
    """
    INIT_CMD = 0
    SILENCE_CMD = 1
    STOP_DEVICE_CMD = 2

class UCEType(Enum):
    """
    Specific uce type for HBM UCE
    """
    WEIGHTS_UCE = "WEIGHTS_UCE"
    KVCACHE_UCE = "KVCACHE_UCE"
    ACTIVATION_UCE = "ACTIVATION_UCE"
    UNKNOWN_UCE = "UNKNOWN_UCE"

class RecoveryStatus(IntEnum):
    SUCCESS = 0
    FAILED = 1

class FaultAction(IntEnum):
    RAISE_EXCEPTION = 0
    RETURN = 1
    RECOMPUTE = 2

135 changes: 135 additions & 0 deletions vllm_ascend/worker/fault_aware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import time
import threading
import torch
import queue
import torch.distributed
import torch_npu

from datetime import timedelta
from vllm.logger import logger
from vllm_ascend.worker.common import FaultStatus,FaultCommand

class FaultAware:
_fault_aware_group = None

def __init__(self,rank:int,world_size:int,fault_queue:queue.Queue,interval_s=1,
aware_event:threading.Event=None):
self.rank = rank
self.world_size = world_size
self.npu_id = torch.npu.current_device()
self.fault_queue = fault_queue
self.interval_s = interval_s

Check failure on line 22 in vllm_ascend/worker/fault_aware.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Incompatible default for argument "aware_event" (default has type "None", argument has type "Event") [assignment]
self._fault_aware_thread = None
self.aware_event = aware_event
self._stop_event = threading.Event()

def init_fault_aware_group(self):
"""
Initialize the Torch process group for fault aware.
Rank 0 is the coordinator rank,
the other ranks are the normal rank,which is used for sending status to rank 0.

Rank 0 will collect the status from all the other ranks and broadcast stop_device
command to all the other ranks through `_fault_aware_group`
"""
assert(
torch.distributed.is_initialized()
),"Default torch process group must be initialized"

assert(
torch.distributed.is_gloo_available()
),"Gloo process group must be available"
Comment on lines +36 to +42
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The assert statements on lines 36 and 40 use tuple syntax assert(condition, message), which is a common pitfall in Python. A non-empty tuple always evaluates to True, so these assertions will never fail, even if the conditions torch.distributed.is_initialized() or torch.distributed.is_gloo_available() are false. This is a critical issue as it bypasses important validations, potentially allowing the program to proceed in an invalid state and causing hard-to-debug errors later.

Suggested change
assert(
torch.distributed.is_initialized()
),"Default torch process group must be initialized"
assert(
torch.distributed.is_gloo_available()
),"Gloo process group must be available"
assert torch.distributed.is_initialized(), "Default torch process group must be initialized"
assert torch.distributed.is_gloo_available(), "Gloo process group must be available"


rank = self.rank
logger.info(
f"init fault aware process group: "
f"rank={rank},world_size={self.world_size},backend=gloo"
)
FaultAware._fault_aware_group = torch.distributed.new_group(
ranks=None,
timeout=timedelta(minutes=5),
backend="gloo"
)
assert self._fault_aware_group is not None

def start(self):
"""Start the fault aware"""
self.init_fault_aware_group()
logger.info("Start fault aware thread")
try:

Check failure on line 60 in vllm_ascend/worker/fault_aware.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Incompatible types in assignment (expression has type "Thread", variable has type "None") [assignment]
self._fault_aware_thread = threading.Thread(
target=self._handler,
daemon=True,
)

Check failure on line 64 in vllm_ascend/worker/fault_aware.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"None" has no attribute "start" [attr-defined]
self._fault_aware_thread.start()
logger.info("Succeeded to start fault aware thread")
except Exception as e:
logger.error(f"Failed to start fault aware thread:{e}")

def _handler(self):
torch.npu.set_device(self.npu_id)
status = FaultStatus.ACTIVE.value
status_list = (
[torch.zeros([1],dtype=torch.int64) for _ in range(self.world_size)]
if self.rank == 0
else None
)
while True:
try:
msg = self.fault_queue.get_nowait()
if msg:
logger.info(f"Get abnormal status,update status {msg.name},update status")
status = msg.value
except queue.Empty:
if not threading.main_thread().is_alive():
return
try:
torch.distributed.gather(
tensor=status,
gather_list=status_list,
dst=0,
group = self._fault_aware_group,
)
fault_cmd = FaultCommand.INIT_CMD
if self.rank == 0:
if all(torch.equal(t,FaultStatus.ACTIVE.value) for t in status_list):
fault_cmd = FaultCommand.SILENCE_CMD

Check failure on line 97 in vllm_ascend/worker/fault_aware.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Item "None" of "Optional[list[Any]]" has no attribute "__iter__" (not iterable) [union-attr]
else:
fault_cmd = FaultCommand.STOP_DEVICE_CMD

torch.distributed.broadcast(
tensor=fault_cmd,
src=0,
group=self._fault_aware_group,
)

if torch.equal(fault_cmd,FaultCommand.SILENCE_CMD):
time.sleep(self.interval_s)
elif torch.equal(fault_cmd,FaultCommand.STOP_DEVICE_CMD):
logger.info(f"Error in group,execute stop_device")
torch_npu.npu.stop_device(self.npu_id)
# Wait for fault_tolerance to wake me up
self.aware_event.wait()
self.aware_event.clear()
# Assume recover successfully
status = FaultStatus.ACTIVE.value
else:
raise RuntimeError(f"Unknown fault command:{fault_cmd}")
except Exception as e:
time.sleep(self.interval_s)
logger.error(f"Fault aware handler exception:{e}")
if not threading.main_thread().is_alive():
return

def destroy_fault_aware_group(self):
"""Destroy the Torch process group for fault aware"""
if self._fault_aware_group is None:
return
logger.info("Destroy fault aware process group")
try:
torch.distributed.destroy_process_group(self._fault_aware_group)
self._fault_aware_group = None
logger.info("Succeeded to destroy fault aware process group")
except Exception as e:
logger.error(f"Failed to destroy fault aware process group:{e}")
Loading
Loading