Skip to content

Commit c6e06cf

Browse files
authored
Executor AllGather In-Place Support (#365)
1 parent 4136153 commit c6e06cf

File tree

4 files changed

+95
-41
lines changed

4 files changed

+95
-41
lines changed

python/test/executor_test.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,15 @@ def dtype_to_mscclpp_dtype(dtype):
7777
raise ValueError(f"Unknown data type: {dtype}")
7878

7979

80+
def determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name):
81+
if "allgather" in execution_plan_name:
82+
return recvbuf
83+
elif in_place:
84+
return sendbuf
85+
else:
86+
return recvbuf
87+
88+
8089
def main(
8190
execution_plan_name: str,
8291
execution_plan_path: str,
@@ -104,9 +113,11 @@ def main(
104113

105114
if "allgather" in execution_plan_name:
106115
recvbuf = cp.zeros(nelems * mscclpp_group.nranks, dtype=dtype)
116+
if in_place:
117+
for i in range(nelems):
118+
recvbuf[mscclpp_group.my_rank * nelems + i] = sendbuf[i]
107119
expected = buffer
108120
else:
109-
cp.random.seed(seed)
110121
recvbuf = cp.zeros(nelems, dtype=dtype)
111122
expected = cp.zeros_like(sendbuf, dtype=dtype)
112123
for i in range(mscclpp_group.nranks):
@@ -116,9 +127,9 @@ def main(
116127
executor_func = lambda stream: executor.execute(
117128
MPI.COMM_WORLD.rank,
118129
sendbuf.data.ptr,
119-
sendbuf.data.ptr if in_place else recvbuf.data.ptr,
130+
determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name).data.ptr,
120131
sendbuf.nbytes,
121-
sendbuf.nbytes if in_place else recvbuf.nbytes,
132+
determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name).nbytes,
122133
dtype_to_mscclpp_dtype(dtype),
123134
execution_plan,
124135
stream.ptr,
@@ -129,10 +140,14 @@ def main(
129140
executor_func(stream)
130141
stream.synchronize()
131142

132-
assert cp.allclose(sendbuf if in_place else recvbuf, expected, atol=1e-2 * mscclpp_group.nranks)
143+
assert cp.allclose(
144+
determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name),
145+
expected,
146+
atol=1e-2 * mscclpp_group.nranks,
147+
)
133148

134149
mscclpp_group.barrier()
135-
execution_time = bench_time(100, 10, executor_func)
150+
execution_time = bench_time(10, 10, executor_func)
136151
if npkit_dump_dir is not None:
137152
npkit.dump(npkit_dump_dir)
138153
npkit.shutdown()

src/executor/execution_plan.cc

Lines changed: 58 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,19 @@ std::vector<BufferType> ExecutionPlan::Impl::getConnectedBufferTypes(int rank) c
148148
}
149149
return std::vector<BufferType>(bufferTypes.begin(), bufferTypes.end());
150150
}
151-
size_t ExecutionPlan::Impl::getScratchBufferSize(int rank, size_t inputSize) const {
151+
size_t ExecutionPlan::Impl::getScratchBufferSize(int rank, size_t inputSize, size_t outputSize) const {
152+
size_t sizePerRank;
153+
if (this->inputChunks.at(rank) != 0)
154+
sizePerRank = inputSize / this->inputChunks.at(rank);
155+
else if (this->outputChunks.at(rank) != 0)
156+
sizePerRank = outputSize / this->outputChunks.at(rank);
157+
else
158+
throw mscclpp::Error("Output or Input chunks must be greater than 0", mscclpp::ErrorCode::ExecutorError);
159+
152160
if (this->isUsingPacket) {
153-
return inputSize / this->inputChunks.at(rank) * this->scratchChunks.at(rank) * 2 /* data + flag*/ *
154-
2 /*double buffer*/;
161+
return sizePerRank * this->scratchChunks.at(rank) * 2 /* data + flag*/ * 2 /*double buffer*/;
155162
}
156-
return inputSize / this->inputChunks.at(rank) * this->scratchChunks.at(rank);
163+
return sizePerRank * this->scratchChunks.at(rank);
157164
}
158165
std::vector<Operation> ExecutionPlan::Impl::getOperations(int rank, int threadblock) const {
159166
return this->operations.at(rank)[threadblock];
@@ -163,7 +170,8 @@ int ExecutionPlan::Impl::getThreadblockCount(int rank) const { return this->oper
163170

164171
int ExecutionPlan::Impl::getNThreadsPerBlock() const { return this->nThreadsPerBlock; }
165172

166-
void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t contsSrcOffset, size_t constDstOffset) {
173+
void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t outputSize, size_t contsSrcOffset,
174+
size_t constDstOffset) {
167175
std::ifstream file(this->planPath);
168176
json obj = json::parse(file);
169177
if (this->name != obj["name"]) {
@@ -186,10 +194,12 @@ void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t contsSrcOff
186194
this->setupChannels(gpus);
187195

188196
this->inputSize = inputSize;
197+
this->outputSize = outputSize;
189198
this->setupOperations(gpus, contsSrcOffset, constDstOffset);
190199
}
191200

192-
void ExecutionPlan::Impl::lightLoadExecutionPlan(size_t inputSize, size_t contsSrcOffset, size_t constDstOffset) {
201+
void ExecutionPlan::Impl::lightLoadExecutionPlan(size_t inputSize, size_t outputSize, size_t contsSrcOffset,
202+
size_t constDstOffset) {
193203
std::ifstream file(this->planPath);
194204
json obj = json::parse(file);
195205
if (this->name != obj["name"]) {
@@ -210,6 +220,7 @@ void ExecutionPlan::Impl::lightLoadExecutionPlan(size_t inputSize, size_t contsS
210220
}
211221

212222
this->inputSize = inputSize;
223+
this->outputSize = outputSize;
213224
this->setupOperations(gpus, contsSrcOffset, constDstOffset);
214225
}
215226

@@ -313,8 +324,9 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffse
313324
// Get the relevant channel index in rank channelInfos
314325
operation.inputChannelIndexes[i] =
315326
channelIndexes[{srcBufferType, dstBufferType, operation.channelType}][op["i_cids"][i]["id"]];
316-
operation.inputOffsets[i] = this->getOffset(rank, this->inputSize, (uint32_t)op["i_cids"][i]["off"]) +
317-
(srcBufferType != BufferType::SCRATCH ? contsSrcOffset : 0);
327+
operation.inputOffsets[i] =
328+
this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["i_cids"][i]["off"]) +
329+
(srcBufferType != BufferType::SCRATCH ? contsSrcOffset : 0);
318330
chunkIndexes.push_back((uint32_t)op["i_cids"][i]["off"]);
319331
}
320332
}
@@ -323,8 +335,9 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffse
323335
operation.nInputs = op["srcs"].size();
324336
operation.inputBufferType = convertToBufferType(op["srcs"][0]["buff"]);
325337
for (int i = 0; i < operation.nInputs; i++) {
326-
operation.inputOffsets[i] = this->getOffset(rank, this->inputSize, (uint32_t)op["srcs"][i]["off"]) +
327-
(operation.inputBufferType != BufferType::SCRATCH ? contsSrcOffset : 0);
338+
operation.inputOffsets[i] =
339+
this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["srcs"][i]["off"]) +
340+
(operation.inputBufferType != BufferType::SCRATCH ? contsSrcOffset : 0);
328341
chunkIndexes.push_back((uint32_t)op["srcs"][i]["off"]);
329342
}
330343
}
@@ -335,8 +348,9 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffse
335348
BufferType dstBufferType = convertToBufferType(op["o_buff"]["dst"]);
336349
operation.outputChannelIndexes[i] =
337350
channelIndexes[{srcBufferType, dstBufferType, operation.channelType}][op["o_cids"][i]["id"]];
338-
operation.outputOffsets[i] = this->getOffset(rank, this->inputSize, (uint32_t)op["o_cids"][i]["off"]) +
339-
(dstBufferType != BufferType::SCRATCH ? constDstOffset : 0);
351+
operation.outputOffsets[i] =
352+
this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["o_cids"][i]["off"]) +
353+
(dstBufferType != BufferType::SCRATCH ? constDstOffset : 0);
340354
chunkIndexes.push_back((uint32_t)op["o_cids"][i]["off"]);
341355
}
342356
}
@@ -345,27 +359,29 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffse
345359
operation.nOutputs = op["dsts"].size();
346360
operation.outputBufferType = convertToBufferType(op["dsts"][0]["buff"]);
347361
for (int i = 0; i < operation.nOutputs; i++) {
348-
operation.outputOffsets[i] = this->getOffset(rank, this->inputSize, (uint32_t)op["dsts"][i]["off"]) +
349-
(operation.outputBufferType != BufferType::SCRATCH ? constDstOffset : 0);
362+
operation.outputOffsets[i] =
363+
this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["dsts"][i]["off"]) +
364+
(operation.outputBufferType != BufferType::SCRATCH ? constDstOffset : 0);
350365
chunkIndexes.push_back((uint32_t)op["dsts"][i]["off"]);
351366
}
352367
}
353368
if (op.contains("srcbuff")) {
354369
operation.srcBufferType = convertToBufferType(op["srcbuff"]);
355370
}
356371
if (op.contains("srcoff")) {
357-
operation.srcOffset = this->getOffset(rank, this->inputSize, (uint32_t)op["srcoff"]);
372+
operation.srcOffset = this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["srcoff"]);
358373
chunkIndexes.push_back((uint32_t)op["srcoff"]);
359374
}
360375
if (op.contains("dstbuff")) {
361376
operation.dstBufferType = convertToBufferType(op["dstbuff"]);
362377
}
363378
if (op.contains("dstoff")) {
364-
operation.dstOffset = this->getOffset(rank, this->inputSize, (uint32_t)op["dstoff"]);
379+
operation.dstOffset = this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["dstoff"]);
365380
chunkIndexes.push_back((uint32_t)op["dstoff"]);
366381
}
367382
if (op.contains("cnt")) {
368-
operation.size = this->getNChunkSize(rank, this->inputSize, (uint32_t)op["cnt"], chunkIndexes);
383+
operation.size =
384+
this->getNChunkSize(rank, this->inputSize, this->outputSize, (uint32_t)op["cnt"], chunkIndexes);
369385
}
370386
ops.push_back(operation);
371387
}
@@ -374,14 +390,33 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffse
374390
}
375391
}
376392

377-
size_t ExecutionPlan::Impl::getOffset(int rank, size_t inputSize, uint32_t chunkIndex, uint32_t alignment) const {
393+
std::pair<size_t, u_int32_t> ExecutionPlan::Impl::calcSizePerRank(int rank, size_t inputSize, size_t outputSize) const {
394+
std::pair<size_t, u_int32_t> sizePerRank;
395+
if (this->inputChunks.at(rank) == 0 && this->outputChunks.at(rank) == 0) {
396+
throw mscclpp::Error("Output or Input chunks must be greater than 0", mscclpp::ErrorCode::ExecutorError);
397+
} else if (this->inputChunks.at(rank) != 0 && this->outputChunks.at(rank) != 0) {
398+
if (inputSize / this->inputChunks.at(rank) != outputSize / this->outputChunks.at(rank))
399+
throw mscclpp::Error("Size per chunks inconsistent", mscclpp::ErrorCode::ExecutorError);
400+
else
401+
sizePerRank = std::make_pair(inputSize, this->inputChunks.at(rank));
402+
} else if (this->inputChunks.at(rank) != 0) {
403+
sizePerRank = std::make_pair(inputSize, this->inputChunks.at(rank));
404+
} else if (this->outputChunks.at(rank) != 0) {
405+
sizePerRank = std::make_pair(outputSize, this->outputChunks.at(rank));
406+
}
407+
return sizePerRank;
408+
}
409+
410+
size_t ExecutionPlan::Impl::getOffset(int rank, size_t inputSize, size_t outputSize, uint32_t chunkIndex,
411+
uint32_t alignment) const {
378412
if (inputSize % alignment != 0) {
379413
throw Error("inputSize must be a multiple of alignment", ErrorCode::ExecutorError);
380414
}
381415

382416
const int nGroups = this->chunkGroups.at(rank);
383-
uint32_t nInputChunks = this->inputChunks.at(rank);
384-
uint32_t nelems = inputSize / (alignment * sizeof(uint8_t));
417+
auto sizePerRank = calcSizePerRank(rank, inputSize, outputSize);
418+
uint32_t nInputChunks = sizePerRank.second;
419+
uint32_t nelems = sizePerRank.first / (alignment * sizeof(uint8_t));
385420
if (nelems % nGroups != 0) {
386421
throw Error("Input size must be a multiple of nGroups", ErrorCode::ExecutorError);
387422
}
@@ -397,12 +432,12 @@ size_t ExecutionPlan::Impl::getOffset(int rank, size_t inputSize, uint32_t chunk
397432
return static_cast<size_t>(offset) * alignment;
398433
}
399434

400-
size_t ExecutionPlan::Impl::getNChunkSize(int rank, size_t inputSize, uint32_t nChunks,
435+
size_t ExecutionPlan::Impl::getNChunkSize(int rank, size_t inputSize, size_t outputSize, uint32_t nChunks,
401436
const std::vector<uint32_t> chunkIndexes) const {
402437
size_t nChunkSize = 0;
403438
for (uint32_t index : chunkIndexes) {
404-
uint32_t beginOff = getOffset(rank, inputSize, index);
405-
uint32_t endOff = getOffset(rank, inputSize, index + nChunks);
439+
uint32_t beginOff = getOffset(rank, inputSize, outputSize, index);
440+
uint32_t endOff = getOffset(rank, inputSize, outputSize, index + nChunks);
406441
if (nChunkSize == 0) {
407442
nChunkSize = endOff - beginOff;
408443
} else if (nChunkSize != endOff - beginOff) {

src/executor/executor.cc

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,13 @@ struct Executor::Impl {
8080
}
8181
~Impl() = default;
8282

83-
ExecutionContext setupExecutionContext(int rank, void* sendbuff, void* recvbuff, size_t messageSize,
84-
size_t contsSrcOffset, size_t constDstOffset, size_t sendBufferSize,
85-
size_t recvBufferSize, const ExecutionPlan& plan) {
83+
ExecutionContext setupExecutionContext(int rank, void* sendbuff, void* recvbuff, size_t inputMessageSize,
84+
size_t outputMessageSize, size_t contsSrcOffset, size_t constDstOffset,
85+
size_t sendBufferSize, size_t recvBufferSize, const ExecutionPlan& plan) {
8686
ExecutionContextKey key = {sendbuff, recvbuff, sendBufferSize, recvBufferSize, plan.impl_->name};
8787
if (this->contexts.find(key) != this->contexts.end()) {
8888
plan.impl_->operationsReset();
89-
plan.impl_->lightLoadExecutionPlan(messageSize, contsSrcOffset, constDstOffset);
89+
plan.impl_->lightLoadExecutionPlan(inputMessageSize, outputMessageSize, contsSrcOffset, constDstOffset);
9090
this->setupDeviceExecutionPlan(this->contexts[key], rank, plan);
9191
this->contexts[key].deviceExecutionPlansBuffer =
9292
allocExtSharedCuda<char>(this->contexts[key].deviceExecutionPlans.size() * sizeof(DeviceExecutionPlan));
@@ -97,10 +97,10 @@ struct Executor::Impl {
9797
}
9898

9999
plan.impl_->reset();
100-
plan.impl_->loadExecutionPlan(messageSize, contsSrcOffset, constDstOffset);
100+
plan.impl_->loadExecutionPlan(inputMessageSize, outputMessageSize, contsSrcOffset, constDstOffset);
101101

102102
ExecutionContext context;
103-
size_t scratchBufferSize = plan.impl_->getScratchBufferSize(rank, sendBufferSize);
103+
size_t scratchBufferSize = plan.impl_->getScratchBufferSize(rank, sendBufferSize, recvBufferSize);
104104
std::shared_ptr<char> scratchBuffer = allocExtSharedCuda<char>(scratchBufferSize);
105105
context.scratchBuffer = scratchBuffer;
106106
context.scratchBufferSize = scratchBufferSize;
@@ -350,8 +350,9 @@ void Executor::execute(int rank, void* sendbuff, void* recvbuff, size_t sendBuff
350350
size_t offsetIn = (char*)sendbuff - (char*)sendBasePtr;
351351
size_t offsetOut = (char*)recvbuff - (char*)recvBasePtr;
352352

353-
ExecutionContext context = this->impl_->setupExecutionContext(
354-
rank, (void*)sendBasePtr, (void*)recvBasePtr, sendBuffSize, offsetIn, offsetOut, sendBytes, recvBytes, plan);
353+
ExecutionContext context =
354+
this->impl_->setupExecutionContext(rank, (void*)sendBasePtr, (void*)recvBasePtr, sendBuffSize, recvBuffSize,
355+
offsetIn, offsetOut, sendBytes, recvBytes, plan);
355356
this->impl_->launchKernel(context, rank, sendbuff, recvbuff, dataType, stream, packetType);
356357
}
357358

src/include/execution_plan.hpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,13 @@ struct ExecutionPlan::Impl {
6565
std::vector<ChannelInfo> getUnpairedChannelInfos(int rank, int worldSize, ChannelType channelType);
6666
std::vector<int> getConnectedPeers(int rank) const;
6767
std::vector<BufferType> getConnectedBufferTypes(int rank) const;
68-
size_t getScratchBufferSize(int rank, size_t inputSize) const;
68+
size_t getScratchBufferSize(int rank, size_t inputSize, size_t outputSize) const;
6969
std::vector<Operation> getOperations(int rank, int threadblock) const;
7070
int getThreadblockCount(int rank) const;
7171
int getNThreadsPerBlock() const;
7272

73-
void loadExecutionPlan(size_t inputSize, size_t contsSrcOffset, size_t constDstOffset);
74-
void lightLoadExecutionPlan(size_t inputSize, size_t contsSrcOffset, size_t constDstOffset);
73+
void loadExecutionPlan(size_t inputSize, size_t outputSize, size_t contsSrcOffset, size_t constDstOffset);
74+
void lightLoadExecutionPlan(size_t inputSize, size_t outputSize, size_t contsSrcOffset, size_t constDstOffset);
7575
void setupChannels(const nlohmann::json& gpus);
7676
void setupOperations(const nlohmann::json& gpus, size_t contsSrcOffset, size_t constDstOffset);
7777

@@ -94,11 +94,14 @@ struct ExecutionPlan::Impl {
9494
std::unordered_map<int, uint32_t> scratchChunks;
9595
std::unordered_map<int, uint32_t> chunkGroups;
9696
size_t inputSize;
97+
size_t outputSize;
9798
int nThreadsPerBlock;
9899

99100
private:
100-
size_t getOffset(int rank, size_t inputSize, uint32_t chunkIndex, uint32_t alignment = 16) const;
101-
size_t getNChunkSize(int rank, size_t inputSize, uint32_t nChunks, const std::vector<uint32_t> offsets) const;
101+
std::pair<size_t, u_int32_t> calcSizePerRank(int rank, size_t inputSize, size_t outputSize) const;
102+
size_t getOffset(int rank, size_t inputSize, size_t outputSize, uint32_t chunkIndex, uint32_t alignment = 16) const;
103+
size_t getNChunkSize(int rank, size_t inputSize, size_t outputSize, uint32_t nChunks,
104+
const std::vector<uint32_t> offsets) const;
102105
};
103106

104107
} // namespace mscclpp

0 commit comments

Comments
 (0)