Skip to content

Commit e87f96b

Browse files
committed
updates
1 parent 7d05da8 commit e87f96b

File tree

8 files changed

+81
-32
lines changed

8 files changed

+81
-32
lines changed

CMakeLists.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,6 @@ if(CMAKE_BUILD_TYPE MATCHES "Debug" AND CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang
9191
target_link_options(coverage_config INTERFACE --coverage)
9292
endif()
9393

94-
# Find ibverbs
95-
include(FindIBVerbs)
96-
9794
# Find NUMA
9895
include(FindNUMA)
9996

ark/CMakeLists.txt

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ if(ARK_USE_ROCM)
1111
set_source_files_properties(${CU_SOURCES} PROPERTIES LANGUAGE CXX)
1212
endif()
1313

14-
set(COMMON_LIBS ARK::numa ARK::ibverbs pthread rt)
14+
set(COMMON_LIBS ARK::numa pthread rt)
1515

1616
# ARK object
1717
target_include_directories(ark_obj PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include)
@@ -20,7 +20,6 @@ target_include_directories(ark_obj SYSTEM PRIVATE
2020
${DLPACK_INCLUDE_DIRS}
2121
${JSON_INCLUDE_DIRS}
2222
${MSCCLPP_INCLUDE_DIRS}
23-
${IBVERBS_INCLUDE_DIRS}
2423
${NUMA_INCLUDE_DIRS}
2524
)
2625

@@ -55,7 +54,6 @@ if(ARK_BUILD_TESTS)
5554
target_include_directories(${exe_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
5655
target_include_directories(${exe_name} SYSTEM PRIVATE
5756
${JSON_INCLUDE_DIRS}
58-
${IBVERBS_INCLUDE_DIRS}
5957
${NUMA_INCLUDE_DIRS}
6058
)
6159

ark/api/executor.cpp

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,7 @@ class CommResource {
162162

163163
struct ConnectionResource {
164164
std::shared_ptr<mscclpp::Connection> connection;
165-
std::vector<std::shared_ptr<mscclpp::SimpleProxyChannel>>
166-
proxy_channels;
165+
std::vector<std::shared_ptr<mscclpp::ProxyChannel>> proxy_channels;
167166
std::vector<std::shared_ptr<mscclpp::SmChannel>> sm_channels;
168167
};
169168

@@ -312,11 +311,11 @@ void CommResource::connect(const PlanJson &plan_json,
312311
[&](std::shared_ptr<ConnectionResource> conn_resource) {
313312
if (!conn_resource) return;
314313
conn_resource->proxy_channels.push_back(
315-
std::make_shared<mscclpp::SimpleProxyChannel>(
314+
std::make_shared<mscclpp::ProxyChannel>(
316315
proxy_service_->proxyChannel(
317316
proxy_service_->buildAndAddSemaphore(
318-
*comm_, conn_resource->connection)),
319-
remote_regmem_id, regmem_id));
317+
*comm_, conn_resource->connection),
318+
remote_regmem_id, regmem_id)));
320319
};
321320
// NOTE: We can create multiple proxy channels here if we need in the
322321
// future
@@ -743,16 +742,15 @@ void PlanResource::init_kernel() {
743742
void *proxy_secondary_chan_addr =
744743
get_global_rt("ARK_PROXY_SECONDARY_CHANS");
745744
void *sm_chan_addr = get_global_rt("ARK_SM_CHANS");
746-
std::vector<mscclpp::SimpleProxyChannel::DeviceHandle> proxy_handles(
745+
std::vector<mscclpp::ProxyChannel::DeviceHandle> proxy_handles(world_size_);
746+
std::vector<mscclpp::ProxyChannel::DeviceHandle> proxy_secondary_handles(
747747
world_size_);
748-
std::vector<mscclpp::SimpleProxyChannel::DeviceHandle>
749-
proxy_secondary_handles(world_size_);
750748
std::vector<mscclpp::SmChannel::DeviceHandle> sm_handles(world_size_);
751749
for (int i = 0; i < world_size_; i++) {
752750
if (i == rank_) continue;
753751
auto resource = comm_resource_->resource(i);
754752
if (!resource) continue;
755-
std::vector<mscclpp::SimpleProxyChannel::DeviceHandle> p_hdls;
753+
std::vector<mscclpp::ProxyChannel::DeviceHandle> p_hdls;
756754
if (resource->ipc) {
757755
sm_handles[i] = resource->ipc->sm_channels[0]->deviceHandle();
758756
p_hdls.push_back(resource->ipc->proxy_channels[0]->deviceHandle());
@@ -772,14 +770,14 @@ void PlanResource::init_kernel() {
772770
}
773771
auto tmp_stream = gpu_manager->create_stream();
774772
GLOG(gpuSetDevice(device_id_));
775-
GLOG(gpuMemcpyAsync(proxy_chan_addr, proxy_handles.data(),
776-
proxy_handles.size() *
777-
sizeof(mscclpp::SimpleProxyChannel::DeviceHandle),
778-
gpuMemcpyHostToDevice, tmp_stream->get()));
773+
GLOG(gpuMemcpyAsync(
774+
proxy_chan_addr, proxy_handles.data(),
775+
proxy_handles.size() * sizeof(mscclpp::ProxyChannel::DeviceHandle),
776+
gpuMemcpyHostToDevice, tmp_stream->get()));
779777
GLOG(gpuMemcpyAsync(proxy_secondary_chan_addr,
780778
proxy_secondary_handles.data(),
781779
proxy_secondary_handles.size() *
782-
sizeof(mscclpp::SimpleProxyChannel::DeviceHandle),
780+
sizeof(mscclpp::ProxyChannel::DeviceHandle),
783781
gpuMemcpyHostToDevice, tmp_stream->get()));
784782
GLOG(gpuMemcpyAsync(
785783
sm_chan_addr, sm_handles.data(),

ark/codegen.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,9 +354,9 @@ std::string CodeGenerator::Impl::def_task(const Json &task_json) {
354354

355355
std::string CodeGenerator::Impl::def_channels(int world_size) {
356356
std::stringstream ss;
357-
ss << "__constant__ mscclpp::SimpleProxyChannelDeviceHandle ";
357+
ss << "__constant__ mscclpp::ProxyChannelDeviceHandle ";
358358
ss << "ARK_PROXY_CHANS[" << world_size << "];\n";
359-
ss << "__constant__ mscclpp::SimpleProxyChannelDeviceHandle ";
359+
ss << "__constant__ mscclpp::ProxyChannelDeviceHandle ";
360360
ss << "ARK_PROXY_SECONDARY_CHANS[" << world_size << "];\n";
361361
ss << "__constant__ mscclpp::SmChannelDeviceHandle ";
362362
ss << "ARK_SM_CHANS[" << world_size << "];\n";

ark/include/kernels/comm.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
#include "common/unit_op.h"
1515
#include "reduce.h"
1616

17-
extern __constant__ mscclpp::SimpleProxyChannelDeviceHandle ARK_PROXY_CHANS[];
18-
extern __constant__ mscclpp::SimpleProxyChannelDeviceHandle
17+
extern __constant__ mscclpp::ProxyChannelDeviceHandle ARK_PROXY_CHANS[];
18+
extern __constant__ mscclpp::ProxyChannelDeviceHandle
1919
ARK_PROXY_SECONDARY_CHANS[];
2020
extern __constant__ mscclpp::SmChannelDeviceHandle ARK_SM_CHANS[];
2121

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
import numpy as np
5+
import ark
6+
7+
8+
def quickstart_tutorial():
9+
# Initialize the ARK environments
10+
ark.init()
11+
12+
M, N, K = 1024, 1024, 1024
13+
m0 = ark.tensor([M, K], ark.fp16)
14+
m1 = ark.tensor([N, K], ark.fp16)
15+
m2 = ark.tensor([M, K], ark.fp16)
16+
17+
# stage 1: matmul
18+
with ark.PlannerContext(processor_range=[0, 108]):
19+
# Use SMs 0~107 (all)
20+
t0 = ark.matmul(m0, m1, transpose_other=True)
21+
22+
# stage 2: parallel copy and matmul
23+
with ark.PlannerContext(processor_range=[0, 54]):
24+
# Use SMs 0~53
25+
t1 = ark.matmul(t0, m1)
26+
with ark.PlannerContext(processor_range=[54, 108]):
27+
# Use SMs 54~107
28+
t2 = ark.copy(input=t0, output=m2)
29+
30+
# Initialize the ARK runtime
31+
runtime = ark.Runtime()
32+
33+
# Launch the ARK runtime
34+
runtime.launch()
35+
36+
# Initialize
37+
m0_host = np.random.rand(M, K).astype(np.float16) * 0.01
38+
m0.from_numpy(m0_host)
39+
m1_host = np.random.rand(N, K).astype(np.float16) * 0.01
40+
m1.from_numpy(m1_host)
41+
42+
# Run the ARK program
43+
runtime.run()
44+
45+
# Check the matmul result
46+
res_host = np.matmul(np.matmul(m0_host, m1_host.T), m1_host)
47+
np.testing.assert_allclose(t1.to_numpy(), res_host, rtol=1e-3, atol=1e-3)
48+
49+
# Check the copy result
50+
np.testing.assert_equal(t2.to_numpy(), t0.to_numpy())
51+
52+
print("Successful!")
53+
54+
55+
if __name__ == "__main__":
56+
quickstart_tutorial()

third_party/CMakeLists.txt

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@ include(FetchContent)
1111
FetchContent_Declare(
1212
mscclpp
1313
GIT_REPOSITORY https://github.com/microsoft/mscclpp
14-
GIT_TAG v0.5.2
14+
GIT_TAG v0.6.0
1515
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/mscclpp
1616
)
17-
set(BUILD_TESTS OFF CACHE BOOL "" FORCE)
18-
set(BUILD_PYTHON_BINDINGS OFF CACHE BOOL "" FORCE)
19-
set(BUILD_APPS_NCCL OFF CACHE BOOL "" FORCE)
20-
set(USE_CUDA ${ARK_USE_CUDA} CACHE BOOL "" FORCE)
21-
set(USE_ROCM ${ARK_USE_ROCM} CACHE BOOL "" FORCE)
22-
set(BYPASS_GPU_CHECK ON CACHE BOOL "" FORCE)
17+
set(MSCCLPP_BUILD_TESTS OFF CACHE BOOL "" FORCE)
18+
set(MSCCLPP_BUILD_PYTHON_BINDINGS OFF CACHE BOOL "" FORCE)
19+
set(MSCCLPP_BUILD_APPS_NCCL OFF CACHE BOOL "" FORCE)
20+
set(MSCCLPP_USE_CUDA ${ARK_USE_CUDA} CACHE BOOL "" FORCE)
21+
set(MSCCLPP_USE_ROCM ${ARK_USE_ROCM} CACHE BOOL "" FORCE)
22+
set(MSCCLPP_BYPASS_GPU_CHECK ON CACHE BOOL "" FORCE)
2323
set(INSTALL_PREFIX "ark")
2424
FetchContent_GetProperties(mscclpp)
2525
if (NOT mscclpp_POPULATED)

third_party/mscclpp

Submodule mscclpp updated 117 files

0 commit comments

Comments
 (0)