-
Notifications
You must be signed in to change notification settings - Fork 624
[WIP]V0.11.0 dev-Token level re-inference #4508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: v0.11.0-dev
Are you sure you want to change the base?
Changes from all commits
8962a63
1985a9e
43afdb6
ceaa764
ae0f201
9beb917
837b8d6
d695acd
ea2de98
0e1993b
e40fad6
6ba9b4a
81dc4a1
5e59de5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| import torch | ||
| from enum import Enum | ||
|
|
||
| class FaultToleranceLevel(Enum): | ||
| """ | ||
| Fault tolerance level | ||
| level 0: disable fault tolerance | ||
| level 1: enable base fault tolerance for weight UCE/Activation UCE/Network Error | ||
| level 2: enable all fault tolerance for weight UCE/Activation UCE/KVCache UCE/Network Error | ||
| """ | ||
| OFF = 0 | ||
| BASIC = 1 | ||
| FULL = 2 | ||
|
|
||
| class FaultStatus(Enum): | ||
| """ | ||
| Fault status which fault_tolerance put into fault_queue | ||
| """ | ||
| ACTIVE = torch.tensor([0]) | ||
| UCE_ERR = torch.tensor([1]) | ||
| FORCE_STOP = torch.tensor([2]) | ||
| NETWORK_ERR = torch.tensor([3]) | ||
|
|
||
| class FaultCommand: | ||
| """ | ||
| Fault command which rank 0 broadcast in fault_aware | ||
| """ | ||
| INIT_CMD = torch.tensor([0]) | ||
| SILENCE_CMD = torch.tensor([1]) | ||
| STOP_DEVICE_CMD = torch.tensor([2]) | ||
|
|
||
| class UCEType(Enum): | ||
| """ | ||
| Specific uce type for HBM UCE | ||
| """ | ||
| WEIGHTS_UCE = "WEIGHTS_UCE" | ||
| KVCACHE_UCE = "KVCACHE_UCE" | ||
| ACTIVATION_UCE = "ACTIVATION_UCE" | ||
| UNKNOWN_UCE = "UNKNOWN_UCE" | ||
|
|
||
| class RecoveryStatus: | ||
| SUCCESS = torch.tensor([0]) | ||
| FAILED = torch.tensor([1]) | ||
|
|
||
| class FaultAction: | ||
| RAISE_EXCEPTION = torch.tensor([0]) | ||
| RETURN = torch.tensor([1]) | ||
| RECOMPUTE = torch.tensor([2]) | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,135 @@ | ||||||||||||||||||||||
| import time | ||||||||||||||||||||||
| import threading | ||||||||||||||||||||||
| import torch | ||||||||||||||||||||||
| import queue | ||||||||||||||||||||||
| import torch.distributed | ||||||||||||||||||||||
| import torch_npu | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| from datetime import timedelta | ||||||||||||||||||||||
| from vllm.logger import logger | ||||||||||||||||||||||
| from vllm_ascend.worker.common import FaultStatus,FaultCommand | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| class FaultAware: | ||||||||||||||||||||||
| _fault_aware_group = None | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def __init__(self,rank:int,world_size:int,fault_queue:queue.Queue,interval_s=1, | ||||||||||||||||||||||
| aware_event:threading.Event=None): | ||||||||||||||||||||||
| self.rank = rank | ||||||||||||||||||||||
| self.world_size = world_size | ||||||||||||||||||||||
| self.npu_id = torch.npu.current_device() | ||||||||||||||||||||||
| self.fault_queue = fault_queue | ||||||||||||||||||||||
| self.interval_s = interval_s | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| self._fault_aware_thread = None | ||||||||||||||||||||||
| self.aware_event = aware_event | ||||||||||||||||||||||
| self._stop_event = threading.Event() | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def init_fault_aware_group(self): | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| Initialize the Torch process group for fault aware. | ||||||||||||||||||||||
| Rank 0 is the coordinator rank, | ||||||||||||||||||||||
| the other ranks are the normal rank,which is used for sending status to rank 0. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Rank 0 will collect the status from all the other ranks and broadcast stop_device | ||||||||||||||||||||||
| command to all the other ranks through `_fault_aware_group` | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| assert( | ||||||||||||||||||||||
| torch.distributed.is_initialized() | ||||||||||||||||||||||
| ),"Default torch process group must be initialized" | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| assert( | ||||||||||||||||||||||
| torch.distributed.is_gloo_available() | ||||||||||||||||||||||
| ),"Gloo process group must be available" | ||||||||||||||||||||||
|
Comment on lines
+36
to
+42
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| rank = self.rank | ||||||||||||||||||||||
| logger.info( | ||||||||||||||||||||||
| f"init fault aware process group: " | ||||||||||||||||||||||
| f"rank={rank},world_size={self.world_size},backend=gloo" | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| FaultAware._fault_aware_group = torch.distributed.new_group( | ||||||||||||||||||||||
| ranks=None, | ||||||||||||||||||||||
| timeout=timedelta(minutes=5), | ||||||||||||||||||||||
| backend="gloo" | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| assert self._fault_aware_group is not None | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def start(self): | ||||||||||||||||||||||
| """Start the fault aware""" | ||||||||||||||||||||||
| self.init_fault_aware_group() | ||||||||||||||||||||||
| logger.info("Start fault aware thread") | ||||||||||||||||||||||
| try: | ||||||||||||||||||||||
| self._fault_aware_thread = threading.Thread( | ||||||||||||||||||||||
| target=self._handler, | ||||||||||||||||||||||
| daemon=True, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| self._fault_aware_thread.start() | ||||||||||||||||||||||
| logger.info("Succeeded to start fault aware thread") | ||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||
| logger.error(f"Failed to start fault aware thread:{e}") | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def _handler(self): | ||||||||||||||||||||||
| torch.npu.set_device(self.npu_id) | ||||||||||||||||||||||
| status = FaultStatus.ACTIVE.value | ||||||||||||||||||||||
| status_list = ( | ||||||||||||||||||||||
| [torch.zeros([1],dtype=torch.int64) for _ in range(self.world_size)] | ||||||||||||||||||||||
| if self.rank == 0 | ||||||||||||||||||||||
| else None | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| while True: | ||||||||||||||||||||||
| try: | ||||||||||||||||||||||
| msg = self.fault_queue.get_nowait() | ||||||||||||||||||||||
| if msg: | ||||||||||||||||||||||
| logger.info(f"Get abnormal status,update status {msg.name},update status") | ||||||||||||||||||||||
| status = msg.value | ||||||||||||||||||||||
| except queue.Empty: | ||||||||||||||||||||||
| if not threading.main_thread().is_alive(): | ||||||||||||||||||||||
| return | ||||||||||||||||||||||
| try: | ||||||||||||||||||||||
| torch.distributed.gather( | ||||||||||||||||||||||
| tensor=status, | ||||||||||||||||||||||
| gather_list=status_list, | ||||||||||||||||||||||
| dst=0, | ||||||||||||||||||||||
| group = self._fault_aware_group, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| fault_cmd = FaultCommand.INIT_CMD | ||||||||||||||||||||||
| if self.rank == 0: | ||||||||||||||||||||||
| if all(torch.equal(t,FaultStatus.ACTIVE.value) for t in status_list): | ||||||||||||||||||||||
| fault_cmd = FaultCommand.SILENCE_CMD | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| fault_cmd = FaultCommand.STOP_DEVICE_CMD | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| torch.distributed.broadcast( | ||||||||||||||||||||||
| tensor=fault_cmd, | ||||||||||||||||||||||
| src=0, | ||||||||||||||||||||||
| group=self._fault_aware_group, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if torch.equal(fault_cmd,FaultCommand.SILENCE_CMD): | ||||||||||||||||||||||
| time.sleep(self.interval_s) | ||||||||||||||||||||||
| elif torch.equal(fault_cmd,FaultCommand.STOP_DEVICE_CMD): | ||||||||||||||||||||||
| logger.info(f"Error in group,execute stop_device") | ||||||||||||||||||||||
| torch_npu.npu.stop_device(self.npu_id) | ||||||||||||||||||||||
| # Wait for fault_tolerance to wake me up | ||||||||||||||||||||||
| self.aware_event.wait() | ||||||||||||||||||||||
| self.aware_event.clear() | ||||||||||||||||||||||
| # Assume recover successfully | ||||||||||||||||||||||
| status = FaultStatus.ACTIVE.value | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| raise RuntimeError(f"Unknown fault command:{fault_cmd}") | ||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||
| time.sleep(self.interval_s) | ||||||||||||||||||||||
| logger.error(f"Fault aware handler exception:{e}") | ||||||||||||||||||||||
| if not threading.main_thread().is_alive(): | ||||||||||||||||||||||
| return | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def destroy_fault_aware_group(self): | ||||||||||||||||||||||
| """Destroy the Torch process group for fault aware""" | ||||||||||||||||||||||
| if self._fault_aware_group is None: | ||||||||||||||||||||||
| return | ||||||||||||||||||||||
| logger.info("Destroy fault aware process group") | ||||||||||||||||||||||
| try: | ||||||||||||||||||||||
| torch.distributed.destroy_process_group(self._fault_aware_group) | ||||||||||||||||||||||
| self._fault_aware_group = None | ||||||||||||||||||||||
| logger.info("Succeeded to destroy fault aware process group") | ||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||
| logger.error(f"Failed to destroy fault aware process group:{e}") | ||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The classes
FaultStatus,FaultCommand,RecoveryStatus, andFaultActionusetorch.tensorobjects as values for constants. This is unconventional and can lead to subtle bugs related to device placement or tensor properties. It is also inconsistent, asFaultStatusis anEnumwhile the others are plain classes. A better practice is to use simple integer values in enums (e.g., usingenum.IntEnum) and create tensors from them only when needed for distributed communication. This would improve readability, maintainability, and avoid potential pitfalls with using tensors as constants.