Skip to content

Commit e2a5d69

Browse files
committed
WIP
1 parent 8d3c5d3 commit e2a5d69

File tree

5 files changed

+43
-21
lines changed

5 files changed

+43
-21
lines changed

apps/nccl/src/nccl.cu

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -326,12 +326,8 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI
326326
commPtr->nRanksPerNode = mscclppComm->bootstrap()->getNranksPerNode();
327327
commPtr->worldSize = mscclppComm->bootstrap()->getNranks();
328328
mscclpp::AlgorithmCollectionBuilder::getInstance()->setFallbackAlgorithmSelector(algoSelector);
329-
commPtr->algorithmCollection =
330-
mscclpp::AlgorithmCollectionBuilder::getInstance()->buildCollectionWithDefaultNativeAlgorithms(
331-
reinterpret_cast<uintptr_t>(commPtr->scratchBuffer_.get()), commPtr->scratchBufferSize_);
332-
auto dslAlgoCollection =
333-
mscclpp::AlgorithmCollectionBuilder::getInstance()->buildCollectionWithDefaultDslAlgorithms(rank);
334-
commPtr->algorithmCollection->extend(*dslAlgoCollection);
329+
commPtr->algorithmCollection = mscclpp::AlgorithmCollectionBuilder::getInstance()->buildDefaultAlgorithms(
330+
reinterpret_cast<uintptr_t>(commPtr->scratchBuffer_.get()), commPtr->scratchBufferSize_, rank);
335331

336332
*comm = commPtr;
337333
#if defined(ENABLE_NPKIT)

include/mscclpp/algorithm.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ class AlgorithmCollection {
245245

246246
std::unordered_map<std::string, std::shared_ptr<Algorithm>> getAlgorithmsByCollective(
247247
const std::string& collective) const;
248-
248+
std::vector<std::shared_ptr<Algorithm>> getAllAlgorithms() const;
249249
void extend(const AlgorithmCollection& other);
250250

251251
private:
@@ -265,10 +265,8 @@ class AlgorithmCollectionBuilder {
265265
/// @param builder The algorithm builder.
266266
void addAlgorithmBuilder(std::shared_ptr<AlgorithmBuilder> builder);
267267

268-
std::shared_ptr<AlgorithmCollection> buildCollectionWithDefaultNativeAlgorithms(uintptr_t scratchBuffer,
269-
size_t scratchBufferSize);
270-
271-
std::shared_ptr<AlgorithmCollection> buildCollectionWithDefaultDslAlgorithms(int rank);
268+
std::shared_ptr<AlgorithmCollection> buildDefaultAlgorithms(uintptr_t scratchBuffer, size_t scratchBufferSize,
269+
int rank);
272270

273271
/// @brief Set a new algorithm selection function.
274272
/// @param selector The algorithm selection function.
@@ -291,6 +289,8 @@ class AlgorithmCollectionBuilder {
291289
AlgoSelectFunc fallbackAlgoSelector_ = nullptr;
292290

293291
static std::shared_ptr<AlgorithmCollectionBuilder> gAlgorithmCollectionBuilder_;
292+
std::shared_ptr<AlgorithmCollection> buildDefaultNativeAlgorithms(uintptr_t scratchBuffer, size_t scratchBufferSize);
293+
std::shared_ptr<AlgorithmCollection> buildDefaultDslAlgorithms(int rank);
294294
};
295295

296296
} // namespace mscclpp

python/csrc/algorithm.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <nanobind/stl/shared_ptr.h>
88
#include <nanobind/stl/string.h>
99
#include <nanobind/stl/unordered_map.h>
10+
#include <nanobind/stl/vector.h>
1011

1112
#include <mscclpp/algorithm.hpp>
1213

@@ -84,12 +85,15 @@ void register_algorithm(nb::module_& m) {
8485
.def("set_fallback_algorithm_selector", &AlgorithmCollectionBuilder::setFallbackAlgorithmSelector,
8586
nb::arg("selector"))
8687
.def("build", &AlgorithmCollectionBuilder::build)
88+
.def("build_default_algorithms", &AlgorithmCollectionBuilder::buildDefaultAlgorithms,
89+
nb::arg("scratch_buffer"), nb::arg("scratch_buffer_size"), nb::arg("rank"))
8790
.def_static("reset", &AlgorithmCollectionBuilder::reset);
8891

8992
nb::class_<AlgorithmCollection>(m, "AlgorithmCollection")
9093
.def("register_algorithm", &AlgorithmCollection::registerAlgorithm, nb::arg("collective"), nb::arg("algo_name"),
9194
nb::arg("algorithm"))
92-
.def("get_algorithms_by_collective", &AlgorithmCollection::getAlgorithmsByCollective, nb::arg("collective"));
95+
.def("get_algorithms_by_collective", &AlgorithmCollection::getAlgorithmsByCollective, nb::arg("collective"))
96+
.def("to_list", &AlgorithmCollection::getAllAlgorithms);
9397

9498
nb::class_<CollectiveRequest>(m, "CollectiveRequest")
9599
.def_ro("world_size", &CollectiveRequest::worldSize)

src/algorithms/algorithm.cc

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ void AlgorithmCollection::registerAlgorithm(const std::string collective, const
9191

9292
std::shared_ptr<Algorithm> AlgorithmCollection::selectAlgorithm(const CollectiveRequest& request) {
9393
std::shared_ptr<Algorithm> algo;
94+
if (!algoSelector_ && !fallbackAlgoSelector_) {
95+
THROW(ALGO, Error, ErrorCode::InvalidUsage, "No algorithm selector is set in AlgorithmCollection.");
96+
}
9497
if (algoSelector_) {
9598
algo = algoSelector_(algoMapByCollective_, request);
9699
}
@@ -108,6 +111,16 @@ void AlgorithmCollection::extend(const AlgorithmCollection& other) {
108111
}
109112
}
110113

114+
std::vector<std::shared_ptr<Algorithm>> AlgorithmCollection::getAllAlgorithms() const {
115+
std::vector<std::shared_ptr<Algorithm>> allAlgos;
116+
for (const auto& [collective, algoMap] : algoMapByCollective_) {
117+
for (const auto& [algoName, algorithm] : algoMap) {
118+
allAlgos.push_back(algorithm);
119+
}
120+
}
121+
return allAlgos;
122+
}
123+
111124
std::unordered_map<std::string, std::shared_ptr<Algorithm>> AlgorithmCollection::getAlgorithmsByCollective(
112125
const std::string& collective) const {
113126
auto it = algoMapByCollective_.find(collective);
@@ -130,7 +143,16 @@ void AlgorithmCollectionBuilder::addAlgorithmBuilder(std::shared_ptr<AlgorithmBu
130143
this->algoBuilders_.push_back(builder);
131144
}
132145

133-
std::shared_ptr<AlgorithmCollection> AlgorithmCollectionBuilder::buildCollectionWithDefaultNativeAlgorithms(
146+
std::shared_ptr<AlgorithmCollection> AlgorithmCollectionBuilder::buildDefaultAlgorithms(uintptr_t scratchBuffer,
147+
size_t scratchBufferSize,
148+
int rank) {
149+
auto nativeCollection = buildDefaultNativeAlgorithms(scratchBuffer, scratchBufferSize);
150+
auto dslCollection = buildDefaultDslAlgorithms(rank);
151+
nativeCollection->extend(*dslCollection);
152+
return nativeCollection;
153+
}
154+
155+
std::shared_ptr<AlgorithmCollection> AlgorithmCollectionBuilder::buildDefaultNativeAlgorithms(
134156
uintptr_t scratchBuffer, size_t scratchBufferSize) {
135157
auto collection = std::make_shared<AlgorithmCollection>();
136158
auto allreduceAllpairPkt =
@@ -163,7 +185,7 @@ std::shared_ptr<AlgorithmCollection> AlgorithmCollectionBuilder::buildCollection
163185
return collection;
164186
}
165187

166-
std::shared_ptr<AlgorithmCollection> AlgorithmCollectionBuilder::buildCollectionWithDefaultDslAlgorithms(int rank) {
188+
std::shared_ptr<AlgorithmCollection> AlgorithmCollectionBuilder::buildDefaultDslAlgorithms(int rank) {
167189
struct DslAlgoConfig {
168190
std::string filename;
169191
std::string collective;
@@ -188,14 +210,14 @@ std::shared_ptr<AlgorithmCollection> AlgorithmCollectionBuilder::buildCollection
188210

189211
std::string planDir = mscclpp::env()->executionPlanDir;
190212
if (!std::filesystem::exists(planDir)) {
191-
INFO(EXEC, "Plan directory does not exist: ", planDir);
213+
INFO(ALGO, "Plan directory does not exist: ", planDir);
192214
return collection;
193215
}
194216
for (const auto& config : defaultAlgoConfigs) {
195217
std::string planPath = planDir + "/" + config.filename;
196-
INFO(EXEC, "Loading plan: ", planPath);
218+
INFO(ALGO, "Loading plan: ", planPath);
197219
if (!std::filesystem::exists(planPath)) {
198-
INFO(EXEC, "Plan file does not exist: ", planPath);
220+
INFO(ALGO, "Plan file does not exist: ", planPath);
199221
continue;
200222
}
201223
std::string planId = generateFileId(planPath);
@@ -205,9 +227,9 @@ std::shared_ptr<AlgorithmCollection> AlgorithmCollectionBuilder::buildCollection
205227
auto algoBuilder = std::make_shared<mscclpp::DslAlgorithm>(
206228
planId, executionPlan, config.tags, mscclpp::Algorithm::Constraint{config.worldSize, config.nRanksPerNode});
207229
collectionBuilder->addAlgorithmBuilder(algoBuilder);
208-
INFO(EXEC, "Successfully loaded plan: ", planId, " for collective: ", config.collective);
230+
INFO(ALGO, "Successfully loaded plan: ", planId, " for collective: ", config.collective);
209231
} catch (const std::exception& e) {
210-
WARN(EXEC, "Failed to load plan : ", planPath, " ", e.what());
232+
WARN(ALGO, "Failed to load plan : ", planPath, " ", e.what());
211233
}
212234
}
213235
return collection;
@@ -290,7 +312,7 @@ CommResult DslAlgorithm::execute(std::shared_ptr<Communicator> comm, const void*
290312
executor->execute(rank, (int*)input, (int*)output, inputSize, outputSize, DataType::UINT32, plan_, stream);
291313
break;
292314
default:
293-
WARN(EXEC, "Unsupported data type: ", static_cast<int>(dtype), " in DslAlgorithm");
315+
WARN(ALGO, "Unsupported data type: ", static_cast<int>(dtype), " in DslAlgorithm");
294316
return CommResult::commInvalidArgument;
295317
}
296318
return CommResult::commSuccess;

src/include/logger.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
namespace mscclpp {
1818

1919
typedef enum : unsigned int { NONE = 0, DEBUG, INFO, WARN, ERROR } LogLevel;
20-
typedef enum : std::size_t { ENV = 0, NET, CONN, EXEC, NCCL, COUNT } LogSubsys;
20+
typedef enum : std::size_t { ENV = 0, NET, CONN, EXEC, NCCL, ALGO, COUNT } LogSubsys;
2121

2222
namespace detail {
2323
std::string guessRemoveProjectPrefix(const std::string& filePathStr);

0 commit comments

Comments
 (0)