@@ -91,6 +91,9 @@ void AlgorithmCollection::registerAlgorithm(const std::string collective, const
9191
9292std::shared_ptr<Algorithm> AlgorithmCollection::selectAlgorithm (const CollectiveRequest& request) {
9393 std::shared_ptr<Algorithm> algo;
94+ if (!algoSelector_ && !fallbackAlgoSelector_) {
95+ THROW (ALGO, Error, ErrorCode::InvalidUsage, " No algorithm selector is set in AlgorithmCollection." );
96+ }
9497 if (algoSelector_) {
9598 algo = algoSelector_ (algoMapByCollective_, request);
9699 }
@@ -108,6 +111,16 @@ void AlgorithmCollection::extend(const AlgorithmCollection& other) {
108111 }
109112}
110113
114+ std::vector<std::shared_ptr<Algorithm>> AlgorithmCollection::getAllAlgorithms () const {
115+ std::vector<std::shared_ptr<Algorithm>> allAlgos;
116+ for (const auto & [collective, algoMap] : algoMapByCollective_) {
117+ for (const auto & [algoName, algorithm] : algoMap) {
118+ allAlgos.push_back (algorithm);
119+ }
120+ }
121+ return allAlgos;
122+ }
123+
111124std::unordered_map<std::string, std::shared_ptr<Algorithm>> AlgorithmCollection::getAlgorithmsByCollective (
112125 const std::string& collective) const {
113126 auto it = algoMapByCollective_.find (collective);
@@ -130,7 +143,16 @@ void AlgorithmCollectionBuilder::addAlgorithmBuilder(std::shared_ptr<AlgorithmBu
130143 this ->algoBuilders_ .push_back (builder);
131144}
132145
133- std::shared_ptr<AlgorithmCollection> AlgorithmCollectionBuilder::buildCollectionWithDefaultNativeAlgorithms (
146+ std::shared_ptr<AlgorithmCollection> AlgorithmCollectionBuilder::buildDefaultAlgorithms (uintptr_t scratchBuffer,
147+ size_t scratchBufferSize,
148+ int rank) {
149+ auto nativeCollection = buildDefaultNativeAlgorithms (scratchBuffer, scratchBufferSize);
150+ auto dslCollection = buildDefaultDslAlgorithms (rank);
151+ nativeCollection->extend (*dslCollection);
152+ return nativeCollection;
153+ }
154+
155+ std::shared_ptr<AlgorithmCollection> AlgorithmCollectionBuilder::buildDefaultNativeAlgorithms (
134156 uintptr_t scratchBuffer, size_t scratchBufferSize) {
135157 auto collection = std::make_shared<AlgorithmCollection>();
136158 auto allreduceAllpairPkt =
@@ -163,7 +185,7 @@ std::shared_ptr<AlgorithmCollection> AlgorithmCollectionBuilder::buildCollection
163185 return collection;
164186}
165187
166- std::shared_ptr<AlgorithmCollection> AlgorithmCollectionBuilder::buildCollectionWithDefaultDslAlgorithms (int rank) {
188+ std::shared_ptr<AlgorithmCollection> AlgorithmCollectionBuilder::buildDefaultDslAlgorithms (int rank) {
167189 struct DslAlgoConfig {
168190 std::string filename;
169191 std::string collective;
@@ -188,14 +210,14 @@ std::shared_ptr<AlgorithmCollection> AlgorithmCollectionBuilder::buildCollection
188210
189211 std::string planDir = mscclpp::env ()->executionPlanDir ;
190212 if (!std::filesystem::exists (planDir)) {
191- INFO (EXEC , " Plan directory does not exist: " , planDir);
213+ INFO (ALGO , " Plan directory does not exist: " , planDir);
192214 return collection;
193215 }
194216 for (const auto & config : defaultAlgoConfigs) {
195217 std::string planPath = planDir + " /" + config.filename ;
196- INFO (EXEC , " Loading plan: " , planPath);
218+ INFO (ALGO , " Loading plan: " , planPath);
197219 if (!std::filesystem::exists (planPath)) {
198- INFO (EXEC , " Plan file does not exist: " , planPath);
220+ INFO (ALGO , " Plan file does not exist: " , planPath);
199221 continue ;
200222 }
201223 std::string planId = generateFileId (planPath);
@@ -205,9 +227,9 @@ std::shared_ptr<AlgorithmCollection> AlgorithmCollectionBuilder::buildCollection
205227 auto algoBuilder = std::make_shared<mscclpp::DslAlgorithm>(
206228 planId, executionPlan, config.tags , mscclpp::Algorithm::Constraint{config.worldSize , config.nRanksPerNode });
207229 collectionBuilder->addAlgorithmBuilder (algoBuilder);
208- INFO (EXEC , " Successfully loaded plan: " , planId, " for collective: " , config.collective );
230+ INFO (ALGO , " Successfully loaded plan: " , planId, " for collective: " , config.collective );
209231 } catch (const std::exception& e) {
210- WARN (EXEC , " Failed to load plan : " , planPath, " " , e.what ());
232+ WARN (ALGO , " Failed to load plan : " , planPath, " " , e.what ());
211233 }
212234 }
213235 return collection;
@@ -290,7 +312,7 @@ CommResult DslAlgorithm::execute(std::shared_ptr<Communicator> comm, const void*
290312 executor->execute (rank, (int *)input, (int *)output, inputSize, outputSize, DataType::UINT32, plan_, stream);
291313 break ;
292314 default :
293- WARN (EXEC , " Unsupported data type: " , static_cast <int >(dtype), " in DslAlgorithm" );
315+ WARN (ALGO , " Unsupported data type: " , static_cast <int >(dtype), " in DslAlgorithm" );
294316 return CommResult::commInvalidArgument;
295317 }
296318 return CommResult::commSuccess;
0 commit comments