@@ -349,11 +349,11 @@ MSCCLPP_DEVICE_INLINE void handleTransformToPacket(void* dst, void* src, uint32_
349349 mscclpp::putPackets<PacketType>(dst, dstOffset, src, srcOffset, size, threadIdx.x , blockDim.x , flag);
350350}
351351
352- template <typename T>
352+ template <typename T, bool SendToRemote = true >
353353MSCCLPP_DEVICE_INLINE void handleReduceSend (T* dst, uint32_t dstOffsetByBytes, T* src, uint32_t srcOffsetByBytes,
354- T* input, uint32_t * inputOffsets, DeviceHandle<SmChannel>* smChannels ,
355- uint8_t * outputChannelIndexes, uint32_t * outputOffsets, int nOutChannels ,
356- uint32_t size) {
354+ T* input, uint32_t * inputOffsets, int nInputs ,
355+ DeviceHandle<SmChannel>* smChannels, uint8_t * outputChannelIndexes ,
356+ uint32_t * outputOffsets, int nOutChannels, uint32_t size) {
357357 const size_t nInt4 = size / sizeof (int4);
358358 const size_t srcOffset4 = srcOffsetByBytes / sizeof (int4);
359359 const size_t dstOffset4 = dstOffsetByBytes / sizeof (int4);
@@ -362,15 +362,17 @@ MSCCLPP_DEVICE_INLINE void handleReduceSend(T* dst, uint32_t dstOffsetByBytes, T
362362 int4* input4 = (int4*)input;
363363 for (size_t idx = threadIdx.x ; idx < nInt4; idx += blockDim.x ) {
364364 int4 tmp = src4[srcOffset4 + idx];
365- for (int index = 0 ; index < nOutChannels ; ++index) {
365+ for (int index = 0 ; index < nInputs ; ++index) {
366366 size_t offset = inputOffsets[index] / sizeof (int4);
367367 int4 val = input4[offset + idx];
368368 tmp = add_vectors<T>(tmp, val);
369369 }
370370 dst4[dstOffset4 + idx] = tmp;
371- for (int index = 0 ; index < nOutChannels; ++index) {
372- size_t offset = outputOffsets[index] / sizeof (int4);
373- smChannels[outputChannelIndexes[index]].write <int4>(offset + idx, tmp);
371+ if constexpr (SendToRemote) {
372+ for (int index = 0 ; index < nOutChannels; ++index) {
373+ size_t offset = outputOffsets[index] / sizeof (int4);
374+ smChannels[outputChannelIndexes[index]].write <int4>(offset + idx, tmp);
375+ }
374376 }
375377 }
376378 // handle rest of data
@@ -379,14 +381,16 @@ MSCCLPP_DEVICE_INLINE void handleReduceSend(T* dst, uint32_t dstOffsetByBytes, T
379381 const size_t endIdx = (srcOffsetByBytes + size) / sizeof (T);
380382 for (size_t idx = threadIdx.x + startIdx; idx < endIdx; idx += blockDim.x ) {
381383 T tmp = src[idx];
382- for (int index = 0 ; index < nOutChannels ; ++index) {
384+ for (int index = 0 ; index < nInputs ; ++index) {
383385 size_t offset = inputOffsets[index] / sizeof (T);
384386 tmp = add_elements (tmp, input[offset + idx]);
385387 }
386388 dst[idx] = tmp;
387- for (int index = 0 ; index < nOutChannels; ++index) {
388- size_t offset = outputOffsets[index] / sizeof (T);
389- smChannels[outputChannelIndexes[index]].write <T>(offset + idx, tmp);
389+ if constexpr (SendToRemote) {
390+ for (int index = 0 ; index < nOutChannels; ++index) {
391+ size_t offset = outputOffsets[index] / sizeof (T);
392+ smChannels[outputChannelIndexes[index]].write <T>(offset + idx, tmp);
393+ }
390394 }
391395 }
392396}
@@ -523,8 +527,14 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu
523527 T* dst = getBuffer (input, output, scratch, op.dstBufferType );
524528 T* src = getBuffer (input, output, scratch, op.srcBufferType );
525529 T* tmp = getBuffer (input, output, scratch, op.inputBufferType );
526- handleReduceSend (dst, op.dstOffset , src, op.srcOffset , tmp, op.inputOffsets , smChannels, op.outputChannelIndexes ,
527- op.outputOffsets , op.nOutputs , op.size );
530+ handleReduceSend (dst, op.dstOffset , src, op.srcOffset , tmp, op.inputOffsets , op.nInputs , smChannels,
531+ op.outputChannelIndexes , op.outputOffsets , op.nOutputs , op.size );
532+ } else if (op.type == OperationType::REDUCE) {
533+ T* dst = getBuffer (input, output, scratch, op.dstBufferType );
534+ T* src = getBuffer (input, output, scratch, op.srcBufferType );
535+ T* tmp = getBuffer (input, output, scratch, op.inputBufferType );
536+ handleReduceSend<T, false >(dst, op.dstOffset , src, op.srcOffset , tmp, op.inputOffsets , op.nInputs , smChannels,
537+ op.outputChannelIndexes , op.outputOffsets , op.nOutputs , op.size );
528538 }
529539
530540#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_EXECUTOR_OP_BASE_EXIT)
0 commit comments