Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ep/include/ep_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,9 @@ struct LowLatencyLayout {

// Send buffer
size_t dispatch_send_buffer_bytes =
num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
(static_cast<size_t>(num_experts) + 1) *
static_cast<size_t>(num_max_dispatch_tokens_per_rank) *
num_bytes_per_dispatch_msg;
size_t combine_send_buffer_bytes = num_experts *
num_max_dispatch_tokens_per_rank *
num_bytes_per_combine_msg;
Expand Down
263 changes: 201 additions & 62 deletions ep/src/internode_ll.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ constexpr int kNumMaxWarpGroups = 16;
constexpr int kNumMaxWarpGroups = 32;
#endif

constexpr int kDispatchChunkSize = 8;

template <int kNumThreads>
__launch_bounds__(kNumThreads, 1) __global__
void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
Expand Down Expand Up @@ -53,9 +55,10 @@ __global__ __launch_bounds__(1024, 1) void dispatch(
int64_t* dispatch_wait_recv_cost_stats, void* rdma_recv_x,
int* rdma_recv_count, void* rdma_x, void const* x, int64_t const* topk_idx,
int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert,
int* next_clean, int* next_clean_second, int num_next_clean_int,
int num_tokens, int num_max_dispatch_tokens_per_rank, int num_topk,
int num_experts, int rank, int num_ranks, int num_warp_groups,
int* chunk_fill_counters, int* next_clean, int* next_clean_second,
int num_next_clean_int, int num_tokens,
int num_max_dispatch_tokens_per_rank, int num_chunks_per_expert,
int num_topk, int num_experts, int rank, int num_ranks, int num_warp_groups,
int num_warps_per_group, bool round_scale, int phases,
uint64_t const* d2h_channel_addrs, int num_d2h_channel_addrs,
int max_nvl_peers, int low_latency_buffer_idx,
Expand Down Expand Up @@ -94,6 +97,12 @@ __global__ __launch_bounds__(1024, 1) void dispatch(
size_t const num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0);

auto const rdma_x_uint8 = static_cast<uint8_t*>(rdma_x);
auto const rdma_x_chunk_uint8 =
rdma_x_uint8 +
static_cast<size_t>(num_max_dispatch_tokens_per_rank) * num_bytes_per_msg;
auto const rdma_x_chunk_int4 = reinterpret_cast<int4*>(rdma_x_chunk_uint8);

// Expert counts
__shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups];

Expand Down Expand Up @@ -121,12 +130,13 @@ __global__ __launch_bounds__(1024, 1) void dispatch(
for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) {
auto const x_int4 =
static_cast<int4 const*>(x) + token_idx * hidden_bf16_int4;
auto const rdma_x_src_idx = reinterpret_cast<int*>(
static_cast<uint8_t*>(rdma_x) + token_idx * num_bytes_per_msg);
auto const rdma_x_vec = reinterpret_cast<vec_t*>(
reinterpret_cast<uint8_t*>(rdma_x_src_idx) + sizeof(int4));
auto const rdma_x_token_uint8 =
rdma_x_uint8 + token_idx * num_bytes_per_msg;
auto const rdma_x_src_idx = reinterpret_cast<int*>(rdma_x_token_uint8);
auto const rdma_x_vec =
reinterpret_cast<vec_t*>(rdma_x_token_uint8 + sizeof(int4));
auto const rdma_x_scales = reinterpret_cast<float*>(
reinterpret_cast<uint8_t*>(rdma_x_vec) + hidden_bytes);
rdma_x_token_uint8 + sizeof(int4) + hidden_bytes);

// Overlap top-k index read and source token index writes
auto dst_expert_idx =
Expand Down Expand Up @@ -188,7 +198,7 @@ __global__ __launch_bounds__(1024, 1) void dispatch(
}
sync_barrier_1(num_threads);

// Issue IBGDA sends
// Issue IBGDA sends in chunks
if (dst_expert_idx >= 0) {
int slot_idx =
lane_id == 0
Expand All @@ -197,40 +207,85 @@ __global__ __launch_bounds__(1024, 1) void dispatch(
slot_idx = __shfl_sync(WARP_MASK, slot_idx, 0);
auto const dst_rank = dst_expert_idx / num_local_experts;
auto const dst_expert_local_idx = dst_expert_idx % num_local_experts;
auto const src_ptr = reinterpret_cast<uint64_t>(rdma_x_src_idx);
auto const dst_ptr =
reinterpret_cast<uint64_t>(rdma_recv_x) +
dst_expert_local_idx * num_ranks *
num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
slot_idx * num_bytes_per_msg;
auto const dst_p2p_ptr =
ipc_rdma_base_ptrs
? uccl::get_ipc_p2p_ptr(dst_ptr, ipc_rdma_base_ptrs, rank,
dst_rank, max_nvl_peers, 0)
: 0;
if (dst_p2p_ptr == 0) {
__threadfence_system();
uccl::nvshmemi_ibgda_put_nbi_warp(
dst_ptr - reinterpret_cast<uint64_t>(rdma_buffer_ptr),
src_ptr - reinterpret_cast<uint64_t>(rdma_buffer_ptr),
num_bytes_per_msg, dst_rank,
/*warp_id=*/dst_expert_local_idx, // NOTE(Yang): for selecting
// rb.
lane_id, slot_idx, d2h_channel_addrs, num_d2h_channel_addrs,
false, low_latency_buffer_idx);
} else {
// Intra-node: use direct memory copy via IPC
auto const* src_int4_ptr = reinterpret_cast<int4 const*>(src_ptr);
auto* dst_int4_ptr = reinterpret_cast<int4*>(dst_p2p_ptr);
UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr,
src_int4_ptr, ld_nc_global, st_na_global);
}
// Increase counter after finishing
auto const token_msg_int4 =
reinterpret_cast<int4 const*>(rdma_x_src_idx);
auto const chunk_slot_linear =
static_cast<size_t>(dst_expert_idx) *
static_cast<size_t>(num_max_dispatch_tokens_per_rank) +
static_cast<size_t>(slot_idx);
auto* chunk_msg_int4 =
rdma_x_chunk_int4 + chunk_slot_linear * num_int4_per_msg;
UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, chunk_msg_int4,
token_msg_int4, ld_nc_global, st_na_global);
__syncwarp();
lane_id == 0 ? atomic_add_release_global(
atomic_finish_counter_per_expert + dst_expert_idx, 1)
: 0;
__threadfence_system();

int chunk_id = slot_idx / kDispatchChunkSize;
int chunk_index = dst_expert_idx * num_chunks_per_expert + chunk_id;
int prev_fill =
lane_id == 0 ? atomicAdd(chunk_fill_counters + chunk_index, 1) : 0;
prev_fill = __shfl_sync(WARP_MASK, prev_fill, 0);
bool chunk_ready = (prev_fill + 1) == kDispatchChunkSize;

if (chunk_ready) {
int const chunk_base_slot = chunk_id * kDispatchChunkSize;
size_t const chunk_bytes =
static_cast<size_t>(kDispatchChunkSize) * num_bytes_per_msg;
auto const chunk_src_ptr =
reinterpret_cast<uint64_t>(rdma_x_chunk_uint8) +
(static_cast<uint64_t>(dst_expert_idx) *
static_cast<uint64_t>(num_max_dispatch_tokens_per_rank) +
static_cast<uint64_t>(chunk_base_slot)) *
num_bytes_per_msg;
auto const chunk_dst_ptr =
reinterpret_cast<uint64_t>(rdma_recv_x) +
static_cast<uint64_t>(dst_expert_local_idx) * num_ranks *
static_cast<uint64_t>(num_max_dispatch_tokens_per_rank) *
num_bytes_per_msg +
static_cast<uint64_t>(rank) *
static_cast<uint64_t>(num_max_dispatch_tokens_per_rank) *
num_bytes_per_msg +
static_cast<uint64_t>(chunk_base_slot) * num_bytes_per_msg;

uint64_t chunk_dst_p2p_ptr = 0;
if (ipc_rdma_base_ptrs && lane_id == 0)
chunk_dst_p2p_ptr =
uccl::get_ipc_p2p_ptr(chunk_dst_ptr, ipc_rdma_base_ptrs, rank,
dst_rank, max_nvl_peers, 0);
auto chunk_dst_p2p_lo = static_cast<uint32_t>(chunk_dst_p2p_ptr);
auto chunk_dst_p2p_hi =
static_cast<uint32_t>(chunk_dst_p2p_ptr >> 32);
chunk_dst_p2p_lo = __shfl_sync(WARP_MASK, chunk_dst_p2p_lo, 0);
chunk_dst_p2p_hi = __shfl_sync(WARP_MASK, chunk_dst_p2p_hi, 0);
chunk_dst_p2p_ptr = (static_cast<uint64_t>(chunk_dst_p2p_hi) << 32) |
chunk_dst_p2p_lo;

if (chunk_dst_p2p_ptr == 0) {
__threadfence_system();
uccl::nvshmemi_ibgda_put_nbi_warp(
chunk_dst_ptr - reinterpret_cast<uint64_t>(rdma_buffer_ptr),
chunk_src_ptr - reinterpret_cast<uint64_t>(rdma_buffer_ptr),
chunk_bytes, dst_rank,
/*warp_id=*/dst_expert_local_idx, lane_id, chunk_base_slot,
d2h_channel_addrs, num_d2h_channel_addrs, false,
low_latency_buffer_idx);
} else {
auto const* chunk_src_int4 =
reinterpret_cast<int4 const*>(chunk_src_ptr);
auto* chunk_dst_int4 = reinterpret_cast<int4*>(chunk_dst_p2p_ptr);
UNROLLED_WARP_COPY(
8, lane_id, num_int4_per_msg * kDispatchChunkSize,
chunk_dst_int4, chunk_src_int4, ld_nc_global, st_na_global);
}

__syncwarp();
if (lane_id == 0) {
st_release_sys_global(chunk_fill_counters + chunk_index, 0);
atomic_add_release_global(
atomic_finish_counter_per_expert + dst_expert_idx,
kDispatchChunkSize);
}
}
}
}
} else if (warp_id == num_warps - 1) {
Expand Down Expand Up @@ -269,11 +324,87 @@ __global__ __launch_bounds__(1024, 1) void dispatch(
#pragma unroll
for (int i = expert_begin_idx; i < expert_end_idx; ++i) {
auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]);
sum = __shfl_sync(WARP_MASK, sum, 0);
if (lane_id == 0) {
shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum;
atomic_add_release_global(atomic_finish_counter_per_expert + i,
FINISHED_SUM_TAG - sum);
}

if (sum > 0) {
int remainder = sum % kDispatchChunkSize;
if (remainder != 0) {
int const chunk_id = sum / kDispatchChunkSize;
int const chunk_index = i * num_chunks_per_expert + chunk_id;
if (lane_id == 0) {
while (ld_acquire_global(atomic_counter_per_expert + i) < sum)
;
while (ld_acquire_global(chunk_fill_counters + chunk_index) <
remainder)
;
}

auto const chunk_base_slot = chunk_id * kDispatchChunkSize;
size_t const chunk_bytes =
static_cast<size_t>(remainder) * num_bytes_per_msg;
auto const chunk_src_ptr =
reinterpret_cast<uint64_t>(rdma_x_chunk_uint8) +
(static_cast<uint64_t>(i) *
static_cast<uint64_t>(num_max_dispatch_tokens_per_rank) +
static_cast<uint64_t>(chunk_base_slot)) *
num_bytes_per_msg;
auto const dst_rank = i / num_local_experts;
auto const dst_expert_local_idx = i % num_local_experts;
auto const chunk_dst_ptr =
reinterpret_cast<uint64_t>(rdma_recv_x) +
static_cast<uint64_t>(dst_expert_local_idx) *
static_cast<uint64_t>(num_ranks) *
static_cast<uint64_t>(num_max_dispatch_tokens_per_rank) *
num_bytes_per_msg +
static_cast<uint64_t>(rank) *
static_cast<uint64_t>(num_max_dispatch_tokens_per_rank) *
num_bytes_per_msg +
static_cast<uint64_t>(chunk_base_slot) * num_bytes_per_msg;

uint64_t chunk_dst_p2p_ptr = 0;
if (ipc_rdma_base_ptrs && lane_id == 0)
chunk_dst_p2p_ptr =
uccl::get_ipc_p2p_ptr(chunk_dst_ptr, ipc_rdma_base_ptrs, rank,
dst_rank, max_nvl_peers, 0);
auto chunk_dst_p2p_lo = static_cast<uint32_t>(chunk_dst_p2p_ptr);
auto chunk_dst_p2p_hi =
static_cast<uint32_t>(chunk_dst_p2p_ptr >> 32);
chunk_dst_p2p_lo = __shfl_sync(WARP_MASK, chunk_dst_p2p_lo, 0);
chunk_dst_p2p_hi = __shfl_sync(WARP_MASK, chunk_dst_p2p_hi, 0);
chunk_dst_p2p_ptr = (static_cast<uint64_t>(chunk_dst_p2p_hi) << 32) |
chunk_dst_p2p_lo;

if (chunk_dst_p2p_ptr == 0) {
__threadfence_system();
uccl::nvshmemi_ibgda_put_nbi_warp(
chunk_dst_ptr - reinterpret_cast<uint64_t>(rdma_buffer_ptr),
chunk_src_ptr - reinterpret_cast<uint64_t>(rdma_buffer_ptr),
chunk_bytes, dst_rank,
/*warp_id=*/dst_expert_local_idx, lane_id, chunk_base_slot,
d2h_channel_addrs, num_d2h_channel_addrs, false,
low_latency_buffer_idx);
} else {
auto const* chunk_src_int4 =
reinterpret_cast<int4 const*>(chunk_src_ptr);
auto* chunk_dst_int4 = reinterpret_cast<int4*>(chunk_dst_p2p_ptr);
UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg * remainder,
chunk_dst_int4, chunk_src_int4, ld_nc_global,
st_na_global);
}

__syncwarp();
if (lane_id == 0) {
st_release_sys_global(chunk_fill_counters + chunk_index, 0);
atomic_add_release_global(atomic_finish_counter_per_expert + i,
remainder);
}
}
}
}
}
__syncthreads();
Expand Down Expand Up @@ -521,31 +652,39 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
auto atomic_counter_per_expert = static_cast<int*>(workspace);
auto atomic_finish_counter_per_expert =
atomic_counter_per_expert + num_experts;
EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);
int const num_chunks_per_expert =
ceil_div(num_max_dispatch_tokens_per_rank, kDispatchChunkSize);
auto chunk_fill_counters = atomic_finish_counter_per_expert + num_experts;
auto const required_workspace_ints =
static_cast<size_t>(num_experts) *
static_cast<size_t>(2 + num_chunks_per_expert);
EP_HOST_ASSERT(required_workspace_ints * sizeof(int) <=
static_cast<size_t>(NUM_WORKSPACE_BYTES));

// FP8 checks
if (use_ue8m0)
EP_HOST_ASSERT(round_scale and "UE8M0 SF requires `round_scale=True`");

#define DISPATCH_LAUNCH_CASE(hidden) \
{ \
auto dispatch_func = dispatch<false, false, hidden>; \
if (use_fp8 and not use_ue8m0) \
dispatch_func = dispatch<true, false, hidden>; \
if (use_fp8 and use_ue8m0) dispatch_func = dispatch<true, true, hidden>; \
LAUNCH_KERNEL( \
&cfg, dispatch_func, packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, packed_recv_count, \
cumulative_local_expert_recv_stats, dispatch_wait_recv_cost_stats, \
rdma_recv_x, rdma_recv_count, rdma_x, x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
next_clean, next_clean_second, num_next_clean_int, num_tokens, \
num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, \
num_ranks, num_warp_groups, num_warps_per_group, round_scale, phases, \
d2h_channel_addrs, num_d2h_channel_addrs, max_nvl_peers, \
low_latency_buffer_idx, ipc_rdma_base_ptrs, rdma_buffer_ptr, \
atomic_buffer_ptr, rdma_recv_count_internode); \
} \
#define DISPATCH_LAUNCH_CASE(hidden) \
{ \
auto dispatch_func = dispatch<false, false, hidden>; \
if (use_fp8 and not use_ue8m0) \
dispatch_func = dispatch<true, false, hidden>; \
if (use_fp8 and use_ue8m0) dispatch_func = dispatch<true, true, hidden>; \
LAUNCH_KERNEL( \
&cfg, dispatch_func, packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, packed_recv_count, \
cumulative_local_expert_recv_stats, dispatch_wait_recv_cost_stats, \
rdma_recv_x, rdma_recv_count, rdma_x, x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
chunk_fill_counters, next_clean, next_clean_second, \
num_next_clean_int, num_tokens, num_max_dispatch_tokens_per_rank, \
num_chunks_per_expert, num_topk, num_experts, rank, num_ranks, \
num_warp_groups, num_warps_per_group, round_scale, phases, \
d2h_channel_addrs, num_d2h_channel_addrs, max_nvl_peers, \
low_latency_buffer_idx, ipc_rdma_base_ptrs, rdma_buffer_ptr, \
atomic_buffer_ptr, rdma_recv_count_internode); \
} \
break
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
EP_HOST_ASSERT(num_warps * WARP_SIZE <= MAX_NTHREADS);
Expand Down