Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions ep/include/proxy_ctx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ struct ProxyCtx {
uint32_t dst_ack_qpn;
struct ibv_ah* dst_ah = nullptr;

// Connectionless SRD support: multiple AHs and QPNs for different remote NICs
std::vector<struct ibv_ah*> dst_ah_per_nic;
std::vector<uint32_t> dst_qpn_per_nic;
std::vector<uint32_t> dst_ack_qpn_per_nic;
std::vector<uintptr_t> remote_addr_per_nic;
std::vector<uint32_t> remote_rkey_per_nic;
std::vector<uint64_t> remote_len_per_nic;

// Remote memory
uintptr_t remote_addr = 0; // Base address of remote rdma_buffer
uint64_t remote_len = 0;
Expand Down
11 changes: 9 additions & 2 deletions ep/include/rdma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ struct RDMAConnectionInfo {
uint32_t ack_qp_num;
uint32_t recv_ack_qp_num;
uint32_t ack_psn;
uint32_t rkey; // Memory region key
uint32_t rkey; // Memory region keyf
uintptr_t addr; // Buffer address
uint64_t len;
uint16_t lid; // Local ID
Expand All @@ -29,7 +29,11 @@ struct RDMAConnectionInfo {
// #ifdef EFA
uint32_t num_rings;
uint32_t data_qp_num[kChannelPerProxy];
// #endif

uint32_t num_nics;
uint8_t gid_per_nic[MAX_NUM_GPUS][16];
uint32_t qp_num_per_nic[MAX_NUM_GPUS];
uint32_t ack_qp_num_per_nic[MAX_NUM_GPUS];
};

struct PendingUpdate {
Expand Down Expand Up @@ -301,6 +305,9 @@ void modify_qp_to_rtr(ProxyCtx& S, RDMAConnectionInfo* remote,
void modify_qp_to_rts(ProxyCtx& S, RDMAConnectionInfo* local_info);

void modify_qp_to_init(ProxyCtx& S);

struct ibv_ah* create_ah(ProxyCtx& S, uint8_t* remote_gid);

void local_poll_completions(ProxyCtx& S,
std::unordered_set<uint64_t>& acked_wrs,
int thread_idx, std::vector<ProxyCtx*>& ctx_by_tag);
Expand Down
1 change: 1 addition & 0 deletions ep/src/proxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ void Proxy::init_common() {
}
}
usleep(50 * 1000);

if (cfg_.use_normal_mode) {
// if (cfg_.thread_idx != 0) {
// return;
Expand Down
56 changes: 47 additions & 9 deletions ep/src/rdma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,12 @@ void create_per_thread_qp(ProxyCtx& S, void* gpu_buffer, size_t size,
local_info->psn = 0;
local_info->ack_psn = 0;
fill_local_gid(S, local_info);

local_info->num_nics = 0;
memset(local_info->gid_per_nic, 0, sizeof(local_info->gid_per_nic));
memset(local_info->qp_num_per_nic, 0, sizeof(local_info->qp_num_per_nic));
memset(local_info->ack_qp_num_per_nic, 0,
sizeof(local_info->ack_qp_num_per_nic));
}

void modify_qp_to_init(ProxyCtx& S) {
Expand Down Expand Up @@ -495,6 +501,18 @@ void modify_qp_to_rtr(ProxyCtx& S, RDMAConnectionInfo* remote,
S.dst_qpn = remote->qp_num;
S.dst_ack_qpn = remote->recv_ack_qp_num;
S.dst_ah = create_ah(S, remote->gid);

if (!use_normal_mode && remote->num_nics > 0) {
S.dst_ah_per_nic.resize(remote->num_nics);
S.dst_qpn_per_nic.resize(remote->num_nics);
S.dst_ack_qpn_per_nic.resize(remote->num_nics);

for (uint32_t nic_idx = 0; nic_idx < remote->num_nics; ++nic_idx) {
S.dst_ah_per_nic[nic_idx] = create_ah(S, remote->gid_per_nic[nic_idx]);
S.dst_qpn_per_nic[nic_idx] = remote->qp_num_per_nic[nic_idx];
S.dst_ack_qpn_per_nic[nic_idx] = remote->ack_qp_num_per_nic[nic_idx];
}
}
#endif

if (use_normal_mode) {
Expand Down Expand Up @@ -1022,17 +1040,33 @@ static void post_rdma_async_batched_fast_mode(
qpx->comp_mask = 0;
qpx->wr_flags = IBV_SEND_SIGNALED;

struct ibv_ah* selected_ah = ctx->dst_ah;
uint32_t selected_qpn = ctx->dst_qpn;
uintptr_t selected_remote_addr = ctx->remote_addr;
uint32_t selected_remote_rkey = ctx->remote_rkey;
uint64_t selected_remote_len = ctx->remote_len;

if (!ctx->dst_ah_per_nic.empty()) {
size_t nic_idx = wrs_to_post[i] % ctx->dst_ah_per_nic.size();
selected_ah = ctx->dst_ah_per_nic[nic_idx];
selected_qpn = ctx->dst_qpn_per_nic[nic_idx];
selected_remote_addr = ctx->remote_addr_per_nic[nic_idx];
selected_remote_rkey = ctx->remote_rkey_per_nic[nic_idx];
selected_remote_len = ctx->remote_len_per_nic[nic_idx];
}

uint64_t remote_addr =
ctx->remote_addr + (cmd.req_rptr ? cmd.req_rptr : 0);
uint64_t remote_end = ctx->remote_addr + ctx->remote_len;
selected_remote_addr + (cmd.req_rptr ? cmd.req_rptr : 0);
uint64_t remote_end = selected_remote_addr + selected_remote_len;

if (remote_addr < ctx->remote_addr ||
if (remote_addr < selected_remote_addr ||
remote_addr + cmd.bytes > remote_end) {
fprintf(stderr,
"[ERROR] Remote write OOB: addr=0x%llx len=%u (base=0x%llx, "
"size=%zu), cmd.req_rptr: 0x%llx\n",
(unsigned long long)remote_addr, cmd.bytes,
(unsigned long long)ctx->remote_addr, (size_t)ctx->remote_len,
(unsigned long long)selected_remote_addr,
(size_t)selected_remote_len,
(unsigned long long)cmd.req_rptr);
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) {
Expand All @@ -1052,7 +1086,8 @@ static void post_rdma_async_batched_fast_mode(
get_low_latency(cmd.cmd_type),
cmd.expert_idx, 1, my_rank)
.GetImmData();
ibv_wr_rdma_write_imm(qpx, ctx->remote_rkey, remote_addr, htonl(imm));
ibv_wr_rdma_write_imm(qpx, selected_remote_rkey, remote_addr,
htonl(imm));
#else
if (cmd.atomic_offset > 0 && cmd.atomic_val > 0) {
int v = static_cast<int>(cmd.atomic_val);
Expand All @@ -1064,20 +1099,23 @@ static void post_rdma_async_batched_fast_mode(
AtomicsImm::Pack(true, false, cmd.atomic_val, cmd.atomic_offset,
get_low_latency(cmd.cmd_type))
.GetImmData();
ibv_wr_rdma_write_imm(qpx, ctx->remote_rkey, remote_addr, htonl(imm));
ibv_wr_rdma_write_imm(qpx, selected_remote_rkey, remote_addr,
htonl(imm));
} else if (j + 1 == k) {
uint32_t imm = WriteImm::Pack(get_is_combine(cmd.cmd_type),
get_low_latency(cmd.cmd_type),
cmd.expert_idx, k, my_rank)
.GetImmData();
ibv_wr_rdma_write_imm(qpx, ctx->remote_rkey, remote_addr, htonl(imm));
ibv_wr_rdma_write_imm(qpx, selected_remote_rkey, remote_addr,
htonl(imm));
} else {
ibv_wr_rdma_write(qpx, ctx->remote_rkey, remote_addr);
ibv_wr_rdma_write(qpx, selected_remote_rkey, remote_addr);
}
#endif
uintptr_t laddr =
cmd.req_lptr + reinterpret_cast<uintptr_t>(ctx->mr->addr);
ibv_wr_set_ud_addr(qpx, ctx->dst_ah, ctx->dst_qpn, QKEY);

ibv_wr_set_ud_addr(qpx, selected_ah, selected_qpn, QKEY);
ibv_wr_set_sge(qpx, ctx->mr->lkey, laddr,
static_cast<uint32_t>(cmd.bytes));
}
Expand Down