diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index b0973b15c29..2420bfdb4ad 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -75,6 +75,8 @@ def __init__(self, vllm_config): "recompute_scheduler_enable", False) self.lmhead_tensor_parallel_size = additional_config.get( "lmhead_tensor_parallel_size", None) + self.fault_tolerance_level = additional_config.get( + "fault_tolerance_lecel",0) if self.lmhead_tensor_parallel_size is not None: logger.info( f"Enable lmhead_tensor_parallel_size={self.lmhead_tensor_parallel_size} in pure DP scenario" diff --git a/vllm_ascend/worker/common.py b/vllm_ascend/worker/common.py new file mode 100644 index 00000000000..81acb13f22e --- /dev/null +++ b/vllm_ascend/worker/common.py @@ -0,0 +1,48 @@ +import torch +from enum import Enum + +class FaultToleranceLevel(Enum): + """ + Fault tolerance level + level 0: disable fault tolerance + level 1: enable base fault tolerance for weight UCE/Activation UCE/Network Error + level 2: enable all fault tolerance for weight UCE/Activation UCE/KVCache UCE/Network Error + """ + OFF = 0 + BASIC = 1 + FULL = 2 + +class FaultStatus(Enum): + """ + Fault status which fault_tolerance put into fault_queue + """ + ACTIVE = torch.tensor([0]) + UCE_ERR = torch.tensor([1]) + FORCE_STOP = torch.tensor([2]) + NETWORK_ERR = torch.tensor([3]) + +class FaultCommand: + """ + Fault command which rank 0 broadcast in fault_aware + """ + INIT_CMD = torch.tensor([0]) + SILENCE_CMD = torch.tensor([1]) + STOP_DEVICE_CMD = torch.tensor([2]) + +class UCEType(Enum): + """ + Specific uce type for HBM UCE + """ + WEIGHTS_UCE = "WEIGHTS_UCE" + KVCACHE_UCE = "KVCACHE_UCE" + ACTIVATION_UCE = "ACTIVATION_UCE" + UNKNOWN_UCE = "UNKNOWN_UCE" + +class RecoveryStatus: + SUCCESS = torch.tensor([0]) + FAILED = torch.tensor([1]) + +class FaultAction: + RAISE_EXCEPTION = torch.tensor([0]) + RETURN = torch.tensor([1]) + RECOMPUTE = torch.tensor([2]) \ No newline at end of file diff --git a/vllm_ascend/worker/fault_aware.py b/vllm_ascend/worker/fault_aware.py new file mode 100644 index 00000000000..37ebd04544d --- /dev/null +++ b/vllm_ascend/worker/fault_aware.py @@ -0,0 +1,135 @@ +import time +import threading +import torch +import queue +import torch.distributed +import torch_npu + +from datetime import timedelta +from vllm.logger import logger +from vllm_ascend.worker.common import FaultStatus,FaultCommand + +class FaultAware: + _fault_aware_group = None + + def __init__(self,rank:int,world_size:int,fault_queue:queue.Queue,interval_s=1, + aware_event:threading.Event=None): + self.rank = rank + self.world_size = world_size + self.npu_id = torch.npu.current_device() + self.fault_queue = fault_queue + self.interval_s = interval_s + + self._fault_aware_thread = None + self.aware_event = aware_event + self._stop_event = threading.Event() + + def init_fault_aware_group(self): + """ + Initialize the Torch process group for fault aware. + Rank 0 is the coordinator rank, + the other ranks are the normal rank,which is used for sending status to rank 0. + + Rank 0 will collect the status from all the other ranks and broadcast stop_device + command to all the other ranks through `_fault_aware_group` + """ + assert( + torch.distributed.is_initialized() + ),"Default torch process group must be initialized" + + assert( + torch.distributed.is_gloo_available() + ),"Gloo process group must be available" + + rank = self.rank + logger.info( + f"init fault aware process group: " + f"rank={rank},world_size={self.world_size},backend=gloo" + ) + FaultAware._fault_aware_group = torch.distributed.new_group( + ranks=None, + timeout=timedelta(minutes=5), + backend="gloo" + ) + assert self._fault_aware_group is not None + + def start(self): + """Start the fault aware""" + self.init_fault_aware_group() + logger.info("Start fault aware thread") + try: + self._fault_aware_thread = threading.Thread( + target=self._handler, + daemon=True, + ) + self._fault_aware_thread.start() + logger.info("Succeeded to start fault aware thread") + except Exception as e: + logger.error(f"Failed to start fault aware thread:{e}") + + def _handler(self): + torch.npu.set_device(self.npu_id) + status = FaultStatus.ACTIVE.value + status_list = ( + [torch.zeros([1],dtype=torch.int64) for _ in range(self.world_size)] + if self.rank == 0 + else None + ) + while True: + try: + msg = self.fault_queue.get_nowait() + if msg: + logger.info(f"Get abnormal status,update status {msg.name},update status") + status = msg.value + except queue.Empty: + if not threading.main_thread().is_alive(): + return + try: + torch.distributed.gather( + tensor=status, + gather_list=status_list, + dst=0, + group = self._fault_aware_group, + ) + fault_cmd = FaultCommand.INIT_CMD + if self.rank == 0: + if all(torch.equal(t,FaultStatus.ACTIVE.value) for t in status_list): + fault_cmd = FaultCommand.SILENCE_CMD + else: + fault_cmd = FaultCommand.STOP_DEVICE_CMD + + torch.distributed.broadcast( + tensor=fault_cmd, + src=0, + group=self._fault_aware_group, + ) + + if torch.equal(fault_cmd,FaultCommand.SILENCE_CMD): + time.sleep(self.interval_s) + elif torch.equal(fault_cmd,FaultCommand.STOP_DEVICE_CMD): + logger.info(f"Error in group,execute stop_device") + torch_npu.npu.stop_device(self.npu_id) + # Wait for fault_tolerance to wake me up + self.aware_event.wait() + self.aware_event.clear() + # Assume recover successfully + status = FaultStatus.ACTIVE.value + else: + raise RuntimeError(f"Unknown fault command:{fault_cmd}") + except Exception as e: + time.sleep(self.interval_s) + logger.error(f"Fault aware handler exception:{e}") + if not threading.main_thread().is_alive(): + return + + def destroy_fault_aware_group(self): + """Destroy the Torch process group for fault aware""" + if self._fault_aware_group is None: + return + logger.info("Destroy fault aware process group") + try: + torch.distributed.destroy_process_group(self._fault_aware_group) + self._fault_aware_group = None + logger.info("Succeeded to destroy fault aware process group") + except Exception as e: + logger.error(f"Failed to destroy fault aware process group:{e}") diff --git a/vllm_ascend/worker/fault_tolerance.py b/vllm_ascend/worker/fault_tolerance.py new file mode 100644 index 00000000000..e0bc5828408 --- /dev/null +++ b/vllm_ascend/worker/fault_tolerance.py @@ -0,0 +1,277 @@ +import torch +import functools +import queue +import threading +import torch_npu +import torch.distributed as dist + +from datetime import timedelta +from typing import Callable,List +from vllm.config import VllmConfig +from vllm.logger import logger +from vllm.distributed.parallel_state import get_pp_group,get_tp_group,get_dp_group +from vllm_ascend.worker.memory_block_info import MemoryBlockInfo +from vllm_ascend.worker.fault_aware import FaultAware +from vllm_ascend.worker.common import FaultAction,FaultToleranceLevel,RecoveryStatus +from vllm_ascend.worker.recovery_chain import UCEHandler, RecoveryHandler, ForceStopHandler, NetworkHandler +from vllm_ascend.worker.recovery_context import RecoveryContext + +class FaultTolerance: + _recovery_group = None + def __init__(self,vllm_config:VllmConfig,model): + self.model = model + self.vllm_config = vllm_config + #TODO: 需要确认当前启动参数里有没有additional_config + self.level = vllm_config.additional_config.get("fault_tolerance_level",0) + self.fault_queue = queue.Queue() + self.memory_info = MemoryBlockInfo(self.model) + self.recovery_chain = self._build_recovery_chain() + + # TODO:这里需要用每个dp组下的rank0做汇总,需要确认一下参数是否正确 + self.world_size = get_dp_group().world_size + self.rank = get_dp_group().rank_in_group + + self._init_recovery_group() + self.memory_info.initialize() + + self.aware_event = threading.Event() + if self.level != FaultToleranceLevel.OFF.value: + FaultAware( + self.rank,self.world_size,self.fault_queue,aware_event=self.aware_event + ).start() + + def _init_recovery_group(self): + """ + Initialize the global communication group for reporting abnormal status to fault_aware. + """ + if not dist.is_initialized() or self.world_size == 1: + return + + FaultTolerance._recovery_group = dist.new_group( + #TODO:确认这个dp_group.ranks是否是我需要的 + ranks=get_dp_group().ranks, + timeout=timedelta(minutes=5), + backend="gloo", + ) + + logger.info(f"Recovery group initialization successful for rank {self.rank}") + + def _build_recovery_chain(self) -> RecoveryHandler: + """initialize recovery chain""" + force_stop_handler = ForceStopHandler() + network_handler = NetworkHandler() + uce_handler = UCEHandler() + + force_stop_handler.set_next(network_handler).set_next(uce_handler) + + return force_stop_handler + + def fault_tolerance_decorator(self, func: Callable) -> Callable: + """fault tolerance decorator is used to modify the execute_model for exception handling.""" + @functools.wraps(func) + def wrapper(*args, **kwargs): + # Level 0:disable fault tolerance + if self.level == FaultToleranceLevel.OFF.value: + output = func(*args,**kwargs) + return output + # Enable fault tolerance + while True: + try: + output = func(*args, **kwargs) + return output + except Exception as e: + # Encapsulate the context information required for fault recovery. + recovery_context = RecoveryContext( + model=self.model, + level=self.level, + exception=e, + rank=self.rank, + model_or_path=self.vllm_config.model_config.model, + memory_block_info=self.memory_info, + fault_queue=self.fault_queue + ) + ft_action = self._handle_exception(recovery_context) + if torch.equal(ft_action,FaultAction.RECOMPUTE): + self.aware_event.set() + logger.info(f"Begin token re-inference at rank {self.rank}") + continue + elif torch.equal(ft_action,FaultAction.RAISE_EXCEPTION): + logger.info(f"Raise exception at rank {self.rank}") + # TODO: Need to clear cache for current batch and destroy all group + raise e + elif torch.equal(ft_action,FaultAction.RETURN): + logger.info(f"Abort current batch at rank {self.rank}") + # TODO: Need to clear cache for current batch and destroy all group + return None + else: + # TODO: Need to clear cache for current batch and destroy all group + logger.info(f"Unknown fault action found at rank {self.rank} ") + raise e + + return wrapper + + def _handle_exception(self, ctx: RecoveryContext) -> torch.Tensor: + """ + Handle exception in recovery_chain and get fault action for the current batch + """ + try: + # 1. Handle exception in recovery_chain and get recovery status + local_recovery_status = self.recovery_chain.handle(ctx) + # 2. Report recovery status and get fault action + ft_action = self._coordinate_recovery(ctx,local_recovery_status) + return ft_action + except Exception as inner_e: + logger.error(f"Handle exception failed at rank {self.rank},get exception {inner_e}") + return FaultAction.RAISE_EXCEPTION + + def _coordinate_recovery(self,ctx:RecoveryContext, local_recovery_status:torch.Tensor) -> torch.Tensor: + """ + Rank 0 gathers recovery status and determines fault actions for each rank + Recovery status is categorized into restart recovery and fault recovery + Failure at any recovery stage will cause re-inference to fail + Therefore, re-inference is executed only when both restart recovery and fault recovery succeed + """ + + # determine fault action for single rank situation + if not dist.is_initialized() or self.world_size == 1: + reinit_status = self._restart_and_reinit(ctx) + if torch.equal(reinit_status,RecoveryStatus.SUCCESS): + return self._single_node_decision(local_recovery_status) + else: + return FaultAction.RAISE_EXCEPTION + #TODO:Should refactor codes below + """ + dummy_forward = {"hidden_states":torch.tensor([0])} + if self.vllm_config.parallel_config.distributed_executor_backend != ( + "external_launcher") and not get_pp_group().is_last_rank: + get_pp_group().send_tensor_dict(dummy_forward,all_gather_group=get_tp_group()) + """ + all_recovery_status = self._gather_statuses(local_recovery_status) + reinit_status = self._restart_and_reinit(ctx) + all_reinit_status = self._gather_statuses(reinit_status) + if self.rank == 0: + has_failed = any(torch.equal(status, RecoveryStatus.FAILED) for status in all_reinit_status) + if has_failed: + reinit_actions = self._analyze_global_status(all_reinit_status) + return self._scatter_ft_actions(reinit_actions) + else: + ft_actions = self._analyze_global_status(all_recovery_status) + return self._scatter_ft_actions(ft_actions) + else: + return self._receive_ft_actions() + + def _single_node_decision(self, local_status: torch.Tensor) -> torch.Tensor: + """ + Single rank situation,determine fault action base on local status + """ + if torch.equal(local_status, RecoveryStatus.SUCCESS): + return FaultAction.RECOMPUTE + else: + return FaultAction.RAISE_EXCEPTION + + def _gather_statuses(self, local_status:torch.Tensor) -> List[torch.Tensor]: + """Rank 0 gathers status from each rank""" + try: + if self.rank == 0: + gather_list = [torch.zeros_like(local_status) for _ in range(self.world_size)] + dist.gather( + local_status, + gather_list=gather_list, + dst=0, + group=FaultTolerance._recovery_group + ) + return gather_list + else: + # 其他rank只发送,不接收 + dist.gather(local_status, gather_list=None, dst=0,group=FaultTolerance._recovery_group) + return [] # 非rank0返回空列表 + except Exception as inner_e: + logger.error(f"Gather status failed,get exception:{inner_e}") + if self.rank == 0: + return [RecoveryStatus.FAILED for _ in range(self.world_size)] + return [] + + def _analyze_global_status(self, all_recovery_statuses: List[torch.Tensor]) -> List[torch.Tensor]: + """ + Analyze status and generate decisions + """ + success_ranks = [] + failure_ranks = [] + + for rank, recovery_status in enumerate(all_recovery_statuses): + if torch.equal(recovery_status, RecoveryStatus.SUCCESS): + success_ranks.append(rank) + elif torch.equal(recovery_status, RecoveryStatus.FAILED): + failure_ranks.append(rank) + else: + logger.warning(f"Unknown status tensor from rank {rank}: {recovery_status}") + failure_ranks.append(rank) + + logger.info(f"Global recovery: {len(success_ranks)} success, {len(failure_ranks)} failure") + + decisions = [] + if not failure_ranks: + logger.info("All ranks recovered, Determine RECOMPUTE for all rank") + decisions = [FaultAction.RECOMPUTE] * self.world_size + elif not success_ranks: + logger.warning("All ranks failed, Determine RAISE_EXCEPTION for all rank") + decisions = [FaultAction.RAISE_EXCEPTION] * self.world_size + else: + logger.warning(f"Partial recovery - success ranks: {success_ranks}") + for rank in range(self.world_size): + if rank in success_ranks: + decisions.append(FaultAction.RETURN) + else: + decisions.append(FaultAction.RAISE_EXCEPTION) + + return decisions + + def _scatter_ft_actions(self, ft_actions: List[torch.Tensor]) -> torch.Tensor: + """Rank 0 distributed fault action to each rank""" + recv_ft_action = torch.tensor([0]) + dist.scatter( + recv_ft_action, + scatter_list=ft_actions, + src=0, + group=FaultTolerance._recovery_group + ) + return recv_ft_action + + def _receive_ft_actions(self) -> torch.Tensor: + """Rank 1 ...N receive fault action""" + recv_ft_action = torch.tensor([0]) + dist.scatter( + recv_ft_action, + scatter_list=None, + src=0, + group=FaultTolerance._recovery_group + ) + return recv_ft_action + + def _restart_and_reinit(self,ctx:RecoveryContext) -> torch.Tensor: + """ + Restart device and reinit process group + """ + try: + torch_npu.npu.restart_device(torch.npu.current_device()) + torch.distributed.reinit_process_group(group=None,rebuild_link=False) + reinit_status = RecoveryStatus.SUCCESS + except Exception as inner_e: + logger.error(f"Failed to restart and reinit process group for rank {self.rank},get exception :{inner_e}") + ctx.exception = inner_e + reinit_status = RecoveryStatus.FAILED + return reinit_status + + def destroy_recovery_group(self): + """Destroy recovery process group and fault_aware""" + if FaultTolerance._recovery_group is None: + return + + logger.info("Destroying recovery process group") + try: + dist.destroy_process_group(FaultTolerance._recovery_group) + FaultTolerance._recovery_group = None + logger.info("Successfully destroyed recovery process group") + except Exception as e: + logger.error(f"Failed to destroy recovery process group: {e}") + diff --git a/vllm_ascend/worker/memory_block_info.py b/vllm_ascend/worker/memory_block_info.py new file mode 100644 index 00000000000..e154027cd24 --- /dev/null +++ b/vllm_ascend/worker/memory_block_info.py @@ -0,0 +1,54 @@ +from vllm_ascend.worker.common import UCEType +from typing import Tuple,List +class MemoryBlockInfo: + def __init__(self,model): + self.model = model + self.weight_blocks = {} + self.kvcache_blocks = {} + self.initialized = False + + def initialize(self): + self._get_weight_memory_info() + self._get_kv_memory_info() + self.initialized = True + + def _get_weight_memory_info(self): + weights_blocks = {} + state_dict = self.model.state_dict() + for name,param in state_dict.items(): + start_address = param.data_ptr() + size_bytes = param.numel() * param.element_size() + end_address = start_address + max(0,size_bytes - 1) + weights_blocks[name] = { + 'name':name, + 'start_address':start_address, + 'end_address':end_address, + } + self.weight_blocks = weights_blocks + def _get_kv_memory_info(self): + pass + + def category_address(self,ptr) -> Tuple[UCEType,List[str]]: + weight_type,weight_layer = self.is_weight_uce(ptr) + if weight_type != None: + return weight_type,weight_layer + + kv_type,kv_layer = self.is_kv_uce(ptr) + if kv_type != None: + return kv_type,kv_layer + + return UCEType.ACTIVATION_UCE,[] + + def is_weight_uce(self,ptr) -> Tuple[UCEType,List[str]]: + error_layer = [] + for name in self.weight_blocks: + start_address = self.weight_blocks[name]['start_address'] + end_address = self.weight_blocks[name]['end_address'] + if start_address <= int(ptr) <= end_address: + error_layer.append(name) + if len(error_layer) > 0: + return UCEType.WEIGHTS_UCE,error_layer + return None,[] + + def is_kv_uce(self,ptr) -> Tuple[UCEType,List[str]]: + pass \ No newline at end of file diff --git a/vllm_ascend/worker/recovery_chain.py b/vllm_ascend/worker/recovery_chain.py new file mode 100644 index 00000000000..0d155f94157 --- /dev/null +++ b/vllm_ascend/worker/recovery_chain.py @@ -0,0 +1,222 @@ +import json +import os.path +import torch +import yaml +import torch_npu + +from abc import ABC, abstractmethod +from typing import List,Tuple,Dict,Any +from vllm.logger import logger +from vllm_ascend.worker.common import RecoveryStatus,FaultStatus,UCEType,FaultToleranceLevel +from vllm_ascend.worker.recovery_context import RecoveryContext +from torch_npu.npu.utils import _get_uce_addr +from collections.abc import Generator +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME +from safetensors.torch import safe_open + +uce_error = [ + "uce error", + "hbm multi bit ecc error" +] +force_stop_error = ["force stop"] +network_error = [ + "suspect remote error", + "hccl op retry failed" +] + + +class RecoveryHandler(ABC): + + def __init__(self): + self.next_handler = None + + def set_next(self, handler: 'RecoveryHandler') -> 'RecoveryHandler': + """Set next handler""" + self.next_handler = handler + return handler + + @abstractmethod + def can_handle(self, ctx:RecoveryContext) -> bool: + pass + + @abstractmethod + def recover(self, ctx:RecoveryContext) -> torch.Tensor: + """Specific recovery function""" + pass + + def handle(self, ctx:RecoveryContext) -> torch.Tensor: + """ Entry point for RecoveryHandler """ + if self.can_handle(ctx): + return self.recover(ctx) + elif self.next_handler: + return self.next_handler.handle(ctx) + else: + logger.warning("No handler can process the exception") + #TODO: 需要添加入队或其他逻辑,单独返回失败状态会导致其余的卡hang住,没有装饰器能处理,或者干脆raise Exception(非目标故障场景) + return RecoveryStatus.FAILED + + +class ForceStopHandler(RecoveryHandler): + + def can_handle(self, ctx:RecoveryContext) -> bool: + error_str = str(ctx.exception).lower() + for error in force_stop_error: + if error in error_str: + return True + return False + + def recover(self, ctx:RecoveryContext) -> RecoveryStatus: + """Force stop needs no extra recovery""" + return RecoveryStatus.SUCCESS + +class NetworkHandler(RecoveryHandler): + + def can_handle(self, ctx:RecoveryContext) -> bool: + error_str = str(ctx.exception).lower() + for error in network_error: + if error in error_str: + ctx.fault_queue.put_nowait(FaultStatus.NETWORK_ERR) + return True + return False + + def recover(self, ctx:RecoveryContext) -> RecoveryStatus: + """恢复Network Error,无特殊操作""" + return RecoveryStatus.SUCCESS + +class UCEHandler(RecoveryHandler): + """统一处理UCE异常的处理器""" + def can_handle(self, ctx:RecoveryContext) -> bool: + """判断是否为UCE异常,如果是则入队""" + error_str = str(ctx.exception).lower() + for error in uce_error: + if error in error_str: + ctx.fault_queue.put_nowait(FaultStatus.UCE_ERR) + return True + return False + + def recover(self, ctx:RecoveryContext) -> RecoveryStatus: + """处理UCE异常,内部判断具体类型并执行恢复""" + #1.判断类型 + uce_result = self.classify_uce_type(ctx) + recovery_statuses = [] + #2.根据类型执行恢复策略 + for uce_type,layer_names in uce_result: + if uce_type == UCEType.KVCACHE_UCE.value: + recovery_statuses.append(self._recover_kv_cache_uce(ctx,layer_names)) + elif uce_type == UCEType.WEIGHTS_UCE.value: + recovery_statuses.append(self._recover_weight_uce(ctx,layer_names)) + elif uce_type == UCEType.ACTIVATION_UCE.value: + recovery_statuses.append(self._recover_activation_uce(ctx)) + else: + logger.error(f"UCEHandler: Unknown UCE type: {uce_type}") + recovery_statuses.append(RecoveryStatus.FAILED) + if RecoveryStatus.FAILED in recovery_statuses: + return RecoveryStatus.FAILED + return RecoveryStatus.SUCCESS + + def classify_uce_type(self,ctx:RecoveryContext) -> List[Tuple[UCEType,List[str]]]: + try: + memory_block_info = ctx.memory_block_info + if not memory_block_info.initialized: + memory_block_info.initialize() + uce_ptrs = _get_uce_addr() + if not uce_ptrs: + logger.error(f"UCEHandler: No UCE addr found") + return [(UCEType.UNKNOWN_UCE,[])] + uce_results = [] + for uceptr in uce_ptrs: + uce_type,layer_names = ctx.memory_block_info.category_address(uceptr) + uce_results.append((uce_type,layer_names)) + return uce_results + except Exception as e: + logger.error(f"UCEHandler:Failed to classify UCE type,{e}") + raise RuntimeError("Failed to classify UCE type") + + def _recover_weight_uce(self, ctx:RecoveryContext,layer_names:List[str]) -> RecoveryStatus: + if not layer_names: + logger.error(f"UCEHandler:layer_names is empty") + return RecoveryStatus.FAILED + + logger.info(f"UCEHandler: Recovering weight UCE for layer: {layer_names}") + # 2. 增量重加载权重 + original_weights_file_name = [] + for layer_name in layer_names: + original_weights_file_name.extend(self.map_to_original_param(layer_name)) + try: + weight_iterator = self.get_weight_iterator(ctx,original_weights_file_name) + loaded_weights = ctx.model.load_model(weight_iterator) + #TODO:这里可能要判断一下是否把需要加载的权重都加载成功了 + return RecoveryStatus.SUCCESS + except Exception as e: + logger.error(f"UCEHandler: Weight reload failed: {e}") + return RecoveryStatus.FAILED + + def _recover_kv_cache_uce(self, ctx:RecoveryContext,layer_names:List[str]) -> RecoveryStatus: + """恢复KV Cache UCE错误""" + level = ctx.level + + if level == FaultToleranceLevel.BASIC: + logger.warning("UCEHandler: KV Cache UCE in BASIC level, aborting recovery") + return RecoveryStatus.FAILED + + if level == FaultToleranceLevel.FULL: + try: + pass + except Exception as e: + logger.error(f"UCEHandler: KV Cache recovery failed: {e}") + return RecoveryStatus.FAILED + + logger.warning(f"UCEHandler: Unsupported fault tolerance level: {level}") + return RecoveryStatus.FAILED + + def _recover_activation_uce(self, ctx:RecoveryContext) -> RecoveryStatus: + """恢复激活值UCE错误""" + logger.info("UCEHandler: Activation UCE detected, no special recovery needed") + # 激活值UCE无需特殊恢复,直接返回成功 + return RecoveryStatus.SUCCESS + + + def _load_mapping_config(self,config_path:str)->Dict[str,List[Tuple[str,Any]]]: + with open(config_path,'r',encoding='utf-8') as f: + if config_path.endswith('.yaml') or config_path.endswith('.yml'): + return yaml.safe_load(f.read()) + elif config_path.endswith('.json'): + return json.load(f) + else: + raise ValueError("不支持该配置文件格式") + + def map_to_original_param(self,merged_name:str,mapping_config:Dict[str,List[Tuple[str,Any]]] = None) -> List[str]: + default_mapping={ + "qkv_proj":[ + ("q_proj","q"), + ("k_proj","k"), + ("v_proj","v"), + ], + "gate_up_proj":[ + ("gate_proj",0), + ("up_proj",1) + ] + } + mapping = mapping_config if mapping_config is not None else default_mapping + original_names = [] + for merged_param_name,mappings in mapping.items(): + if merged_param_name in merged_name: + for original_param_name,_ in mappings: + original_name = merged_name.replace(merged_param_name,original_param_name) + original_names.append(original_name) + break + if not original_names: + return [merged_name] + return original_names + def get_weight_iterator(self,ctx:RecoveryContext,original_names:List[str]) -> Generator[tuple[str,torch.Tensor],None,None]: + index_file_name = os.path.join(ctx.model_or_path,SAFE_WEIGHTS_INDEX_NAME) + with open(index_file_name) as f: + weight_map = json.load(f)["weight_map"] + weight_files_in_index = set() + for original_name in original_names: + weight_files_in_index.add(os.path.join(ctx.model_or_path,weight_map[original_name])) + for st_file in weight_files_in_index: + with safe_open(st_file,framework="pt") as f: + for name in f.keys(): + param = f.get_tensor(name) + yield name,param \ No newline at end of file diff --git a/vllm_ascend/worker/recovery_context.py b/vllm_ascend/worker/recovery_context.py new file mode 100644 index 00000000000..c68bb1bca26 --- /dev/null +++ b/vllm_ascend/worker/recovery_context.py @@ -0,0 +1,15 @@ +from queue import Queue +from vllm_ascend.worker.common import FaultToleranceLevel +from vllm_ascend.worker.memory_block_info import MemoryBlockInfo + +class RecoveryContext: + def __init__(self,model,level:FaultToleranceLevel,exception : 'Exception',rank: int,model_or_path:'str', + memory_block_info :'MemoryBlockInfo',fault_queue:'Queue'): + self.model = model + self.level = level + self.exception = exception + self.rank = rank + self.model_or_path = model_or_path + self.memory_block_info = memory_block_info + self.fault_queue = fault_queue + diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index b6b6008182f..2870e7cc0c5 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -51,6 +51,8 @@ from vllm_ascend.utils import (init_ascend_soc_version, is_enable_nz, register_ascend_customop, sleep_mode_enabled, try_register_lib) +from vllm_ascend.worker.common import FaultToleranceLevel +from vllm_ascend.worker.fault_tolerance import FaultTolerance from vllm_ascend.worker.model_runner_v1 import NPUModelRunner torch._dynamo.trace_rules.clear_lru_cache() # noqa: E402 @@ -311,6 +313,13 @@ def load_model(self) -> None: context = nullcontext() # type: ignore with context: self.model_runner.load_model() + self.fault_tolerance = FaultTolerance( + vllm_config=self.vllm_config, + model=self.model_runner.model, + level=FaultToleranceLevel.BASIC + ) + self.execute_model = self.fault_tolerance.fault_tolerance_decorator(self.execute_model) + self.execute_dummy_batch = self.fault_tolerance.fault_tolerance_decorator(self.execute_dummy_batch) def compile_or_warm_up_model(self) -> None: # Note: need to adapt for graph mode.