Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
b7c1fb1
EP: port internode_dispatch to amd.
zhenhuang12 Oct 31, 2025
05e7df8
EP: port internode_combine to amd.
zhenhuang12 Nov 3, 2025
04c6a14
ziming fix amd port
MaoZiming Nov 4, 2025
946db83
fixing wr bug
MaoZiming Nov 4, 2025
fbda03b
debugging
MaoZiming Nov 4, 2025
d6f3355
clean
MaoZiming Nov 4, 2025
1903181
Merge branch 'zm-amd-port' of https://github.com/uccl-project/uccl in…
MaoZiming Nov 4, 2025
7324fcf
checkpt
MaoZiming Nov 5, 2025
b136aaa
merge main
MaoZiming Nov 5, 2025
4d63369
adding wr_id_to_wr_ids emplace for normal mode atomics
MaoZiming Nov 5, 2025
de4abc4
EP: fix ep internode
zhenhuang12 Nov 6, 2025
49c0bab
EP: fix RDMAAndNVLForwarder copy data
zhenhuang12 Nov 11, 2025
e5b2ef3
merge main
MaoZiming Nov 11, 2025
a4d9d4b
Merge branch 'main' of https://github.com/uccl-project/uccl into zm-a…
YangZhou1997 Nov 14, 2025
026a836
Merge branch 'main' of https://github.com/uccl-project/uccl into zm-a…
YangZhou1997 Nov 14, 2025
b64412d
debugging
YangZhou1997 Nov 15, 2025
9cb051f
run on nebius
YangZhou1997 Nov 15, 2025
f19fd62
merge with main
YangZhou1997 Nov 24, 2025
3a7d093
fixing setup.py on amd
YangZhou1997 Nov 24, 2025
8c86136
add printf to internode
YangZhou1997 Nov 24, 2025
fa1d720
Merge branch 'main' into yang-amd-normal
zhenhuang12 Nov 26, 2025
b6451b4
debug internode on amd-gpu
zhenhuang12 Nov 26, 2025
a142d9e
trying to debug dispatch kernel hang issues, but fails
YangZhou1997 Nov 28, 2025
a2d9f26
merge
YangZhou1997 Nov 28, 2025
5d38ecb
EP: fix normal dispatch bug
zhenhuang12 Dec 1, 2025
33e9bb1
Merge branch 'main' into yang-amd-normal
zhenhuang12 Dec 1, 2025
97d905e
EP: format code
zhenhuang12 Dec 1, 2025
4a44245
EP: restore internode_ll.cu
zhenhuang12 Dec 1, 2025
61eba6f
EP: format code
zhenhuang12 Dec 1, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ep/bench/run_ep.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ if [ "$MODE" = "ll" ]; then
else
torchrun --nnodes=$NNODES --nproc_per_node=8 --node_rank=$RANK \
--master_addr=$MAIN_IP --master_port=12355 \
test_internode.py --num-tokens=4096 \
--log-dir=./logs --redirect 3 test_internode.py --num-tokens=1024 \
--hidden=7168 --num-topk=8 --num-experts=288 --test-ll-compatibility
fi
# --log-dir=logs --redirect=3
8 changes: 8 additions & 0 deletions ep/include/ep_configs.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,19 @@
// #define ENABLE_FAST_DEBUG
#ifndef ENABLE_FAST_DEBUG
#define NUM_CPU_TIMEOUT_SECS 100
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
#define NUM_TIMEOUT_CYCLES 20000000000ull
#else
#define NUM_TIMEOUT_CYCLES 200000000000ull // 200G cycles ~= 100s
#endif
#else
#define NUM_CPU_TIMEOUT_SECS 10
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
#define NUM_TIMEOUT_CYCLES 2000000000ull
#else
#define NUM_TIMEOUT_CYCLES 20000000000ull // 20G cycles ~= 10s
#endif
#endif

#define LOW_LATENCY_SEND_PHASE 1
#define LOW_LATENCY_RECV_PHASE 2
Expand Down
16 changes: 0 additions & 16 deletions ep/include/ep_launch.cuh
Original file line number Diff line number Diff line change
@@ -1,21 +1,6 @@
#pragma once
#include "exception.cuh"

#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
#ifndef SETUP_LAUNCH_CONFIG
#define SETUP_LAUNCH_CONFIG(num_sms, num_threads, stream) \
hipLaunchConfig_t cfg = {(num_sms), (num_threads), 0, stream, nullptr, 0}; \
hipLaunchAttribute attr[1]; \
attr[0].id = hipLaunchAttributeCooperative; \
attr[0].val.cooperative = 1; \
cfg.attrs = attr; \
cfg.numAttrs = 1
#endif
#ifndef LAUNCH_KERNEL
#define LAUNCH_KERNEL(config, kernel, ...) \
CUDA_CHECK(hipLaunchKernelEx(config, kernel, ##__VA_ARGS__))
#endif
#else
#ifndef SETUP_LAUNCH_CONFIG
#ifndef DISABLE_SM90_FEATURES
#define SETUP_LAUNCH_CONFIG(num_sms, num_threads, stream) \
Expand Down Expand Up @@ -55,7 +40,6 @@
} while (0)
#endif
#endif
#endif

#ifndef SET_SHARED_MEMORY_FOR_TMA
#ifndef DISABLE_SM90_FEATURES
Expand Down
3 changes: 1 addition & 2 deletions ep/include/ep_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
#include "amd_nanosleep.cuh"
#define __syncwarp() __builtin_amdgcn_wave_barrier()
#ifndef clock64
#define clock64 wall_clock64
#endif
Expand Down Expand Up @@ -906,7 +905,7 @@ __device__ __forceinline__ void st_relaxed_sys_global(int const* ptr, int val) {
__device__ __forceinline__ int ld_acquire_cta(int const* ptr) {
int ret;
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
HIP_ATOMIC_LOAD(ptr, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_WORKGROUP);
ret = HIP_ATOMIC_LOAD(ptr, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_WORKGROUP);
#else
asm volatile("ld.acquire.cta.s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
#endif
Expand Down
45 changes: 37 additions & 8 deletions ep/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
import subprocess
import setuptools
from glob import glob
import torch
import shutil
import site

import re
from pathlib import Path

import torch
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
from setuptools.command.install import install

Expand Down Expand Up @@ -156,11 +156,37 @@ def run(self):
if float(default_arch) >= 9.0:
nvcc_flags.extend(["--ptxas-options=--register-usage-level=10"])

os.environ["TORCH_CUDA_ARCH_LIST"] = os.getenv(
"TORCH_CUDA_ARCH_LIST", default_arch
)
device_arch = os.environ["TORCH_CUDA_ARCH_LIST"]
# Set architecture environment variable before creating CUDAExtension
device_arch = os.getenv("TORCH_CUDA_ARCH_LIST", default_arch)
os.environ["TORCH_CUDA_ARCH_LIST"] = device_arch
else:
gpu_archs = os.getenv("TORCH_CUDA_ARCH_LIST", None)
if gpu_archs is None or gpu_archs.strip() == "":
# Detect GPU architecture on AMD
GPU_ARCH_PATTERN = re.compile(r"Name:\s*(gfx\d+\w*)")
try:
result = subprocess.run(
["rocminfo"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
check=True,
)
except Exception as e:
raise RuntimeError(f"rocminfo failed: {e}")

matches = set(GPU_ARCH_PATTERN.findall(result.stdout))

if not matches:
raise RuntimeError("No gfx architecture found in rocminfo output.")
arch_list = list(matches)

else:
gpu_archs = gpu_archs.split(",")
Copy link

Copilot AI Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable 'arch_list' is only defined inside the 'if not matches' block (line 182) but is used here unconditionally. If 'gpu_archs' is provided via environment variable, 'arch_list' will be undefined, causing a NameError. The logic should use 'gpu_archs' when set, or define 'arch_list' in both branches.

Suggested change
gpu_archs = gpu_archs.split(",")
arch_list = gpu_archs.split(",")

Copilot uses AI. Check for mistakes.

for arch in arch_list:
nvcc_flags.append(f"--offload-arch={arch.lower()}")

# Disable SM90 features on AMD
cxx_flags.append("-DDISABLE_SM90_FEATURES")
nvcc_flags.append("-DDISABLE_SM90_FEATURES")
Expand All @@ -169,8 +195,11 @@ def run(self):
cxx_flags.append("-DDISABLE_AGGRESSIVE_ATOMIC")
nvcc_flags.append("-DDISABLE_AGGRESSIVE_ATOMIC")

device_arch = os.getenv("TORCH_CUDA_ARCH_LIST", "gfx942")
os.environ["PYTORCH_ROCM_ARCH"] = device_arch
cxx_flags.append("-DUSE_GRACE_HOPPER")
nvcc_flags.append("-DUSE_GRACE_HOPPER")
Comment on lines +198 to +199
Copy link

Copilot AI Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The USE_GRACE_HOPPER flag is being added in the AMD/ROCm code path, but Grace Hopper is an NVIDIA architecture. This appears to be a copy-paste error and should likely be removed from the AMD-specific section.

Suggested change
cxx_flags.append("-DUSE_GRACE_HOPPER")
nvcc_flags.append("-DUSE_GRACE_HOPPER")
# Removed erroneous Grace Hopper flag for AMD/ROCm
# (No action needed)

Copilot uses AI. Check for mistakes.

# Get device architecture (already set at top of file)
device_arch = os.getenv("PYTORCH_ROCM_ARCH", "gfx942")

# Disable LD/ST tricks, as some CUDA version does not support `.L1::no_allocate`
# Only enable aggressive PTX instructions for SM 9.0+ (H100/H800/B200)
Expand Down
106 changes: 103 additions & 3 deletions ep/src/internode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,12 @@ __global__ void __launch_bounds__(
// RDMA sender warp synchronization
// NOTES: `rdma_send_channel_tail` means the latest released tail
// NOTES: `rdma_send_channel_window` means the ongoing 32 transactions' status
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
__shared__ int volatile rdma_send_next_token_idx;
__shared__ int volatile rdma_send_channel_next_tail[kNumRDMARanks];
#else
__shared__ int rdma_send_channel_lock[kNumRDMARanks];
#endif
__shared__ int rdma_send_channel_tail[kNumRDMARanks];
__shared__ uint32_t rdma_send_channel_window[kNumRDMARanks];

Expand Down Expand Up @@ -629,6 +634,12 @@ __global__ void __launch_bounds__(
get_channel_task_range(num_tokens, num_channels, channel_id,
token_start_idx, token_end_idx);

#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
(warp_id == 0 and lane_id == 0)
? (rdma_send_next_token_idx = token_start_idx)
: 0;
#endif

// Send number of tokens in this channel by `-value - 1`
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= WARP_SIZE,
"Invalid number of NVL peers");
Expand Down Expand Up @@ -694,13 +705,67 @@ __global__ void __launch_bounds__(
auto send_buffer = lane_id == rdma_rank
? rdma_channel_data.recv_buffer(lane_id)
: rdma_channel_data.send_buffer(lane_id);

#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
// NOTE: sequential lock works for amd.
int last_rdma_tail_idx = -1;
for (token_idx = token_start_idx + warp_id; token_idx < token_end_idx;
token_idx += kNumDispatchRDMASenderWarps) {
// Read RDMA rank existence
uint64_t is_token_in_rank_uint64 = 0;
if (lane_id < kNumRDMARanks) {
is_token_in_rank_uint64 = __ldg(reinterpret_cast<uint64_t const*>(
is_token_in_rank + token_idx * num_ranks +
lane_id * NUM_MAX_NVL_PEERS));
}

// Acquire sequential lock
while (lane_id == 0 and rdma_send_next_token_idx != token_idx)
;
__syncwarp();

// Acquire next tail
int rdma_tail_idx = -1;
auto start_time = clock64();
if (is_token_in_rank_uint64 != 0) {
rdma_tail_idx = rdma_send_channel_next_tail[lane_id]++;
// Wait the remote buffer to be released
while (rdma_tail_idx - cached_rdma_channel_head >=
num_max_rdma_chunked_recv_tokens) {
cached_rdma_channel_head = static_cast<int>(
ld_acquire_sys_global(rdma_channel_head.buffer(lane_id)));

// Timeout check
if (clock64() - start_time >= NUM_TIMEOUT_CYCLES) {
printf(
"DeepEP dispatch RDMA sender timeout, channel: %d, RDMA: %d, "
"nvl: %d, dst RDMA lane: %d, head: %d, tail: %d\n",
channel_id, rdma_rank, nvl_rank, lane_id,
cached_rdma_channel_head, rdma_tail_idx);
trap();
}
}
}
__syncwarp();

// Update last token tail
if (last_rdma_tail_idx >= 0)
st_release_cta(const_cast<int const*>(rdma_send_channel_tail + lane_id),
last_rdma_tail_idx + 1);
last_rdma_tail_idx = rdma_tail_idx;

// Release sequential lock
lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0;

#else
for (token_idx = token_start_idx; token_idx < token_end_idx; ++token_idx) {
// Read RDMA rank existence
uint64_t is_token_in_rank_uint64 = 0;
if (lane_id < kNumRDMARanks) {
is_token_in_rank_uint64 = __ldg(reinterpret_cast<uint64_t const*>(
is_token_in_rank + token_idx * num_ranks +
lane_id * NUM_MAX_NVL_PEERS));

global_rdma_tail_idx += (is_token_in_rank_uint64 != 0);
}
__syncwarp();
Expand Down Expand Up @@ -730,6 +795,7 @@ __global__ void __launch_bounds__(
trap();
}
}
#endif
__syncwarp();

// Store RDMA head for combine
Expand Down Expand Up @@ -813,6 +879,7 @@ __global__ void __launch_bounds__(
}
__syncwarp();

#if defined(__NVCC__)
// Release the transaction in the window
if (is_token_in_rank_uint64 != 0) {
// Acquire lock first
Expand Down Expand Up @@ -841,8 +908,25 @@ __global__ void __launch_bounds__(
// Release lock
release_lock(rdma_send_channel_lock + lane_id);
}
#endif
__syncwarp();
}

#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
// Epilogue
// Acquire sequential lock
while (lane_id == 0 and rdma_send_next_token_idx != token_idx)
;
__syncwarp();

// Update last token tail
if (last_rdma_tail_idx >= 0)
st_release_cta(const_cast<int const*>(rdma_send_channel_tail + lane_id),
last_rdma_tail_idx + 1);

// Release sequential lock
lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0;
#endif
} else if (warp_role == WarpRole::kRDMASenderCoordinator) {
// NOTES: in case of splitting, the issued put at the end of the buffer
EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens %
Expand All @@ -852,7 +936,11 @@ __global__ void __launch_bounds__(
// Clean shared memory
EP_STATIC_ASSERT(kNumRDMARanks <= WARP_SIZE,
"Invalid number of RDMA ranks");
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
(lane_id < kNumRDMARanks) ? (rdma_send_channel_next_tail[lane_id] = 0) : 0;
#else
(lane_id < kNumRDMARanks) ? (rdma_send_channel_lock[lane_id] = 0) : 0;
#endif
(lane_id < kNumRDMARanks) ? (rdma_send_channel_tail[lane_id] = 0) : 0;
(lane_id < kNumRDMARanks) ? (rdma_send_channel_window[lane_id] = 0) : 0;

Expand Down Expand Up @@ -1114,9 +1202,10 @@ __global__ void __launch_bounds__(

// Copy data
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
UNROLLED_WARP_COPY(
5, lane_id, hidden_int4, reinterpret_cast<int4*>(dst_shifted),
reinterpret_cast<int4*>(shifted), ld_nc_global, st_na_global);
UNROLLED_WARP_COPY(5, lane_id, num_bytes_per_token / sizeof(int4),
reinterpret_cast<int4*>(dst_shifted),
reinterpret_cast<int4*>(shifted), ld_nc_global,
st_na_global);
Comment on lines +1205 to +1208
Copy link

Copilot AI Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The magic number '5' in the UNROLLED_WARP_COPY macro is unclear. Consider replacing it with a named constant (e.g., 'UNROLL_FACTOR' or 'NUM_ITERATIONS') to improve code readability and maintainability.

Copilot uses AI. Check for mistakes.
#else
if (lane_id == 0) {
tma_load_1d(tma_buffer, shifted, tma_mbarrier, num_bytes_per_token,
Expand Down Expand Up @@ -1298,6 +1387,12 @@ __global__ void __launch_bounds__(
5, lane_id, hidden_int4,
reinterpret_cast<int4*>(recv_x + recv_token_idx * hidden_int4),
reinterpret_cast<int4*>(shifted), ld_nc_global, st_na_global);
if (scale_aligned)
UNROLLED_WARP_COPY(1, lane_id, num_scales,
recv_x_scales + recv_token_idx * num_scales,
reinterpret_cast<float*>(shifted + hidden_bytes),
ld_nc_global, st_na_global);

#else
if (lane_id == 0) {
tma_load_1d(tma_buffer, shifted, tma_mbarrier, tma_load_bytes);
Expand Down Expand Up @@ -1660,7 +1755,12 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx,
bool is_cached_dispatch, bool low_latency_mode,
uint64_t const* d2h_channel_addrs, int num_d2h_channel_addrs,
void* atomic_buffer_ptr) {
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
int const num_threads =
std::max(128, WARP_SIZE * (is_cached_dispatch ? 2 : num_channels));
#else
int const num_threads = std::max(128, WARP_SIZE * num_channels);
#endif
int const num_warps = num_threads / WARP_SIZE;
auto const num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
int const kNumTMABytesPerWarp = 8192;
Expand Down
2 changes: 1 addition & 1 deletion ep/src/internode_ll.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1125,4 +1125,4 @@ void combine(void* combined_x, void* rdma_recv_x, int* rdma_recv_flag,
}

} // namespace internode_ll
} // namespace uccl
} // namespace uccl
4 changes: 0 additions & 4 deletions ep/src/proxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,6 @@ void Proxy::init_common() {
reinterpret_cast<uint32_t*>(static_cast<uint8_t*>(cfg_.gpu_buffer) +
cfg_.total_size - atomic_buf_size);

// printf("[PROXY_INIT] Atomic buffer at %p, size %zu bytes\n",
// ctx_.atomic_old_values_buf, atomic_buf_size);

int num_ranks = ctxs_for_all_ranks_.size();
local_infos_.assign(num_ranks, RDMAConnectionInfo{});
remote_infos_.assign(num_ranks, RDMAConnectionInfo{});
Expand Down Expand Up @@ -846,7 +843,6 @@ void Proxy::post_gpu_commands_mixed(
0) {
return;
}

// Handle regular RDMA writes
if (!rdma_wrs.empty()) {
post_rdma_async_batched(ctx_, cfg_.gpu_buffer, rdma_wrs.size(), rdma_wrs,
Expand Down
Loading