Skip to content
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
b7c1fb1
EP: port internode_dispatch to amd.
zhenhuang12 Oct 31, 2025
05e7df8
EP: port internode_combine to amd.
zhenhuang12 Nov 3, 2025
04c6a14
ziming fix amd port
MaoZiming Nov 4, 2025
946db83
fixing wr bug
MaoZiming Nov 4, 2025
fbda03b
debugging
MaoZiming Nov 4, 2025
d6f3355
clean
MaoZiming Nov 4, 2025
1903181
Merge branch 'zm-amd-port' of https://github.com/uccl-project/uccl in…
MaoZiming Nov 4, 2025
7324fcf
checkpt
MaoZiming Nov 5, 2025
b136aaa
merge main
MaoZiming Nov 5, 2025
4d63369
adding wr_id_to_wr_ids emplace for normal mode atomics
MaoZiming Nov 5, 2025
de4abc4
EP: fix ep internode
zhenhuang12 Nov 6, 2025
49c0bab
EP: fix RDMAAndNVLForwarder copy data
zhenhuang12 Nov 11, 2025
e5b2ef3
merge main
MaoZiming Nov 11, 2025
a4d9d4b
Merge branch 'main' of https://github.com/uccl-project/uccl into zm-a…
YangZhou1997 Nov 14, 2025
026a836
Merge branch 'main' of https://github.com/uccl-project/uccl into zm-a…
YangZhou1997 Nov 14, 2025
b64412d
debugging
YangZhou1997 Nov 15, 2025
9cb051f
run on nebius
YangZhou1997 Nov 15, 2025
f19fd62
merge with main
YangZhou1997 Nov 24, 2025
3a7d093
fixing setup.py on amd
YangZhou1997 Nov 24, 2025
8c86136
add printf to internode
YangZhou1997 Nov 24, 2025
fa1d720
Merge branch 'main' into yang-amd-normal
zhenhuang12 Nov 26, 2025
b6451b4
debug internode on amd-gpu
zhenhuang12 Nov 26, 2025
a142d9e
trying to debug dispatch kernel hang issues, but fails
YangZhou1997 Nov 28, 2025
a2d9f26
merge
YangZhou1997 Nov 28, 2025
5d38ecb
EP: fix normal dispatch bug
zhenhuang12 Dec 1, 2025
33e9bb1
Merge branch 'main' into yang-amd-normal
zhenhuang12 Dec 1, 2025
97d905e
EP: format code
zhenhuang12 Dec 1, 2025
4a44245
EP: restore internode_ll.cu
zhenhuang12 Dec 1, 2025
61eba6f
EP: format code
zhenhuang12 Dec 1, 2025
ce30d31
Merge branch 'main' into yang-amd-normal
YangZhou1997 Dec 2, 2025
6687cb6
EP: modify as suggested.
zhenhuang12 Dec 2, 2025
29f331f
Merge branch 'main' into yang-amd-normal
YangZhou1997 Dec 2, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ep/bench/run_ep.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ if [ "$MODE" = "ll" ]; then
else
torchrun --nnodes=$NNODES --nproc_per_node=8 --node_rank=$RANK \
--master_addr=$MAIN_IP --master_port=12355 \
test_internode.py --num-tokens=4096 \
--log-dir=./logs --redirect 3 test_internode.py --num-tokens=1024 \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is --redirect 3?

--hidden=7168 --num-topk=8 --num-experts=288 --test-ll-compatibility
fi
# --log-dir=logs --redirect=3
4 changes: 2 additions & 2 deletions ep/include/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ extern bool use_ll_sl;
#define kMaxInflightNormal 8
#define kBatchSize 32
#define kIterations 40000
#define kNumProxyThs 4
#define kNumProxyThs 1
#define kTestNumGpuThPerBlock 1
#define kObjectSize 7168 // 7 KB
// #define kObjectSize 10752 // 10.5 KB
Expand All @@ -48,7 +48,7 @@ extern bool use_ll_sl;
#define kMaxOutstandingRecvs 2048
#define kSenderAckQueueDepth 2048
#define kWarmupOps 10000
#define kChannelPerProxy 8
#define kChannelPerProxy 1
// TODO(MaoZiming): I tried to fit more bits, but this eats into offset and
// values.
#define kReorderingBufferSize 16 // Right now only 4 bits.
Expand Down
8 changes: 8 additions & 0 deletions ep/include/ep_configs.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,19 @@
// #define ENABLE_FAST_DEBUG
#ifndef ENABLE_FAST_DEBUG
#define NUM_CPU_TIMEOUT_SECS 100
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
#define NUM_TIMEOUT_CYCLES 20000000000ull
#else
#define NUM_TIMEOUT_CYCLES 200000000000ull // 200G cycles ~= 100s
#endif
#else
#define NUM_CPU_TIMEOUT_SECS 10
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
#define NUM_TIMEOUT_CYCLES 2000000000ull
#else
#define NUM_TIMEOUT_CYCLES 20000000000ull // 20G cycles ~= 10s
#endif
#endif

#define LOW_LATENCY_SEND_PHASE 1
#define LOW_LATENCY_RECV_PHASE 2
Expand Down
26 changes: 25 additions & 1 deletion ep/include/ep_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,29 @@ __forceinline__ __device__ float fast_pow2(int x) {
return *reinterpret_cast<float*>(&bits_x);
}

#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
__forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale,
float& scale_inv,
bool round_scale) {
if (!isfinite(amax) || amax <= 0.0f) {
scale = 1.0f;
scale_inv = 1.0f;
return;
}
float t = amax * kFinfoAmaxInvE4M3;
if (round_scale) {
int e;
frexpf(t, &e);
scale_inv = ldexpf(1.0f, e);
scale = ldexpf(1.0f, -e);
} else {
scale_inv = t;
scale = kFinfoAmaxE4M3 / amax;
}
if (!isfinite(scale) || scale <= 0.0f) scale = 1.0f;
if (!isfinite(scale_inv) || scale_inv <= 0.0f) scale_inv = 1.0f;
}
#else
__forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale,
float& scale_inv,
bool round_scale) {
Expand All @@ -359,6 +382,7 @@ __forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale,
scale = kFinfoAmaxE4M3 / amax;
}
}
#endif

// `ld.global.nc.L1::no_allocate` will be translated into
// `LDG.E.NA.[width].CONSTANT` in SASS
Expand Down Expand Up @@ -901,7 +925,7 @@ __device__ __forceinline__ void st_relaxed_sys_global(int const* ptr, int val) {
__device__ __forceinline__ int ld_acquire_cta(int const* ptr) {
int ret;
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
HIP_ATOMIC_LOAD(ptr, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_WORKGROUP);
ret = HIP_ATOMIC_LOAD(ptr, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_WORKGROUP);
#else
asm volatile("ld.acquire.cta.s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
#endif
Expand Down
22 changes: 16 additions & 6 deletions ep/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,20 @@
import subprocess
import setuptools
from glob import glob
import torch
import shutil
import site

from pathlib import Path

# Set ROCm architecture early, before importing torch
# This prevents PyTorch from compiling for all architectures
if not os.getenv("TORCH_CUDA_ARCH_LIST"):
# Default to gfx942, but can be overridden by environment variable
default_rocm_arch = os.getenv("PYTORCH_ROCM_ARCH", "gfx942")
os.environ["PYTORCH_ROCM_ARCH"] = default_rocm_arch
os.environ["TORCH_CUDA_ARCH_LIST"] = default_rocm_arch

import torch
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
from setuptools.command.install import install

Expand Down Expand Up @@ -156,10 +164,9 @@ def run(self):
if float(default_arch) >= 9.0:
nvcc_flags.extend(["--ptxas-options=--register-usage-level=10"])

os.environ["TORCH_CUDA_ARCH_LIST"] = os.getenv(
"TORCH_CUDA_ARCH_LIST", default_arch
)
device_arch = os.environ["TORCH_CUDA_ARCH_LIST"]
# Set architecture environment variable before creating CUDAExtension
device_arch = os.getenv("TORCH_CUDA_ARCH_LIST", default_arch)
os.environ["TORCH_CUDA_ARCH_LIST"] = device_arch
else:
# Disable SM90 features on AMD
cxx_flags.append("-DDISABLE_SM90_FEATURES")
Expand All @@ -168,8 +175,11 @@ def run(self):
cxx_flags.append("-DDISABLE_AGGRESSIVE_ATOMIC")
nvcc_flags.append("-DDISABLE_AGGRESSIVE_ATOMIC")

cxx_flags.append("-DENABLE_FAST_DEBUG")
nvcc_flags.append("-DENABLE_FAST_DEBUG")

# Get device architecture (already set at top of file)
device_arch = os.getenv("TORCH_CUDA_ARCH_LIST", "gfx942")
os.environ["PYTORCH_ROCM_ARCH"] = device_arch

# Disable LD/ST tricks, as some CUDA version does not support `.L1::no_allocate`
# Only enable aggressive PTX instructions for SM 9.0+ (H100/H800/B200)
Expand Down
74 changes: 50 additions & 24 deletions ep/src/internode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,7 @@ __global__ void notify_dispatch(
i)[NUM_MAX_NVL_PEERS + num_rdma_experts];
recv_rdma_rank_prefix_sum[i] = sum;
}
while (ld_volatile_global(moe_recv_rdma_counter_mapped) != -1)
;
while (ld_volatile_global(moe_recv_rdma_counter_mapped) != -1);
*moe_recv_rdma_counter_mapped = sum;
}

Expand Down Expand Up @@ -303,8 +302,7 @@ __global__ void notify_dispatch(
sum += nvl_recv_num_tokens_per_rank.buffer(src_nvl_rank)[src_rdma_rank];
recv_gbl_rank_prefix_sum[i] = sum;
}
while (ld_volatile_global(moe_recv_counter_mapped) != -1)
;
while (ld_volatile_global(moe_recv_counter_mapped) != -1);
*moe_recv_counter_mapped = sum;
}
if (thread_id < num_nvl_experts) {
Expand All @@ -314,8 +312,7 @@ __global__ void notify_dispatch(
sum += nvl_recv_num_tokens_per_expert.buffer(i)[thread_id];
sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment;
while (ld_volatile_global(moe_recv_expert_counter_mapped + thread_id) !=
-1)
;
-1);
moe_recv_expert_counter_mapped[thread_id] = sum;
}

Expand Down Expand Up @@ -706,9 +703,7 @@ __global__ void __launch_bounds__(
__syncwarp();

// Skip the token which does not belong to this warp
if ((token_idx - token_start_idx) % kNumDispatchRDMASenderWarps !=
warp_id)
continue;
if ((token_idx - token_start_idx) % 2 != warp_id) continue;
Copy link
Collaborator

@zhenhuang12 zhenhuang12 Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I set kNumDispatchRDMASenderWarps = 2 here, it works! but I'm still working on it.
There are two possible reasons:

  • rdma transaction porting error
  • rdma send warp size 2 limit the speed of rdma command commit to avoid atomic errors.

Copy link
Member Author

@YangZhou1997 YangZhou1997 Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhenhuang12 indeed, setting to 2 makes it work! Nice!

cc @MaoZiming, @CalebZ9909

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@YangZhou1997 The changes don't explain the cause of the error, but provide directions for troubleshooting. I'm currently planning to perform a more in-depth troubleshooting.

auto rdma_tail_idx =
is_token_in_rank_uint64 == 0 ? -1 : global_rdma_tail_idx - 1;

Expand Down Expand Up @@ -819,7 +814,7 @@ __global__ void __launch_bounds__(
acquire_lock(rdma_send_channel_lock + lane_id);
auto latest_tail = rdma_send_channel_tail[lane_id];
auto offset = rdma_tail_idx - latest_tail;
while (offset >= WARP_SIZE) {
while (offset >= 32) {
release_lock(rdma_send_channel_lock + lane_id);
acquire_lock(rdma_send_channel_lock + lane_id);
latest_tail = rdma_send_channel_tail[lane_id];
Expand All @@ -830,8 +825,7 @@ __global__ void __launch_bounds__(
// Add the bit and move the ones if possible
auto window = rdma_send_channel_window[lane_id] | (1u << offset);
if (offset == 0) {
auto num_empty_slots =
(~window) == 0 ? WARP_SIZE : __ffs(~window) - 1;
auto num_empty_slots = (~window) == 0 ? 32 : __ffs(~window) - 1;
st_release_cta(rdma_send_channel_tail + lane_id,
latest_tail + num_empty_slots);
window >>= num_empty_slots;
Expand Down Expand Up @@ -1114,9 +1108,10 @@ __global__ void __launch_bounds__(

// Copy data
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
UNROLLED_WARP_COPY(
5, lane_id, hidden_int4, reinterpret_cast<int4*>(dst_shifted),
reinterpret_cast<int4*>(shifted), ld_nc_global, st_na_global);
UNROLLED_WARP_COPY(5, lane_id, num_bytes_per_token / sizeof(int4),
reinterpret_cast<int4*>(dst_shifted),
reinterpret_cast<int4*>(shifted), ld_nc_global,
st_na_global);
Comment on lines +1205 to +1208
Copy link

Copilot AI Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The magic number '5' in the UNROLLED_WARP_COPY macro is unclear. Consider replacing it with a named constant (e.g., 'UNROLL_FACTOR' or 'NUM_ITERATIONS') to improve code readability and maintainability.

Copilot uses AI. Check for mistakes.
#else
if (lane_id == 0) {
tma_load_1d(tma_buffer, shifted, tma_mbarrier, num_bytes_per_token,
Expand Down Expand Up @@ -1148,9 +1143,17 @@ __global__ void __launch_bounds__(

// Move tail index
__syncwarp();
if (lane_id == 0)
if (lane_id == 0) {
// /*******************************************************************/
// printf(
// "DeepEP dispatch NVL forwarder, channel: %d, RDMA: %d, "
// "src NVL: %d, dst NVL: %d, head: %d, tail: %d\n",
// channel_id, rdma_rank, nvl_rank, target_rank,
// cached_nvl_channel_head, cached_nvl_channel_tail);
st_release_sys_global(nvl_channel_tail.buffer(),
cached_nvl_channel_tail);
// /*******************************************************************/
}
}
// Retired
__syncwarp();
Expand Down Expand Up @@ -1240,6 +1243,7 @@ __global__ void __launch_bounds__(
}
}
num_tokens_to_recv = warp_reduce_sum(end_offset - start_offset);

auto num_tokens_to_recv_original = num_tokens_to_recv;
// Save for combine usage
if (lane_id < kNumRDMARanks and not kCachedMode)
Expand All @@ -1260,17 +1264,31 @@ __global__ void __launch_bounds__(

cached_channel_tail_idx = __shfl_sync(
WARP_MASK, ld_acquire_sys_global(nvl_channel_tail.buffer()), 0);
// if (lane_id == 0) {
/*******************************************************************/
// printf(
// "DeepEP dispatch NVL receiver check, channel: %d, RDMA: %d, src
// " "NVL: %d, dst NVL: %d, head: %d, tail: %d, "
// "num_tokens_to_recv_original: %d, "
// "num_tokens_to_recv: %d\n",
// channel_id, rdma_rank, target_rank, nvl_rank,
// ld_acquire_sys_global(nvl_channel_head.buffer()),
// ld_acquire_sys_global(nvl_channel_tail.buffer()),
// num_tokens_to_recv_original, num_tokens_to_recv);
/*******************************************************************/
// }
// Timeout check
if (lane_id == 0 and clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf(
"DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, "
"nvl: %d, src NVL: %d, head: %d, tail: %d, "
"num_tokens_to_recv_original: %d, last_recv_token_idx: %lld, "
"num_tokens_to_recv_original: %d, num_tokens_to_recv: %d, "
"last_recv_token_idx: %lld, "
"next_expected_token_idx: %lld\n",
channel_id, rdma_rank, nvl_rank, src_nvl_rank,
cached_channel_head_idx, cached_channel_tail_idx,
num_tokens_to_recv_original, last_recv_token_idx,
(long long)(last_recv_token_idx + 1));
num_tokens_to_recv_original, num_tokens_to_recv,
last_recv_token_idx, (long long)(last_recv_token_idx + 1));
trap();
}
}
Expand Down Expand Up @@ -1298,6 +1316,11 @@ __global__ void __launch_bounds__(
5, lane_id, hidden_int4,
reinterpret_cast<int4*>(recv_x + recv_token_idx * hidden_int4),
reinterpret_cast<int4*>(shifted), ld_nc_global, st_na_global);
if (scale_aligned)
UNROLLED_WARP_COPY(1, lane_id, num_scales,
recv_x_scales + recv_token_idx * num_scales,
reinterpret_cast<float*>(shifted + hidden_bytes),
ld_nc_global, st_na_global);
#else
if (lane_id == 0) {
tma_load_1d(tma_buffer, shifted, tma_mbarrier, tma_load_bytes);
Expand Down Expand Up @@ -1691,6 +1714,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx,
auto cached_notify_func = low_latency_mode
? cached_notify<true, kNumTMABytesPerWarp>
: cached_notify<false, kNumTMABytesPerWarp>;

SETUP_LAUNCH_CONFIG(num_channels * 2, num_threads, stream);
SET_SHARED_MEMORY_FOR_TMA(cached_notify_func);
LAUNCH_KERNEL(&cfg, cached_notify_func, rdma_clean_meta.first,
Expand Down Expand Up @@ -1882,7 +1906,7 @@ template <
int kNumWarpsPerForwarder = (kNumCombineForwarderWarps / kNumRDMARanks > 0)
? kNumCombineForwarderWarps / kNumRDMARanks
: 1,
int kNumForwarders = kNumRDMARanks* kNumWarpsPerForwarder,
int kNumForwarders = kNumRDMARanks * kNumWarpsPerForwarder,
int kNumRDMAReceivers = kNumForwarders - NUM_MAX_NVL_PEERS>
__global__ void __launch_bounds__((kNumForwarders + 1) * WARP_SIZE, 1)
combine(int4* combined_x, float* combined_topk_weights,
Expand Down Expand Up @@ -2051,7 +2075,7 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * WARP_SIZE, 1)
channel_id, rdma_rank, nvl_rank, dst_nvl_rank, lane_id,
ld_volatile_global(nvl_channel_head.buffer() + lane_id),
cached_channel_tail_idx, token_start_idx, token_end_idx);
trap();
// trap();
}
}

Expand Down Expand Up @@ -2303,7 +2327,7 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * WARP_SIZE, 1)
channel_id, rdma_rank, nvl_rank, dst_rdma_rank,
ld_acquire_sys_global(rdma_channel_head.buffer(dst_rdma_rank)),
token_start_idx, num_chunked_tokens);
trap();
// trap();
}
}
sync_large_warp();
Expand Down Expand Up @@ -2340,7 +2364,7 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * WARP_SIZE, 1)
channel_id, rdma_rank, nvl_rank, lane_id, dst_rdma_rank,
cached_nvl_channel_tail_idx, token_idx, num_tokens_to_combine,
sub_warp_id, kNumWarpsPerForwarder, expected_head);
trap();
// trap();
}
}

Expand Down Expand Up @@ -2484,7 +2508,7 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * WARP_SIZE, 1)
"nvl: %d, src RDMA: %d, tail: %d, waiting: %ld, expect: %d\n",
channel_id, rdma_rank, nvl_rank, lane_id,
cached_channel_tail_idx, token_idx, expected_head);
trap();
// trap();
}
}
__syncwarp();
Expand Down Expand Up @@ -2660,6 +2684,8 @@ void combine(cudaDataType_t type, void* combined_x,
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens >= num_warps_per_forwarder);
EP_HOST_ASSERT(type == CUDA_R_16BF);

printf("combine num_sms = %d, num_threads = %d", num_channels * 2,
(num_forwarder_warps + 1) * WARP_SIZE);
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
EP_HOST_ASSERT((num_forwarder_warps + 1) * WARP_SIZE <= MAX_NTHREADS);
#endif
Expand Down
Loading
Loading