Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions lightllm/common/fused_moe/grouped_fused_moe_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import triton.language as tl
from typing import Any, Callable, Dict, Optional, Tuple
import torch.distributed as dist
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.utils.log_utils import init_logger
from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd
from lightllm.common.fused_moe.moe_silu_and_mul_mix_quant_ep import silu_and_mul_masked_post_quant_fwd
Expand Down Expand Up @@ -142,6 +143,15 @@ def fused_experts_impl(

# scatter
all_tokens = sum(num_recv_tokens_per_expert_list) # calcu padding all nums.

if get_env_start_args().enable_ep_fake_balance:
rank = dist.get_rank()
if rank == 0:
logger.info(
f"prefill, [{rank}], all_tokens = {all_tokens}, "
f"num_recv_tokens_per_expert_list: {num_recv_tokens_per_expert_list}"
)

# gather_out shape [recive_num_tokens, hidden]
gather_out = torch.empty_like(recv_x[0], device=hidden_states.device, dtype=hidden_states.dtype)
if all_tokens > 0:
Expand Down Expand Up @@ -219,6 +229,18 @@ def fused_experts_impl(
async_finish=False,
return_recv_hook=False,
)

# NOTE: when decoding graph is open, we can not call logger. Thus it can only be used when --disable_cudagraph
args = get_env_start_args()
if args.enable_ep_fake_balance and args.disable_cudagraph:
rank = dist.get_rank()
all_tokens = sum(masked_m)
if rank == 0:
logger.info(
f"decode, [{rank}], all_tokens = {all_tokens}, "
f"expected_m = {expected_m}, num_recv_tokens_per_expert: {masked_m}"
)

# deepgemm
gemm_out_b = masked_group_gemm(recv_x, masked_m, hidden_states.dtype, w1, w1_scale, w2, w2_scale, expected_m)
# low latency combine
Expand Down
9 changes: 9 additions & 0 deletions lightllm/common/fused_moe/topk_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import torch
from lightllm.utils.sgl_utils import sgl_ops
from lightllm.utils.light_utils import light_ops
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.utils.balance_utils import BalancedTensor
from typing import Callable, List, Optional, Tuple
from lightllm.common.fused_moe.softmax_topk import softmax_topk

Expand Down Expand Up @@ -227,4 +229,11 @@ def select_experts(
hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize
)

# Enable EP fake balance
if get_env_start_args().enable_ep_fake_balance:
num_tokens, num_experts = router_logits.shape
balanced_tensor_collection = BalancedTensor(num_experts=num_experts, num_selected=top_k)
balance_topk_ids = balanced_tensor_collection.get_balance_topk_ids(num_tokens)
topk_ids.copy_(balance_topk_ids)

return topk_weights, topk_ids
3 changes: 3 additions & 0 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,9 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--enable_monitor_auth", action="store_true", help="Whether to open authentication for push_gateway"
)

parser.add_argument("--enable_ep_fake_balance", action="store_true", help="Enable the fake balance of the EP mode")

parser.add_argument("--disable_cudagraph", action="store_true", help="Disable the cudagraph of the decoding stage")

parser.add_argument(
Expand Down
1 change: 1 addition & 0 deletions lightllm/server/core/objs/start_args_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class StartArgs:
visual_dp: int = field(default=1)
visual_nccl_ports: List[int] = field(default_factory=lambda: [29500])
enable_monitor_auth: bool = field(default=False)
enable_ep_fake_balance: bool = field(default=False)
disable_cudagraph: bool = field(default=False)
graph_max_batch_size: int = field(default=256)
graph_split_batch_size: int = field(default=32)
Expand Down
71 changes: 71 additions & 0 deletions lightllm/utils/balance_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import torch
import os

import threading

from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)


def singleton_threadsafe(cls):
instances = {}
lock = threading.Lock()

def get_instance(*args, **kwargs):
# A key that includes the arguments is needed for parameter-dependent singletons.
# Using a tuple of args and a frozenset of kwargs items makes it hashable.
key = (cls, args, frozenset(kwargs.items()))
with lock:
if key not in instances:
instances[key] = cls(*args, **kwargs)
return instances[key]

return get_instance


@singleton_threadsafe
class BalancedTensor:
def __init__(self, num_experts=256, num_selected=8):
self.balanced_tensors = {}
self.num_experts = num_experts
self.num_selected = num_selected

def generate_balanced_tensor(self, num_tokens):
# Evenly distribute num_tokens to num_selected experts out of num_experts.
# Note that the num_selected experts activated by a token cannot be repeated.
# Performance is not that important, as it is only activated in special scenarios.
tensor = torch.zeros((num_tokens, self.num_selected), dtype=torch.int, device="cuda")
expert_load = torch.zeros(self.num_experts, dtype=torch.int, device="cuda")

for i in range(num_tokens):
available_experts = torch.arange(self.num_experts, device="cuda")
selected = []
for _ in range(self.num_selected):
current_load = expert_load[available_experts]
min_load_indices = torch.where(current_load == current_load.min())[0]
if len(min_load_indices) > 1:
# If there are multiple least-loaded experts, select one randomly
chosen_index = torch.randint(0, len(min_load_indices), (1,), device="cuda").item()
chosen_expert_index = min_load_indices[chosen_index]
else:
chosen_expert_index = min_load_indices[0]
chosen_expert = available_experts[chosen_expert_index]
selected.append(chosen_expert)
# Remove the selected expert from the list of available experts
available_experts = torch.cat(
[available_experts[:chosen_expert_index], available_experts[chosen_expert_index + 1 :]]
)
expert_load[chosen_expert] += 1

tensor[i] = torch.tensor(selected, dtype=torch.int, device="cuda")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of generate_balanced_tensor is inefficient due to the use of torch.cat inside a loop. This creates a new tensor and copies data in every iteration, which can be slow for large num_tokens or num_experts. A more performant approach would be to use a boolean mask to keep track of selected experts, avoiding the expensive torch.cat operation. This can significantly reduce the overhead.

        tensor = torch.empty((num_tokens, self.num_selected), dtype=torch.int, device="cuda")
        expert_load = torch.zeros(self.num_experts, dtype=torch.int, device="cuda")

        for i in range(num_tokens):
            selected_mask = torch.zeros(self.num_experts, dtype=torch.bool, device="cuda")
            for j in range(self.num_selected):
                # Use a large value for already selected experts to exclude them
                load_view = torch.where(selected_mask, torch.iinfo(expert_load.dtype).max, expert_load)

                min_load_indices = torch.where(load_view == load_view.min())[0]

                if len(min_load_indices) > 1:
                    # If there are multiple least-loaded experts, select one randomly
                    rand_idx = torch.randint(0, len(min_load_indices), (1,), device="cuda").item()
                    chosen_expert = min_load_indices[rand_idx]
                else:
                    chosen_expert = min_load_indices[0]

                tensor[i, j] = chosen_expert
                expert_load[chosen_expert] += 1
                selected_mask[chosen_expert] = True


return tensor

def get_balance_topk_ids(self, num_tokens):
if num_tokens in self.balanced_tensors:
return self.balanced_tensors[num_tokens]

tensor = self.generate_balanced_tensor(num_tokens)
self.balanced_tensors[num_tokens] = tensor
return tensor