1111
1212namespace mscclpp {
1313
14+ // / Data types supported by the executor.
1415enum class DataType {
15- INT32,
16- UINT32,
17- FLOAT16,
18- FLOAT32,
19- BFLOAT16,
20- FP8_E4M3, // Add FP8 E4M3 type
21- FP8_E5M2, // Add FP8 E5M2 type
16+ INT32, // 32-bit signed integer.
17+ UINT32, // 32-bit unsigned integer.
18+ FLOAT16, // IEEE 754 half precision.
19+ FLOAT32, // IEEE 754 single precision.
20+ BFLOAT16, // bfloat16 precision.
21+ FP8_E4M3, // FP8 with E4M3 layout.
22+ FP8_E5M2, // FP8 with E5M2 layout.
2223};
2324
25+ // / Packet formats used by low-latency transport.
2426enum class PacketType {
25- LL8,
26- LL16,
27+ LL8, // 8-byte low-latency packet.
28+ LL16, // 16-byte low-latency packet.
2729};
2830
31+ // / Represents a compiled execution plan loaded from disk.
32+ // /
33+ // / An ExecutionPlan encapsulates metadata about a collective algorithm such as its name, the
34+ // / collective it implements, and the supported message-size range. The concrete implementation
35+ // / is hidden behind the PIMPL pointer.
2936class ExecutionPlan {
3037 public:
38+ // / Construct an ExecutionPlan by loading the plan file at `planPath`.
39+ // / @param planPath Filesystem path to the serialized plan.
40+ // / @param rank The rank of the current process.
3141 ExecutionPlan (const std::string& planPath, int rank);
42+
43+ // / Destructor.
3244 ~ExecutionPlan () = default ;
3345
46+ // / Return the human-readable name of the plan.
3447 std::string name () const ;
48+
49+ // / Return the collective implemented by this plan (e.g., "allreduce", "allgather").
3550 std::string collective () const ;
51+
52+ // / Minimum message size (in bytes) for which this plan is valid.
3653 size_t minMessageSize () const ;
54+
55+ // / Maximum message size (in bytes) for which this plan is valid.
3756 size_t maxMessageSize () const ;
57+
58+ // / Whether this plan performs the operation in-place.
3859 bool isInPlace () const ;
3960
4061 private:
@@ -44,6 +65,7 @@ class ExecutionPlan {
4465 friend class Executor ;
4566};
4667
68+ // / Request parameters provided when executing a plan.
4769struct ExecutionRequest {
4870 int worldSize;
4971 int nRanksPerNode;
@@ -54,41 +76,76 @@ struct ExecutionRequest {
5476 const std::string& collective;
5577 const std::unordered_map<std::string, std::vector<uint64_t >>& hints;
5678
79+ // / Whether the request indicates an in-place operation.
5780 bool isInPlace () const ;
5881};
5982
83+ // / A handle representing a specific execution plan along with its constraints and metadata.
6084struct ExecutionPlanHandle {
85+ // / Constraints that must be satisfied for the plan to be valid.
6186 struct Constraint {
6287 int worldSize;
6388 int nRanksPerNode;
6489 };
65- std::string id;
66- Constraint constraint;
67- std::shared_ptr<ExecutionPlan> plan;
68- std::unordered_map<std::string, uint64_t > tags;
6990
91+ std::string id; // / Unique identifier for the handle.
92+ Constraint constraint; // / Constraints for plan applicability.
93+ std::shared_ptr<ExecutionPlan> plan; // / Backing ExecutionPlan instance.
94+ std::unordered_map<std::string, uint64_t > tags; // / Optional tags/metadata used by selector.
95+
96+ // / Create a new ExecutionPlanHandle.
97+ // / @param id Unique id for the handle.
98+ // / @param worldSize Required world size for the plan.
99+ // / @param nRanksPerNode Required ranks-per-node for the plan.
100+ // / @param plan The associated ExecutionPlan.
101+ // / @param tags Optional tags used for selection.
70102 static std::shared_ptr<ExecutionPlanHandle> create (const std::string& id, int worldSize, int nRanksPerNode,
71103 std::shared_ptr<ExecutionPlan> plan,
72104 const std::unordered_map<std::string, uint64_t >& tags = {});
105+
106+ // / Check whether the given ExecutionRequest satisfies this handle's parameters.
107+ // / @param request The execution request to evaluate.
108+ // / @return True if the request matches the handle parameters, false otherwise.
73109 bool match (const ExecutionRequest& request);
74110};
75111
112+ // / Selector function type used to pick an ExecutionPlanHandle from a list of candidates.
76113using ExecutionPlanSelector = std::function<std::shared_ptr<ExecutionPlanHandle>(
77114 const std::vector<std::shared_ptr<ExecutionPlanHandle>> plans, const ExecutionRequest& request)>;
115+
116+ // / Registry that holds available execution plans and performs selection logic.
78117class ExecutionPlanRegistry {
79118 public:
119+ // / Retrieve the singleton instance of the registry.
80120 static std::shared_ptr<ExecutionPlanRegistry> getInstance ();
121+
122+ // / Destructor.
81123 ~ExecutionPlanRegistry ();
82124
125+ // / Register a plan handle with the registry.
83126 void registerPlan (const std::shared_ptr<ExecutionPlanHandle> planHandle);
127+
128+ // / Get all plan handles for a given collective name.
84129 std::vector<std::shared_ptr<ExecutionPlanHandle>> getPlans (const std::string& collective);
130+
131+ // / Lookup a plan handle by id.
85132 std::shared_ptr<ExecutionPlanHandle> get (const std::string& id);
133+
134+ // / Select a suitable plan handle for the given parameters.
86135 std::shared_ptr<ExecutionPlanHandle> select (const std::string& collective, int worldSize, int nRanksPerNode, int rank,
87136 const void * sendBuffer, void * recvBuffer, size_t messageSize,
88137 const std::unordered_map<std::string, std::vector<uint64_t >>& hints);
138+
139+ // / Provide a custom selector function.
89140 void setSelector (ExecutionPlanSelector selector);
141+
142+ // / Set the default selector used when no custom selector is provided.
90143 void setDefaultSelector (ExecutionPlanSelector selector);
144+
145+ // / Load built-in/default plans for the given rank.
91146 void loadDefaultPlans (int rank);
147+
148+ // / Clear all registered plans from the registry.
92149 void clear ();
93150
94151 private:
@@ -97,13 +154,36 @@ class ExecutionPlanRegistry {
97154 ExecutionPlanRegistry ();
98155};
99156
157+ // / High-level executor responsible for invoking execution plans on a communicator.
100158class Executor {
101159 public:
160+ // / Construct an Executor using the provided communicator.
161+ // / @param comm Communicator instance used for underlying communication.
162+ // / @param defaultScratchBuffer Optional scratch buffer used by some plans (may be nullptr).
102163 Executor (std::shared_ptr<Communicator> comm, std::shared_ptr<char > defaultScratchBuffer = nullptr );
164+
165+ // / Copy construction is disabled for Executor.
103166 Executor (const Executor&) = delete ;
167+
168+ // / Copy assignment is disabled for Executor.
104169 Executor& operator =(const Executor&) = delete ;
170+
171+ // / Destructor. Cleans up internal resources held by the Executor.
105172 ~Executor ();
106173
174+ // / Execute a plan.
175+ // /
176+ // / This method dispatches the given plan on the provided CUDA stream.
177+ // /
178+ // / @param rank Rank of the calling process.
179+ // / @param sendbuff Pointer to the send buffer.
180+ // / @param recvBuff Pointer to the receive buffer.
181+ // / @param sendBuffSize Size of the send buffer in bytes.
182+ // / @param recvBuffSize Size of the receive buffer in bytes.
183+ // / @param dataType Data type of elements in the buffers.
184+ // / @param plan The execution plan to run.
185+ // / @param stream CUDA stream to execute kernels/operations on.
186+ // / @param packetType Packet type used for low-latency transports (default: LL16).
107187 void execute (int rank, void * sendbuff, void * recvBuff, size_t sendBuffSize, size_t recvBuffSize, DataType dataType,
108188 const ExecutionPlan& plan, cudaStream_t stream, PacketType packetType = PacketType::LL16);
109189
0 commit comments