@@ -277,15 +277,14 @@ MSCCLPP_DEVICE_INLINE void handleReadReduceCopySend(T* output, uint32_t outputOf
277277}
278278
279279template <typename PacketType>
280- MSCCLPP_DEVICE_INLINE void handlePutPacket (size_t scratchSize, DeviceHandle<SmChannel>* smChannels,
280+ MSCCLPP_DEVICE_INLINE void handlePutPacket (DeviceHandle<SmChannel>* smChannels,
281281 DeviceHandle<SimpleProxyChannel>* proxyChannels, uint8_t * dstChannelIndexes,
282282 uint32_t * dstOffsets, uint32_t * srcOffsets, int nDstChannels, uint32_t size,
283283 ChannelType chType, uint32_t flag) {
284- const size_t scratchBaseOffset = flag & 0x1 ? 0 : scratchSize >> 1 ;
285284 if (chType == ChannelType::SM) {
286285 for (int index = 0 ; index < nDstChannels; ++index) {
287- smChannels[dstChannelIndexes[index]].putPackets <PacketType>(
288- scratchBaseOffset + dstOffsets[index] * 2 , srcOffsets[index], size, threadIdx.x , blockDim.x , flag);
286+ smChannels[dstChannelIndexes[index]].putPackets <PacketType>(dstOffsets[index] * 2 , srcOffsets[index], size,
287+ threadIdx.x , blockDim.x , flag);
289288 }
290289 }
291290 if (chType == ChannelType::PROXY) {
@@ -294,8 +293,8 @@ MSCCLPP_DEVICE_INLINE void handlePutPacket(size_t scratchSize, DeviceHandle<SmCh
294293 return ;
295294 }
296295 // For proxy channel, we assume src and dst are in packet format
297- uint32_t dstOffset = ( dstOffsets[tid] << 1 ) + scratchBaseOffset ;
298- uint32_t srcOffset = ( srcOffsets[tid] << 1 ) + scratchBaseOffset ;
296+ uint32_t dstOffset = dstOffsets[tid] << 1 ;
297+ uint32_t srcOffset = srcOffsets[tid] << 1 ;
299298 proxyChannels[dstChannelIndexes[tid]].put (dstOffset, srcOffset, size << 1 );
300299 }
301300}
@@ -307,15 +306,14 @@ MSCCLPP_DEVICE_INLINE void handleReduceSendPacket(T* dst, uint32_t dstOffsetByBy
307306 uint32_t * outputOffsets, int nDstChannels, size_t size,
308307 uint32_t flag) {
309308 size_t nPackets = size * 2 / sizeof (PacketType);
310- const size_t intputBaseOffset = flag & 0x1 ? 0 : inputBuffSize >> 1 ;
311309 const uint32_t srcOffset = srcOffsetByBytes / sizeof (PacketPayload<PacketType>);
312310 const uint32_t dstOffset = dstOffsetByBytes / sizeof (PacketPayload<PacketType>);
313311 PacketPayload<PacketType>* srcPacketPayload = (PacketPayload<PacketType>*)src + srcOffset;
314312 PacketPayload<PacketType>* dstPacketPayload = (PacketPayload<PacketType>*)dst + dstOffset;
315313 for (size_t idx = threadIdx.x ; idx < nPackets; idx += blockDim.x ) {
316314 PacketPayload<PacketType> data = {};
317315 for (int index = 0 ; index < nSrcs; ++index) {
318- PacketType* pkt = (PacketType*)((char *)inputBuff + intputBaseOffset + 2 * inputOffsets[index]);
316+ PacketType* pkt = (PacketType*)((char *)inputBuff + 2 * inputOffsets[index]);
319317 PacketPayload<PacketType> val = pkt[idx].read (flag);
320318 data = add_vectors<T>(data, val);
321319 }
@@ -325,7 +323,7 @@ MSCCLPP_DEVICE_INLINE void handleReduceSendPacket(T* dst, uint32_t dstOffsetByBy
325323 if (SendToRemote) {
326324 PacketType pkt (data, flag);
327325 for (int index = 0 ; index < nDstChannels; ++index) {
328- size_t offset = (intputBaseOffset + outputOffsets[index] * 2 ) / sizeof (PacketType);
326+ size_t offset = outputOffsets[index] * 2 / sizeof (PacketType);
329327 smChannels[outputChannelIndexes[index]].write (offset + idx, pkt);
330328 }
331329 }
@@ -335,8 +333,7 @@ MSCCLPP_DEVICE_INLINE void handleReduceSendPacket(T* dst, uint32_t dstOffsetByBy
335333template <typename PacketType>
336334MSCCLPP_DEVICE_INLINE void handleCopyPacket (void * dst, void * src, size_t srcSize, uint32_t dstOffset,
337335 uint32_t srcOffset, size_t size, uint32_t flag) {
338- const size_t inputScratchBaseOffset = flag & 0x1 ? 0 : srcSize >> 1 ;
339- PacketType* srcPackets = (PacketType*)((char *)src + inputScratchBaseOffset + 2 * srcOffset);
336+ PacketType* srcPackets = (PacketType*)((char *)src + 2 * srcOffset);
340337 PacketPayload<PacketType>* result = (PacketPayload<PacketType>*)((char *)dst + dstOffset);
341338 size_t nPackets = size * 2 / sizeof (PacketType);
342339 for (size_t idx = threadIdx.x ; idx < nPackets; idx += blockDim.x ) {
@@ -348,8 +345,7 @@ MSCCLPP_DEVICE_INLINE void handleCopyPacket(void* dst, void* src, size_t srcSize
348345template <typename PacketType>
349346MSCCLPP_DEVICE_INLINE void handleTransformToPacket (void * dst, void * src, size_t dstSize, uint32_t dstOffset,
350347 uint32_t srcOffset, size_t size, uint32_t flag) {
351- const size_t outputScratchBaseOffset = flag & 0x1 ? 0 : dstSize >> 1 ;
352- dstOffset = dstOffset * 2 + outputScratchBaseOffset;
348+ dstOffset = dstOffset * 2 ;
353349 mscclpp::putPackets<PacketType>(dst, dstOffset, src, srcOffset, size, threadIdx.x , blockDim.x , flag);
354350}
355351
@@ -403,7 +399,7 @@ MSCCLPP_DEVICE_INLINE void handleCopy(void* dst, void* src, uint32_t dstOffset,
403399
404400template <typename T, typename PacketType = LL16Packet>
405401__global__ void executionKernel ([[maybe_unused]] int rank /* for debug*/ , T* input, T* output, T* scratch,
406- size_t scratchSize, DeviceExecutionPlan* plan, uint32_t flag
402+ DeviceExecutionPlan* plan, uint32_t flag
407403#if defined(ENABLE_NPKIT)
408404 ,
409405 NpKitEventCollectContext* npKitEventCollectContexts, uint64_t * cpuTimestamp) {
@@ -501,28 +497,28 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu
501497 op.inputChannelIndexes , op.outputOffsets , op.inputOffsets , op.nOutputs , op.nInputs ,
502498 op.size , false );
503499 } else if (op.type == OperationType::PUT_PACKET) {
504- handlePutPacket<PacketType>(scratchSize, smChannels, proxyChannels, op.outputChannelIndexes , op.outputOffsets ,
505- op.inputOffsets , op. nOutputs , op.size , op.channelType , flag);
500+ handlePutPacket<PacketType>(smChannels, proxyChannels, op.outputChannelIndexes , op.outputOffsets , op. inputOffsets ,
501+ op.nOutputs , op.size , op.channelType , flag);
506502 } else if (op.type == OperationType::REDUCE_SEND_PACKET) {
507503 T* dst = getBuffer (input, output, scratch, op.dstBufferType );
508504 T* src = getBuffer (input, output, scratch, op.srcBufferType );
509- handleReduceSendPacket<T, PacketType>(dst, op.dstOffset , src, op.srcOffset , scratch, scratchSize , op.inputOffsets ,
510- op.nInputs , smChannels , op.outputChannelIndexes , op.outputOffsets ,
511- op. nOutputs , op. size , flag);
505+ handleReduceSendPacket<T, PacketType>(dst, op.dstOffset , src, op.srcOffset , scratch, op. inputOffsets , op.nInputs ,
506+ smChannels, op.outputChannelIndexes , op. outputOffsets , op.nOutputs , op.size ,
507+ flag);
512508 } else if (op.type == OperationType::REDUCE_PACKET) {
513509 T* dst = getBuffer (input, output, scratch, op.dstBufferType );
514510 T* src = getBuffer (input, output, scratch, op.srcBufferType );
515- handleReduceSendPacket<T, PacketType, false >(dst, op.dstOffset , src, op.srcOffset , scratch, scratchSize ,
516- op.inputOffsets , op. nInputs , smChannels, op.outputChannelIndexes ,
517- op.outputOffsets , op. nOutputs , op.size , flag);
511+ handleReduceSendPacket<T, PacketType, false >(dst, op.dstOffset , src, op.srcOffset , scratch, op. inputOffsets ,
512+ op.nInputs , smChannels, op.outputChannelIndexes , op. outputOffsets ,
513+ op.nOutputs , op.size , flag);
518514 } else if (op.type == OperationType::COPY_PACKET) {
519515 T* dst = getBuffer (input, output, scratch, op.dstBufferType );
520516 T* src = getBuffer (input, output, scratch, op.srcBufferType );
521- handleCopyPacket<PacketType>(dst, src, scratchSize, op.dstOffset , op.srcOffset , op.size , flag);
517+ handleCopyPacket<PacketType>(dst, src, op.dstOffset , op.srcOffset , op.size , flag);
522518 } else if (op.type == OperationType::TRANSFORM_TO_PACKET) {
523519 T* dst = getBuffer (input, output, scratch, op.dstBufferType );
524520 T* src = getBuffer (input, output, scratch, op.srcBufferType );
525- handleTransformToPacket<PacketType>(dst, src, scratchSize, op.dstOffset , op.srcOffset , op.size , flag);
521+ handleTransformToPacket<PacketType>(dst, src, op.dstOffset , op.srcOffset , op.size , flag);
526522 } else if (op.type == OperationType::REDUCE_SEND) {
527523 T* dst = getBuffer (input, output, scratch, op.dstBufferType );
528524 T* src = getBuffer (input, output, scratch, op.srcBufferType );
@@ -548,12 +544,12 @@ class ExecutionKernel {
548544#if defined(MSCCLPP_DEVICE_HIP)
549545 template <typename PacketType>
550546 static void launchKernel (int rank, int nthreadblocks, int nthreads, void * src, void * dst, void * scratch,
551- size_t scratchSize, DataType dataType, DeviceExecutionPlan* plan, size_t sharedMemSize,
552- cudaStream_t stream, uint32_t flag = 0 ) {
547+ DataType dataType, DeviceExecutionPlan* plan, size_t sharedMemSize, cudaStream_t stream ,
548+ uint32_t flag = 0 ) {
553549 switch (dataType) {
554550 case DataType::INT32:
555551 executionKernel<int32_t , PacketType><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
556- rank, (int32_t *)src, (int32_t *)dst, (int32_t *)scratch, scratchSize, plan, flag
552+ rank, (int32_t *)src, (int32_t *)dst, (int32_t *)scratch, plan, flag
557553#if defined(ENABLE_NPKIT)
558554 ,
559555 NpKit::GetGpuEventCollectContexts (), NpKit::GetCpuTimestamp ());
@@ -563,7 +559,7 @@ class ExecutionKernel {
563559 break ;
564560 case DataType::UINT32:
565561 executionKernel<uint32_t , PacketType><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
566- rank, (uint32_t *)src, (uint32_t *)dst, (uint32_t *)scratch, scratchSize, plan, flag
562+ rank, (uint32_t *)src, (uint32_t *)dst, (uint32_t *)scratch, plan, flag
567563#if defined(ENABLE_NPKIT)
568564 ,
569565 NpKit::GetGpuEventCollectContexts (), NpKit::GetCpuTimestamp ());
@@ -573,7 +569,7 @@ class ExecutionKernel {
573569 break ;
574570 case DataType::FLOAT16:
575571 executionKernel<half, PacketType><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
576- rank, (half*)src, (half*)dst, (half*)scratch, scratchSize, plan, flag
572+ rank, (half*)src, (half*)dst, (half*)scratch, plan, flag
577573#if defined(ENABLE_NPKIT)
578574 ,
579575 NpKit::GetGpuEventCollectContexts (), NpKit::GetCpuTimestamp ());
@@ -583,7 +579,7 @@ class ExecutionKernel {
583579 break ;
584580 case DataType::FLOAT32:
585581 executionKernel<float , PacketType><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
586- rank, (float *)src, (float *)dst, (float *)scratch, scratchSize, plan, flag
582+ rank, (float *)src, (float *)dst, (float *)scratch, plan, flag
587583#if defined(ENABLE_NPKIT)
588584 ,
589585 NpKit::GetGpuEventCollectContexts (), NpKit::GetCpuTimestamp ());
@@ -593,7 +589,7 @@ class ExecutionKernel {
593589 break ;
594590 case DataType::BFLOAT16:
595591 executionKernel<__bfloat16, PacketType><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
596- rank, (__bfloat16*)src, (__bfloat16*)dst, (__bfloat16*)scratch, scratchSize, plan, flag
592+ rank, (__bfloat16*)src, (__bfloat16*)dst, (__bfloat16*)scratch, plan, flag
597593#if defined(ENABLE_NPKIT)
598594 ,
599595 NpKit::GetGpuEventCollectContexts (), NpKit::GetCpuTimestamp ());
@@ -606,8 +602,8 @@ class ExecutionKernel {
606602#else // !defined(MSCCLPP_DEVICE_HIP)
607603 template <typename PacketType>
608604 static void launchKernel (int rank, int nthreadblocks, int nthreads, void * src, void * dst, void * scratch,
609- size_t scratchSize, DataType dataType, DeviceExecutionPlan* plan, size_t sharedMemSize,
610- cudaStream_t stream, uint32_t flag = 0 );
605+ DataType dataType, DeviceExecutionPlan* plan, size_t sharedMemSize, cudaStream_t stream ,
606+ uint32_t flag = 0 );
611607#endif // !defined(MSCCLPP_DEVICE_HIP)
612608};
613609} // namespace mscclpp
0 commit comments