99from typing import Callable ,List
1010from vllm .config import VllmConfig
1111from vllm .logger import logger
12- from vllm .distributed .parallel_state import get_pp_group ,get_tp_group
12+ from vllm .distributed .parallel_state import get_pp_group ,get_tp_group , get_dp_group
1313from vllm_ascend .worker .memory_block_info import MemoryBlockInfo
1414from vllm_ascend .worker .fault_aware import FaultAware
1515from vllm_ascend .worker .common import FaultAction ,FaultToleranceLevel ,RecoveryStatus
1818
1919class FaultTolerance :
2020 _recovery_group = None
21- def __init__ (self ,vllm_config :VllmConfig ,model , level : FaultToleranceLevel = FaultToleranceLevel . OFF ):
21+ def __init__ (self ,vllm_config :VllmConfig ,model ):
2222 self .model = model
2323 self .vllm_config = vllm_config
24- self .level = level
24+ #TODO: 需要确认当前启动参数里有没有additional_config
25+ self .level = vllm_config .additional_config .get ("fault_tolerance_level" ,0 )
2526 self .fault_queue = queue .Queue ()
2627 self .memory_info = MemoryBlockInfo (self .model )
2728 self .recovery_chain = self ._build_recovery_chain ()
2829
29- self .world_size = dist .get_world_size () if dist .is_initialized () else 1
30- self .rank = dist .get_rank () if dist .is_initialized () else 0
30+ # TODO:这里需要用每个dp组下的rank0做汇总,需要确认一下参数是否正确
31+ self .world_size = get_dp_group ().world_size
32+ self .rank = get_dp_group ().rank_in_group
3133
3234 self ._init_recovery_group ()
3335 self .memory_info .initialize ()
3436
3537 self .aware_event = threading .Event ()
36- if self .level != FaultToleranceLevel .OFF :
38+ if self .level != FaultToleranceLevel .OFF . value :
3739 FaultAware (
3840 self .rank ,self .world_size ,self .fault_queue ,aware_event = self .aware_event
3941 ).start ()
@@ -46,7 +48,8 @@ def _init_recovery_group(self):
4648 return
4749
4850 FaultTolerance ._recovery_group = dist .new_group (
49- ranks = None ,
51+ #TODO:确认这个dp_group.ranks是否是我需要的
52+ ranks = get_dp_group ().ranks ,
5053 timeout = timedelta (minutes = 5 ),
5154 backend = "gloo" ,
5255 )
@@ -69,7 +72,8 @@ def fault_tolerance_decorator(self, func: Callable) -> Callable:
6972 def wrapper (* args , ** kwargs ):
7073 # Level 0:disable fault tolerance
7174 if self .level == FaultToleranceLevel .OFF .value :
72- return func (* args , ** kwargs )
75+ output = func (* args ,** kwargs )
76+ return output
7377 # Enable fault tolerance
7478 while True :
7579 try :
0 commit comments