Skip to content

Commit 5e59de5

Browse files
author
luwenbo
committed
add decorator at execute_dummy_run
1 parent 81dc4a1 commit 5e59de5

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

vllm_ascend/ascend_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ def __init__(self, vllm_config):
7575
"recompute_scheduler_enable", False)
7676
self.lmhead_tensor_parallel_size = additional_config.get(
7777
"lmhead_tensor_parallel_size", None)
78+
self.fault_tolerance_level = additional_config.get(
79+
"fault_tolerance_lecel",0)
7880
if self.lmhead_tensor_parallel_size is not None:
7981
logger.info(
8082
f"Enable lmhead_tensor_parallel_size={self.lmhead_tensor_parallel_size} in pure DP scenario"

vllm_ascend/worker/fault_tolerance.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import Callable,List
1010
from vllm.config import VllmConfig
1111
from vllm.logger import logger
12-
from vllm.distributed.parallel_state import get_pp_group,get_tp_group
12+
from vllm.distributed.parallel_state import get_pp_group,get_tp_group,get_dp_group
1313
from vllm_ascend.worker.memory_block_info import MemoryBlockInfo
1414
from vllm_ascend.worker.fault_aware import FaultAware
1515
from vllm_ascend.worker.common import FaultAction,FaultToleranceLevel,RecoveryStatus
@@ -18,22 +18,24 @@
1818

1919
class FaultTolerance:
2020
_recovery_group = None
21-
def __init__(self,vllm_config:VllmConfig,model,level: FaultToleranceLevel = FaultToleranceLevel.OFF):
21+
def __init__(self,vllm_config:VllmConfig,model):
2222
self.model = model
2323
self.vllm_config = vllm_config
24-
self.level = level
24+
#TODO: 需要确认当前启动参数里有没有additional_config
25+
self.level = vllm_config.additional_config.get("fault_tolerance_level",0)
2526
self.fault_queue = queue.Queue()
2627
self.memory_info = MemoryBlockInfo(self.model)
2728
self.recovery_chain = self._build_recovery_chain()
2829

29-
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
30-
self.rank = dist.get_rank() if dist.is_initialized() else 0
30+
# TODO:这里需要用每个dp组下的rank0做汇总,需要确认一下参数是否正确
31+
self.world_size = get_dp_group().world_size
32+
self.rank = get_dp_group().rank_in_group
3133

3234
self._init_recovery_group()
3335
self.memory_info.initialize()
3436

3537
self.aware_event = threading.Event()
36-
if self.level != FaultToleranceLevel.OFF:
38+
if self.level != FaultToleranceLevel.OFF.value:
3739
FaultAware(
3840
self.rank,self.world_size,self.fault_queue,aware_event=self.aware_event
3941
).start()
@@ -46,7 +48,8 @@ def _init_recovery_group(self):
4648
return
4749

4850
FaultTolerance._recovery_group = dist.new_group(
49-
ranks=None,
51+
#TODO:确认这个dp_group.ranks是否是我需要的
52+
ranks=get_dp_group().ranks,
5053
timeout=timedelta(minutes=5),
5154
backend="gloo",
5255
)
@@ -69,7 +72,8 @@ def fault_tolerance_decorator(self, func: Callable) -> Callable:
6972
def wrapper(*args, **kwargs):
7073
# Level 0:disable fault tolerance
7174
if self.level == FaultToleranceLevel.OFF.value:
72-
return func(*args, **kwargs)
75+
output = func(*args,**kwargs)
76+
return output
7377
# Enable fault tolerance
7478
while True:
7579
try:

0 commit comments

Comments
 (0)