Skip to content

Commit ceaf4d6

Browse files
committed
wip
1 parent 42ef9cc commit ceaf4d6

File tree

4 files changed

+102
-13
lines changed

4 files changed

+102
-13
lines changed

include/mscclpp/executor.hpp

Lines changed: 93 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,30 +11,51 @@
1111

1212
namespace mscclpp {
1313

14+
/// Data types supported by the executor.
1415
enum class DataType {
15-
INT32,
16-
UINT32,
17-
FLOAT16,
18-
FLOAT32,
19-
BFLOAT16,
20-
FP8_E4M3, // Add FP8 E4M3 type
21-
FP8_E5M2, // Add FP8 E5M2 type
16+
INT32, // 32-bit signed integer.
17+
UINT32, // 32-bit unsigned integer.
18+
FLOAT16, // IEEE 754 half precision.
19+
FLOAT32, // IEEE 754 single precision.
20+
BFLOAT16, // bfloat16 precision.
21+
FP8_E4M3, // FP8 with E4M3 layout.
22+
FP8_E5M2, // FP8 with E5M2 layout.
2223
};
2324

25+
/// Packet formats used by low-latency transport.
2426
enum class PacketType {
25-
LL8,
26-
LL16,
27+
LL8, // 8-byte low-latency packet.
28+
LL16, // 16-byte low-latency packet.
2729
};
2830

31+
/// Represents a compiled execution plan loaded from disk.
32+
///
33+
/// An ExecutionPlan encapsulates metadata about a collective algorithm such as its name, the
34+
/// collective it implements, and the supported message-size range. The concrete implementation
35+
/// is hidden behind the PIMPL pointer.
2936
class ExecutionPlan {
3037
public:
38+
/// Construct an ExecutionPlan by loading the plan file at `planPath`.
39+
/// @param planPath Filesystem path to the serialized plan.
40+
/// @param rank The rank of the current process.
3141
ExecutionPlan(const std::string& planPath, int rank);
42+
43+
/// Destructor.
3244
~ExecutionPlan() = default;
3345

46+
/// Return the human-readable name of the plan.
3447
std::string name() const;
48+
49+
/// Return the collective implemented by this plan (e.g., "allreduce", "allgather").
3550
std::string collective() const;
51+
52+
/// Minimum message size (in bytes) for which this plan is valid.
3653
size_t minMessageSize() const;
54+
55+
/// Maximum message size (in bytes) for which this plan is valid.
3756
size_t maxMessageSize() const;
57+
58+
/// Whether this plan performs the operation in-place.
3859
bool isInPlace() const;
3960

4061
private:
@@ -44,6 +65,7 @@ class ExecutionPlan {
4465
friend class Executor;
4566
};
4667

68+
/// Request parameters provided when executing a plan.
4769
struct ExecutionRequest {
4870
int worldSize;
4971
int nRanksPerNode;
@@ -54,41 +76,76 @@ struct ExecutionRequest {
5476
const std::string& collective;
5577
const std::unordered_map<std::string, std::vector<uint64_t>>& hints;
5678

79+
/// Whether the request indicates an in-place operation.
5780
bool isInPlace() const;
5881
};
5982

83+
/// A handle representing a specific execution plan along with its constraints and metadata.
6084
struct ExecutionPlanHandle {
85+
/// Constraints that must be satisfied for the plan to be valid.
6186
struct Constraint {
6287
int worldSize;
6388
int nRanksPerNode;
6489
};
65-
std::string id;
66-
Constraint constraint;
67-
std::shared_ptr<ExecutionPlan> plan;
68-
std::unordered_map<std::string, uint64_t> tags;
6990

91+
std::string id; /// Unique identifier for the handle.
92+
Constraint constraint; /// Constraints for plan applicability.
93+
std::shared_ptr<ExecutionPlan> plan; /// Backing ExecutionPlan instance.
94+
std::unordered_map<std::string, uint64_t> tags; /// Optional tags/metadata used by selector.
95+
96+
/// Create a new ExecutionPlanHandle.
97+
/// @param id Unique id for the handle.
98+
/// @param worldSize Required world size for the plan.
99+
/// @param nRanksPerNode Required ranks-per-node for the plan.
100+
/// @param plan The associated ExecutionPlan.
101+
/// @param tags Optional tags used for selection.
70102
static std::shared_ptr<ExecutionPlanHandle> create(const std::string& id, int worldSize, int nRanksPerNode,
71103
std::shared_ptr<ExecutionPlan> plan,
72104
const std::unordered_map<std::string, uint64_t>& tags = {});
105+
106+
/// Check whether the given ExecutionRequest satisfies this handle's parameters.
107+
/// @param request The execution request to evaluate.
108+
/// @return True if the request matches the handle parameters, false otherwise.
73109
bool match(const ExecutionRequest& request);
74110
};
75111

112+
/// Selector function type used to pick an ExecutionPlanHandle from a list of candidates.
76113
using ExecutionPlanSelector = std::function<std::shared_ptr<ExecutionPlanHandle>(
77114
const std::vector<std::shared_ptr<ExecutionPlanHandle>> plans, const ExecutionRequest& request)>;
115+
116+
/// Registry that holds available execution plans and performs selection logic.
78117
class ExecutionPlanRegistry {
79118
public:
119+
/// Retrieve the singleton instance of the registry.
80120
static std::shared_ptr<ExecutionPlanRegistry> getInstance();
121+
122+
/// Destructor.
81123
~ExecutionPlanRegistry();
82124

125+
/// Register a plan handle with the registry.
83126
void registerPlan(const std::shared_ptr<ExecutionPlanHandle> planHandle);
127+
128+
/// Get all plan handles for a given collective name.
84129
std::vector<std::shared_ptr<ExecutionPlanHandle>> getPlans(const std::string& collective);
130+
131+
/// Lookup a plan handle by id.
85132
std::shared_ptr<ExecutionPlanHandle> get(const std::string& id);
133+
134+
/// Select a suitable plan handle for the given parameters.
86135
std::shared_ptr<ExecutionPlanHandle> select(const std::string& collective, int worldSize, int nRanksPerNode, int rank,
87136
const void* sendBuffer, void* recvBuffer, size_t messageSize,
88137
const std::unordered_map<std::string, std::vector<uint64_t>>& hints);
138+
139+
/// Provide a custom selector function.
89140
void setSelector(ExecutionPlanSelector selector);
141+
142+
/// Set the default selector used when no custom selector is provided.
90143
void setDefaultSelector(ExecutionPlanSelector selector);
144+
145+
/// Load built-in/default plans for the given rank.
91146
void loadDefaultPlans(int rank);
147+
148+
/// Clear all registered plans from the registry.
92149
void clear();
93150

94151
private:
@@ -97,13 +154,36 @@ class ExecutionPlanRegistry {
97154
ExecutionPlanRegistry();
98155
};
99156

157+
/// High-level executor responsible for invoking execution plans on a communicator.
100158
class Executor {
101159
public:
160+
/// Construct an Executor using the provided communicator.
161+
/// @param comm Communicator instance used for underlying communication.
162+
/// @param defaultScratchBuffer Optional scratch buffer used by some plans (may be nullptr).
102163
Executor(std::shared_ptr<Communicator> comm, std::shared_ptr<char> defaultScratchBuffer = nullptr);
164+
165+
/// Copy construction is disabled for Executor.
103166
Executor(const Executor&) = delete;
167+
168+
/// Copy assignment is disabled for Executor.
104169
Executor& operator=(const Executor&) = delete;
170+
171+
/// Destructor. Cleans up internal resources held by the Executor.
105172
~Executor();
106173

174+
/// Execute a plan.
175+
///
176+
/// This method dispatches the given plan on the provided CUDA stream.
177+
///
178+
/// @param rank Rank of the calling process.
179+
/// @param sendbuff Pointer to the send buffer.
180+
/// @param recvBuff Pointer to the receive buffer.
181+
/// @param sendBuffSize Size of the send buffer in bytes.
182+
/// @param recvBuffSize Size of the receive buffer in bytes.
183+
/// @param dataType Data type of elements in the buffers.
184+
/// @param plan The execution plan to run.
185+
/// @param stream CUDA stream to execute kernels/operations on.
186+
/// @param packetType Packet type used for low-latency transports (default: LL16).
107187
void execute(int rank, void* sendbuff, void* recvBuff, size_t sendBuffSize, size_t recvBuffSize, DataType dataType,
108188
const ExecutionPlan& plan, cudaStream_t stream, PacketType packetType = PacketType::LL16);
109189

python/mscclpp/__main__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
14
import os
25
import shutil
36
import argparse
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
14
from mscclpp.language.default_algos.allreduce_2nodes import allreduce_2nodes
25

36
__all__ = ["allreduce_2nodes"]

python/mscclpp/language/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
14
from enum import Enum
25
from dataclasses import dataclass, field
36
from mscclpp.language.collectives import Collective

0 commit comments

Comments
 (0)