@@ -162,8 +162,7 @@ class CommResource {
162162
163163 struct ConnectionResource {
164164 std::shared_ptr<mscclpp::Connection> connection;
165- std::vector<std::shared_ptr<mscclpp::SimpleProxyChannel>>
166- proxy_channels;
165+ std::vector<std::shared_ptr<mscclpp::ProxyChannel>> proxy_channels;
167166 std::vector<std::shared_ptr<mscclpp::SmChannel>> sm_channels;
168167 };
169168
@@ -312,11 +311,11 @@ void CommResource::connect(const PlanJson &plan_json,
312311 [&](std::shared_ptr<ConnectionResource> conn_resource) {
313312 if (!conn_resource) return ;
314313 conn_resource->proxy_channels .push_back (
315- std::make_shared<mscclpp::SimpleProxyChannel >(
314+ std::make_shared<mscclpp::ProxyChannel >(
316315 proxy_service_->proxyChannel (
317316 proxy_service_->buildAndAddSemaphore (
318- *comm_, conn_resource->connection )) ,
319- remote_regmem_id, regmem_id));
317+ *comm_, conn_resource->connection ),
318+ remote_regmem_id, regmem_id) ));
320319 };
321320 // NOTE: We can create multiple proxy channels here if we need in the
322321 // future
@@ -743,16 +742,15 @@ void PlanResource::init_kernel() {
743742 void *proxy_secondary_chan_addr =
744743 get_global_rt (" ARK_PROXY_SECONDARY_CHANS" );
745744 void *sm_chan_addr = get_global_rt (" ARK_SM_CHANS" );
746- std::vector<mscclpp::SimpleProxyChannel::DeviceHandle> proxy_handles (
745+ std::vector<mscclpp::ProxyChannel::DeviceHandle> proxy_handles (world_size_);
746+ std::vector<mscclpp::ProxyChannel::DeviceHandle> proxy_secondary_handles (
747747 world_size_);
748- std::vector<mscclpp::SimpleProxyChannel::DeviceHandle>
749- proxy_secondary_handles (world_size_);
750748 std::vector<mscclpp::SmChannel::DeviceHandle> sm_handles (world_size_);
751749 for (int i = 0 ; i < world_size_; i++) {
752750 if (i == rank_) continue ;
753751 auto resource = comm_resource_->resource (i);
754752 if (!resource) continue ;
755- std::vector<mscclpp::SimpleProxyChannel ::DeviceHandle> p_hdls;
753+ std::vector<mscclpp::ProxyChannel ::DeviceHandle> p_hdls;
756754 if (resource->ipc ) {
757755 sm_handles[i] = resource->ipc ->sm_channels [0 ]->deviceHandle ();
758756 p_hdls.push_back (resource->ipc ->proxy_channels [0 ]->deviceHandle ());
@@ -772,14 +770,14 @@ void PlanResource::init_kernel() {
772770 }
773771 auto tmp_stream = gpu_manager->create_stream ();
774772 GLOG (gpuSetDevice (device_id_));
775- GLOG (gpuMemcpyAsync (proxy_chan_addr, proxy_handles. data (),
776- proxy_handles.size () *
777- sizeof (mscclpp::SimpleProxyChannel ::DeviceHandle),
778- gpuMemcpyHostToDevice, tmp_stream->get ()));
773+ GLOG (gpuMemcpyAsync (
774+ proxy_chan_addr, proxy_handles.data (),
775+ proxy_handles. size () * sizeof (mscclpp::ProxyChannel ::DeviceHandle),
776+ gpuMemcpyHostToDevice, tmp_stream->get ()));
779777 GLOG (gpuMemcpyAsync (proxy_secondary_chan_addr,
780778 proxy_secondary_handles.data (),
781779 proxy_secondary_handles.size () *
782- sizeof (mscclpp::SimpleProxyChannel ::DeviceHandle),
780+ sizeof (mscclpp::ProxyChannel ::DeviceHandle),
783781 gpuMemcpyHostToDevice, tmp_stream->get ()));
784782 GLOG (gpuMemcpyAsync (
785783 sm_chan_addr, sm_handles.data (),
0 commit comments