Skip to content

Commit aaff57f

Browse files
committed
fix reduce primitive and plan validation
1 parent 0d43a21 commit aaff57f

File tree

4 files changed

+49
-14
lines changed

4 files changed

+49
-14
lines changed

src/executor/execution_plan.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,15 @@ void ExecutionPlan::Impl::setupChannels(const json& gpus) {
305305
}
306306
}
307307

308+
void ExecutionPlan::Impl::checkChannelsPerOperation(int channels) {
309+
if (channels > MAX_CHANNEL_PER_OPERATION) {
310+
throw Error("Executor plan has " + std::to_string(channels) +
311+
" channels per operation, exceeding executor support (" +
312+
std::to_string(MAX_CHANNEL_PER_OPERATION) + ")",
313+
ErrorCode::ExecutorError);
314+
}
315+
}
316+
308317
void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffset, size_t constDstOffset) {
309318
// setup threadblocks and operations
310319
for (const auto& gpu : gpus) {
@@ -332,6 +341,7 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffse
332341
}
333342
if (op.contains("i_cids")) {
334343
operation.nInputs = op["i_cids"].size();
344+
checkChannelsPerOperation(operation.nInputs);
335345
for (int i = 0; i < operation.nInputs; i++) {
336346
BufferType srcBufferType = convertToBufferType(op["i_buff"]["src"]);
337347
BufferType dstBufferType = convertToBufferType(op["i_buff"]["dst"]);
@@ -347,6 +357,7 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffse
347357
// will have either srcs or i_cids
348358
if (op.contains("srcs")) {
349359
operation.nInputs = op["srcs"].size();
360+
checkChannelsPerOperation(operation.nInputs);
350361
operation.inputBufferType = convertToBufferType(op["srcs"][0]["buff"]);
351362
for (int i = 0; i < operation.nInputs; i++) {
352363
operation.inputOffsets[i] =
@@ -357,6 +368,7 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffse
357368
}
358369
if (op.contains("o_cids")) {
359370
operation.nOutputs = op["o_cids"].size();
371+
checkChannelsPerOperation(operation.nOutputs);
360372
for (int i = 0; i < operation.nOutputs; i++) {
361373
BufferType srcBufferType = convertToBufferType(op["o_buff"]["src"]);
362374
BufferType dstBufferType = convertToBufferType(op["o_buff"]["dst"]);
@@ -371,6 +383,7 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffse
371383
// will have either dsts or o_cids
372384
if (op.contains("dsts")) {
373385
operation.nOutputs = op["dsts"].size();
386+
checkChannelsPerOperation(operation.nOutputs);
374387
operation.outputBufferType = convertToBufferType(op["dsts"][0]["buff"]);
375388
for (int i = 0; i < operation.nOutputs; i++) {
376389
operation.outputOffsets[i] =

src/executor/executor.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,19 @@ struct Executor::Impl {
287287
DeviceExecutionPlan deviceExecutionPlan = {};
288288
std::vector<Operation> ops = plan.impl_->getOperations(rank, threadblock);
289289
deviceExecutionPlan.nOperations = ops.size();
290+
if (deviceExecutionPlan.nOperations > MAX_OPERATION) {
291+
throw Error("Executor plan has " + std::to_string(deviceExecutionPlan.nOperations) +
292+
" operations, exceeding executor support (" + std::to_string(MAX_OPERATION) + ")",
293+
ErrorCode::ExecutorError);
294+
}
290295
deviceExecutionPlan.nSmChannels = plan.impl_->threadblockSMChannelMap.at(rank).at(threadblock).size();
291296
deviceExecutionPlan.nProxyChannels = plan.impl_->threadblockProxyChannelMap.at(rank).at(threadblock).size();
297+
if (deviceExecutionPlan.nSmChannels > MAX_CHANNEL || deviceExecutionPlan.nProxyChannels > MAX_CHANNEL) {
298+
throw Error("Executor plan has " +
299+
std::to_string(std::max(deviceExecutionPlan.nSmChannels, deviceExecutionPlan.nProxyChannels)) +
300+
" channels, exceeding executor support (" + std::to_string(MAX_CHANNEL) + ")",
301+
ErrorCode::ExecutorError);
302+
}
292303
int chanIndex = 0;
293304
for (const auto& [index, _] : plan.impl_->threadblockSMChannelMap.at(rank).at(threadblock)) {
294305
deviceExecutionPlan.channels.smChannels[chanIndex++] = mscclpp::deviceHandle(context.smChannels[index]);

src/include/execution_kernel.hpp

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -349,11 +349,11 @@ MSCCLPP_DEVICE_INLINE void handleTransformToPacket(void* dst, void* src, uint32_
349349
mscclpp::putPackets<PacketType>(dst, dstOffset, src, srcOffset, size, threadIdx.x, blockDim.x, flag);
350350
}
351351

352-
template <typename T>
352+
template <typename T, bool SendToRemote = true>
353353
MSCCLPP_DEVICE_INLINE void handleReduceSend(T* dst, uint32_t dstOffsetByBytes, T* src, uint32_t srcOffsetByBytes,
354-
T* input, uint32_t* inputOffsets, DeviceHandle<SmChannel>* smChannels,
355-
uint8_t* outputChannelIndexes, uint32_t* outputOffsets, int nOutChannels,
356-
uint32_t size) {
354+
T* input, uint32_t* inputOffsets, int nInputs,
355+
DeviceHandle<SmChannel>* smChannels, uint8_t* outputChannelIndexes,
356+
uint32_t* outputOffsets, int nOutChannels, uint32_t size) {
357357
const size_t nInt4 = size / sizeof(int4);
358358
const size_t srcOffset4 = srcOffsetByBytes / sizeof(int4);
359359
const size_t dstOffset4 = dstOffsetByBytes / sizeof(int4);
@@ -362,15 +362,17 @@ MSCCLPP_DEVICE_INLINE void handleReduceSend(T* dst, uint32_t dstOffsetByBytes, T
362362
int4* input4 = (int4*)input;
363363
for (size_t idx = threadIdx.x; idx < nInt4; idx += blockDim.x) {
364364
int4 tmp = src4[srcOffset4 + idx];
365-
for (int index = 0; index < nOutChannels; ++index) {
365+
for (int index = 0; index < nInputs; ++index) {
366366
size_t offset = inputOffsets[index] / sizeof(int4);
367367
int4 val = input4[offset + idx];
368368
tmp = add_vectors<T>(tmp, val);
369369
}
370370
dst4[dstOffset4 + idx] = tmp;
371-
for (int index = 0; index < nOutChannels; ++index) {
372-
size_t offset = outputOffsets[index] / sizeof(int4);
373-
smChannels[outputChannelIndexes[index]].write<int4>(offset + idx, tmp);
371+
if constexpr (SendToRemote) {
372+
for (int index = 0; index < nOutChannels; ++index) {
373+
size_t offset = outputOffsets[index] / sizeof(int4);
374+
smChannels[outputChannelIndexes[index]].write<int4>(offset + idx, tmp);
375+
}
374376
}
375377
}
376378
// handle rest of data
@@ -379,14 +381,16 @@ MSCCLPP_DEVICE_INLINE void handleReduceSend(T* dst, uint32_t dstOffsetByBytes, T
379381
const size_t endIdx = (srcOffsetByBytes + size) / sizeof(T);
380382
for (size_t idx = threadIdx.x + startIdx; idx < endIdx; idx += blockDim.x) {
381383
T tmp = src[idx];
382-
for (int index = 0; index < nOutChannels; ++index) {
384+
for (int index = 0; index < nInputs; ++index) {
383385
size_t offset = inputOffsets[index] / sizeof(T);
384386
tmp = add_elements(tmp, input[offset + idx]);
385387
}
386388
dst[idx] = tmp;
387-
for (int index = 0; index < nOutChannels; ++index) {
388-
size_t offset = outputOffsets[index] / sizeof(T);
389-
smChannels[outputChannelIndexes[index]].write<T>(offset + idx, tmp);
389+
if constexpr (SendToRemote) {
390+
for (int index = 0; index < nOutChannels; ++index) {
391+
size_t offset = outputOffsets[index] / sizeof(T);
392+
smChannels[outputChannelIndexes[index]].write<T>(offset + idx, tmp);
393+
}
390394
}
391395
}
392396
}
@@ -523,8 +527,14 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu
523527
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
524528
T* src = getBuffer(input, output, scratch, op.srcBufferType);
525529
T* tmp = getBuffer(input, output, scratch, op.inputBufferType);
526-
handleReduceSend(dst, op.dstOffset, src, op.srcOffset, tmp, op.inputOffsets, smChannels, op.outputChannelIndexes,
527-
op.outputOffsets, op.nOutputs, op.size);
530+
handleReduceSend(dst, op.dstOffset, src, op.srcOffset, tmp, op.inputOffsets, op.nInputs, smChannels,
531+
op.outputChannelIndexes, op.outputOffsets, op.nOutputs, op.size);
532+
} else if (op.type == OperationType::REDUCE) {
533+
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
534+
T* src = getBuffer(input, output, scratch, op.srcBufferType);
535+
T* tmp = getBuffer(input, output, scratch, op.inputBufferType);
536+
handleReduceSend<T, false>(dst, op.dstOffset, src, op.srcOffset, tmp, op.inputOffsets, op.nInputs, smChannels,
537+
op.outputChannelIndexes, op.outputOffsets, op.nOutputs, op.size);
528538
}
529539

530540
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_EXECUTOR_OP_BASE_EXIT)

src/include/execution_plan.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ struct ExecutionPlan::Impl {
108108
size_t getNChunkSize(int rank, size_t inputSize, size_t outputSize, uint32_t nChunks,
109109
const std::vector<uint32_t> offsets) const;
110110
void calcScratchBufferSizeAndOffset(int rank, size_t inputSize, size_t outputSize, int flag);
111+
void checkChannelsPerOperation(int channels);
111112
};
112113

113114
} // namespace mscclpp

0 commit comments

Comments
 (0)