Skip to content

Conversation

@Peter-Lu-22
Copy link

Provide token-level reinference capability for vLLM on the Ascend platform,
currently only supporting token recomputation in network link failure scenarios under TP parallelism.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant fault tolerance and token-level re-inference mechanism for vLLM on the Ascend platform. The changes include new components for fault detection, a recovery handler chain, and distributed coordination for error recovery. My review has identified several critical and high-severity issues that should be addressed. These include critical bugs such as incorrect assertion syntax which bypasses important safety checks, and missing resource cleanup logic on failure paths which will lead to resource leaks. Additionally, there are high-severity design concerns, including the use of internal, non-public APIs, hardcoded model-specific configurations that limit the feature's applicability, and a hardcoded fault-tolerance level that restricts functionality. Addressing these points will improve the robustness, maintainability, and generality of this new fault tolerance system.

Comment on lines +36 to +42
assert(
torch.distributed.is_initialized()
),"Default torch process group must be initialized"

assert(
torch.distributed.is_gloo_available()
),"Gloo process group must be available"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The assert statements on lines 36 and 40 use tuple syntax assert(condition, message), which is a common pitfall in Python. A non-empty tuple always evaluates to True, so these assertions will never fail, even if the conditions torch.distributed.is_initialized() or torch.distributed.is_gloo_available() are false. This is a critical issue as it bypasses important validations, potentially allowing the program to proceed in an invalid state and causing hard-to-debug errors later.

Suggested change
assert(
torch.distributed.is_initialized()
),"Default torch process group must be initialized"
assert(
torch.distributed.is_gloo_available()
),"Gloo process group must be available"
assert torch.distributed.is_initialized(), "Default torch process group must be initialized"
assert torch.distributed.is_gloo_available(), "Gloo process group must be available"

Comment on lines +92 to +103
elif torch.equal(ft_action,FaultAction.RAISE_EXCEPTION):
logger.info(f"Raise exception at rank {self.rank}")
# TODO: Need to clear cache for current batch and destroy all group
raise e
elif torch.equal(ft_action,FaultAction.RETURN):
logger.info(f"Abort current batch at rank {self.rank}")
# TODO: Need to clear cache for current batch and destroy all group
return None
else:
# TODO: Need to clear cache for current batch and destroy all group
logger.info(f"Unknown fault action found at rank {self.rank} ")
raise e
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The fault_tolerance_decorator catches exceptions but does not clean up resources for the failed batch when the recovery action is RAISE_EXCEPTION or RETURN. The TODO comments indicate this is a known missing piece. Without freeing the KV cache blocks and other resources associated with the aborted requests, this will lead to resource leaks, eventually causing the system to hang or crash when it runs out of memory or cache blocks. This is a critical issue that must be addressed.

Comment on lines +15 to +48
class FaultStatus(Enum):
"""
Fault status which fault_tolerance put into fault_queue
"""
ACTIVE = torch.tensor([0])
UCE_ERR = torch.tensor([1])
FORCE_STOP = torch.tensor([2])
NETWORK_ERR = torch.tensor([3])

class FaultCommand:
"""
Fault command which rank 0 broadcast in fault_aware
"""
INIT_CMD = torch.tensor([0])
SILENCE_CMD = torch.tensor([1])
STOP_DEVICE_CMD = torch.tensor([2])

class UCEType(Enum):
"""
Specific uce type for HBM UCE
"""
WEIGHTS_UCE = "WEIGHTS_UCE"
KVCACHE_UCE = "KVCACHE_UCE"
ACTIVATION_UCE = "ACTIVATION_UCE"
UNKNOWN_UCE = "UNKNOWN_UCE"

class RecoveryStatus:
SUCCESS = torch.tensor([0])
FAILED = torch.tensor([1])

class FaultAction:
RAISE_EXCEPTION = torch.tensor([0])
RETURN = torch.tensor([1])
RECOMPUTE = torch.tensor([2]) No newline at end of file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The classes FaultStatus, FaultCommand, RecoveryStatus, and FaultAction use torch.tensor objects as values for constants. This is unconventional and can lead to subtle bugs related to device placement or tensor properties. It is also inconsistent, as FaultStatus is an Enum while the others are plain classes. A better practice is to use simple integer values in enums (e.g., using enum.IntEnum) and create tensors from them only when needed for distributed communication. This would improve readability, maintainability, and avoid potential pitfalls with using tensors as constants.

from enum import IntEnum

class FaultStatus(IntEnum):
    """
    Fault status which fault_tolerance put into fault_queue
    """
    ACTIVE = 0
    UCE_ERR = 1
    FORCE_STOP = 2
    NETWORK_ERR = 3

class FaultCommand(IntEnum):
    """
    Fault command which rank 0 broadcast in fault_aware
    """
    INIT_CMD = 0
    SILENCE_CMD = 1
    STOP_DEVICE_CMD = 2

class UCEType(Enum):
    """
    Specific uce type for HBM UCE
    """
    WEIGHTS_UCE = "WEIGHTS_UCE"
    KVCACHE_UCE = "KVCACHE_UCE"
    ACTIVATION_UCE = "ACTIVATION_UCE"
    UNKNOWN_UCE = "UNKNOWN_UCE"

class RecoveryStatus(IntEnum):
    SUCCESS = 0
    FAILED = 1

class FaultAction(IntEnum):
    RAISE_EXCEPTION = 0
    RETURN = 1
    RECOMPUTE = 2

memory_block_info = ctx.memory_block_info
if not memory_block_info.initialized:
memory_block_info.initialize()
uce_ptrs = _get_uce_addr()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The code uses _get_uce_addr (imported on line 12), which is an internal API from torch_npu.npu.utils as indicated by the leading underscore. Relying on internal APIs is risky because they are not part of the public contract and can be changed or removed without notice in future versions of torch_npu. This could break the fault tolerance feature unexpectedly upon a library update.

Comment on lines +187 to +209
def map_to_original_param(self,merged_name:str,mapping_config:Dict[str,List[Tuple[str,Any]]] = None) -> List[str]:
default_mapping={
"qkv_proj":[
("q_proj","q"),
("k_proj","k"),
("v_proj","v"),
],
"gate_up_proj":[
("gate_proj",0),
("up_proj",1)
]
}
mapping = mapping_config if mapping_config is not None else default_mapping
original_names = []
for merged_param_name,mappings in mapping.items():
if merged_param_name in merged_name:
for original_param_name,_ in mappings:
original_name = merged_name.replace(merged_param_name,original_param_name)
original_names.append(original_name)
break
if not original_names:
return [merged_name]
return original_names
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The map_to_original_param method contains a hardcoded default_mapping for converting merged parameter names (like qkv_proj) back to their original names. This mapping is specific to certain model architectures (e.g., Llama-style models) and will cause the weight recovery feature to fail for models that use different naming conventions. The presence of the unused _load_mapping_config function suggests this was intended to be configurable. To make this feature more general and robust, the mapping should be loaded from a model-specific configuration.

self.fault_tolerance = FaultTolerance(
vllm_config=self.vllm_config,
model=self.model_runner.model,
level=FaultToleranceLevel.BASIC
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The fault tolerance level is hardcoded to FaultToleranceLevel.BASIC. This prevents users from enabling FULL fault tolerance, which is defined to include features like KV cache UCE recovery. To allow users to leverage all fault tolerance capabilities, this level should be made configurable, for example, through VllmConfig or AscendConfig.

Suggested change
level=FaultToleranceLevel.BASIC
level=self.vllm_config.ascend_config.fault_tolerance_level

@wangxiyuan wangxiyuan added the hold-on The PR should be hold-on but no need to release label Dec 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

hold-on The PR should be hold-on but no need to release module:core

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants