-
Notifications
You must be signed in to change notification settings - Fork 104
[EP] debugging amd normal #548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 22 commits
b7c1fb1
05e7df8
04c6a14
946db83
fbda03b
d6f3355
1903181
7324fcf
b136aaa
4d63369
de4abc4
49c0bab
e5b2ef3
a4d9d4b
026a836
b64412d
9cb051f
f19fd62
3a7d093
8c86136
fa1d720
b6451b4
a142d9e
a2d9f26
5d38ecb
33e9bb1
97d905e
4a44245
61eba6f
ce30d31
6687cb6
29f331f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -273,8 +273,7 @@ __global__ void notify_dispatch( | |
| i)[NUM_MAX_NVL_PEERS + num_rdma_experts]; | ||
| recv_rdma_rank_prefix_sum[i] = sum; | ||
| } | ||
| while (ld_volatile_global(moe_recv_rdma_counter_mapped) != -1) | ||
| ; | ||
| while (ld_volatile_global(moe_recv_rdma_counter_mapped) != -1); | ||
| *moe_recv_rdma_counter_mapped = sum; | ||
| } | ||
|
|
||
|
|
@@ -303,8 +302,7 @@ __global__ void notify_dispatch( | |
| sum += nvl_recv_num_tokens_per_rank.buffer(src_nvl_rank)[src_rdma_rank]; | ||
| recv_gbl_rank_prefix_sum[i] = sum; | ||
| } | ||
| while (ld_volatile_global(moe_recv_counter_mapped) != -1) | ||
| ; | ||
| while (ld_volatile_global(moe_recv_counter_mapped) != -1); | ||
| *moe_recv_counter_mapped = sum; | ||
| } | ||
| if (thread_id < num_nvl_experts) { | ||
|
|
@@ -314,8 +312,7 @@ __global__ void notify_dispatch( | |
| sum += nvl_recv_num_tokens_per_expert.buffer(i)[thread_id]; | ||
| sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment; | ||
| while (ld_volatile_global(moe_recv_expert_counter_mapped + thread_id) != | ||
| -1) | ||
| ; | ||
| -1); | ||
| moe_recv_expert_counter_mapped[thread_id] = sum; | ||
| } | ||
|
|
||
|
|
@@ -706,9 +703,7 @@ __global__ void __launch_bounds__( | |
| __syncwarp(); | ||
|
|
||
| // Skip the token which does not belong to this warp | ||
| if ((token_idx - token_start_idx) % kNumDispatchRDMASenderWarps != | ||
| warp_id) | ||
| continue; | ||
| if ((token_idx - token_start_idx) % 2 != warp_id) continue; | ||
|
||
| auto rdma_tail_idx = | ||
| is_token_in_rank_uint64 == 0 ? -1 : global_rdma_tail_idx - 1; | ||
|
|
||
|
|
@@ -819,7 +814,7 @@ __global__ void __launch_bounds__( | |
| acquire_lock(rdma_send_channel_lock + lane_id); | ||
| auto latest_tail = rdma_send_channel_tail[lane_id]; | ||
| auto offset = rdma_tail_idx - latest_tail; | ||
| while (offset >= WARP_SIZE) { | ||
| while (offset >= 32) { | ||
| release_lock(rdma_send_channel_lock + lane_id); | ||
| acquire_lock(rdma_send_channel_lock + lane_id); | ||
| latest_tail = rdma_send_channel_tail[lane_id]; | ||
|
|
@@ -830,8 +825,7 @@ __global__ void __launch_bounds__( | |
| // Add the bit and move the ones if possible | ||
| auto window = rdma_send_channel_window[lane_id] | (1u << offset); | ||
| if (offset == 0) { | ||
| auto num_empty_slots = | ||
| (~window) == 0 ? WARP_SIZE : __ffs(~window) - 1; | ||
| auto num_empty_slots = (~window) == 0 ? 32 : __ffs(~window) - 1; | ||
| st_release_cta(rdma_send_channel_tail + lane_id, | ||
| latest_tail + num_empty_slots); | ||
| window >>= num_empty_slots; | ||
|
|
@@ -1114,9 +1108,10 @@ __global__ void __launch_bounds__( | |
|
|
||
| // Copy data | ||
| #if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) | ||
| UNROLLED_WARP_COPY( | ||
| 5, lane_id, hidden_int4, reinterpret_cast<int4*>(dst_shifted), | ||
| reinterpret_cast<int4*>(shifted), ld_nc_global, st_na_global); | ||
| UNROLLED_WARP_COPY(5, lane_id, num_bytes_per_token / sizeof(int4), | ||
| reinterpret_cast<int4*>(dst_shifted), | ||
| reinterpret_cast<int4*>(shifted), ld_nc_global, | ||
| st_na_global); | ||
|
Comment on lines
+1205
to
+1208
|
||
| #else | ||
| if (lane_id == 0) { | ||
| tma_load_1d(tma_buffer, shifted, tma_mbarrier, num_bytes_per_token, | ||
|
|
@@ -1148,9 +1143,17 @@ __global__ void __launch_bounds__( | |
|
|
||
| // Move tail index | ||
| __syncwarp(); | ||
| if (lane_id == 0) | ||
| if (lane_id == 0) { | ||
| // /*******************************************************************/ | ||
| // printf( | ||
| // "DeepEP dispatch NVL forwarder, channel: %d, RDMA: %d, " | ||
| // "src NVL: %d, dst NVL: %d, head: %d, tail: %d\n", | ||
| // channel_id, rdma_rank, nvl_rank, target_rank, | ||
| // cached_nvl_channel_head, cached_nvl_channel_tail); | ||
| st_release_sys_global(nvl_channel_tail.buffer(), | ||
| cached_nvl_channel_tail); | ||
| // /*******************************************************************/ | ||
| } | ||
| } | ||
| // Retired | ||
| __syncwarp(); | ||
|
|
@@ -1240,6 +1243,7 @@ __global__ void __launch_bounds__( | |
| } | ||
| } | ||
| num_tokens_to_recv = warp_reduce_sum(end_offset - start_offset); | ||
|
|
||
| auto num_tokens_to_recv_original = num_tokens_to_recv; | ||
| // Save for combine usage | ||
| if (lane_id < kNumRDMARanks and not kCachedMode) | ||
|
|
@@ -1260,17 +1264,31 @@ __global__ void __launch_bounds__( | |
|
|
||
| cached_channel_tail_idx = __shfl_sync( | ||
| WARP_MASK, ld_acquire_sys_global(nvl_channel_tail.buffer()), 0); | ||
| // if (lane_id == 0) { | ||
| /*******************************************************************/ | ||
| // printf( | ||
| // "DeepEP dispatch NVL receiver check, channel: %d, RDMA: %d, src | ||
| // " "NVL: %d, dst NVL: %d, head: %d, tail: %d, " | ||
| // "num_tokens_to_recv_original: %d, " | ||
| // "num_tokens_to_recv: %d\n", | ||
| // channel_id, rdma_rank, target_rank, nvl_rank, | ||
| // ld_acquire_sys_global(nvl_channel_head.buffer()), | ||
| // ld_acquire_sys_global(nvl_channel_tail.buffer()), | ||
| // num_tokens_to_recv_original, num_tokens_to_recv); | ||
| /*******************************************************************/ | ||
| // } | ||
| // Timeout check | ||
| if (lane_id == 0 and clock64() - start_time > NUM_TIMEOUT_CYCLES) { | ||
| printf( | ||
| "DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, " | ||
| "nvl: %d, src NVL: %d, head: %d, tail: %d, " | ||
| "num_tokens_to_recv_original: %d, last_recv_token_idx: %lld, " | ||
| "num_tokens_to_recv_original: %d, num_tokens_to_recv: %d, " | ||
| "last_recv_token_idx: %lld, " | ||
| "next_expected_token_idx: %lld\n", | ||
| channel_id, rdma_rank, nvl_rank, src_nvl_rank, | ||
| cached_channel_head_idx, cached_channel_tail_idx, | ||
| num_tokens_to_recv_original, last_recv_token_idx, | ||
| (long long)(last_recv_token_idx + 1)); | ||
| num_tokens_to_recv_original, num_tokens_to_recv, | ||
| last_recv_token_idx, (long long)(last_recv_token_idx + 1)); | ||
| trap(); | ||
| } | ||
| } | ||
|
|
@@ -1298,6 +1316,11 @@ __global__ void __launch_bounds__( | |
| 5, lane_id, hidden_int4, | ||
| reinterpret_cast<int4*>(recv_x + recv_token_idx * hidden_int4), | ||
| reinterpret_cast<int4*>(shifted), ld_nc_global, st_na_global); | ||
| if (scale_aligned) | ||
| UNROLLED_WARP_COPY(1, lane_id, num_scales, | ||
| recv_x_scales + recv_token_idx * num_scales, | ||
| reinterpret_cast<float*>(shifted + hidden_bytes), | ||
| ld_nc_global, st_na_global); | ||
| #else | ||
| if (lane_id == 0) { | ||
| tma_load_1d(tma_buffer, shifted, tma_mbarrier, tma_load_bytes); | ||
|
|
@@ -1691,6 +1714,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, | |
| auto cached_notify_func = low_latency_mode | ||
| ? cached_notify<true, kNumTMABytesPerWarp> | ||
| : cached_notify<false, kNumTMABytesPerWarp>; | ||
|
|
||
| SETUP_LAUNCH_CONFIG(num_channels * 2, num_threads, stream); | ||
| SET_SHARED_MEMORY_FOR_TMA(cached_notify_func); | ||
| LAUNCH_KERNEL(&cfg, cached_notify_func, rdma_clean_meta.first, | ||
|
|
@@ -1882,7 +1906,7 @@ template < | |
| int kNumWarpsPerForwarder = (kNumCombineForwarderWarps / kNumRDMARanks > 0) | ||
| ? kNumCombineForwarderWarps / kNumRDMARanks | ||
| : 1, | ||
| int kNumForwarders = kNumRDMARanks* kNumWarpsPerForwarder, | ||
| int kNumForwarders = kNumRDMARanks * kNumWarpsPerForwarder, | ||
| int kNumRDMAReceivers = kNumForwarders - NUM_MAX_NVL_PEERS> | ||
| __global__ void __launch_bounds__((kNumForwarders + 1) * WARP_SIZE, 1) | ||
| combine(int4* combined_x, float* combined_topk_weights, | ||
|
|
@@ -2051,7 +2075,7 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * WARP_SIZE, 1) | |
| channel_id, rdma_rank, nvl_rank, dst_nvl_rank, lane_id, | ||
| ld_volatile_global(nvl_channel_head.buffer() + lane_id), | ||
| cached_channel_tail_idx, token_start_idx, token_end_idx); | ||
| trap(); | ||
| // trap(); | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -2303,7 +2327,7 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * WARP_SIZE, 1) | |
| channel_id, rdma_rank, nvl_rank, dst_rdma_rank, | ||
| ld_acquire_sys_global(rdma_channel_head.buffer(dst_rdma_rank)), | ||
| token_start_idx, num_chunked_tokens); | ||
| trap(); | ||
| // trap(); | ||
| } | ||
| } | ||
| sync_large_warp(); | ||
|
|
@@ -2340,7 +2364,7 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * WARP_SIZE, 1) | |
| channel_id, rdma_rank, nvl_rank, lane_id, dst_rdma_rank, | ||
| cached_nvl_channel_tail_idx, token_idx, num_tokens_to_combine, | ||
| sub_warp_id, kNumWarpsPerForwarder, expected_head); | ||
| trap(); | ||
| // trap(); | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -2484,7 +2508,7 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * WARP_SIZE, 1) | |
| "nvl: %d, src RDMA: %d, tail: %d, waiting: %ld, expect: %d\n", | ||
| channel_id, rdma_rank, nvl_rank, lane_id, | ||
| cached_channel_tail_idx, token_idx, expected_head); | ||
| trap(); | ||
| // trap(); | ||
| } | ||
| } | ||
| __syncwarp(); | ||
|
|
@@ -2660,6 +2684,8 @@ void combine(cudaDataType_t type, void* combined_x, | |
| EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens >= num_warps_per_forwarder); | ||
| EP_HOST_ASSERT(type == CUDA_R_16BF); | ||
|
|
||
| printf("combine num_sms = %d, num_threads = %d", num_channels * 2, | ||
| (num_forwarder_warps + 1) * WARP_SIZE); | ||
| #if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) | ||
| EP_HOST_ASSERT((num_forwarder_warps + 1) * WARP_SIZE <= MAX_NTHREADS); | ||
| #endif | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is --redirect 3?