Skip to content

Commit 59d0917

Browse files
committed
fix bug
1 parent 96a6d56 commit 59d0917

File tree

4 files changed

+24
-30
lines changed

4 files changed

+24
-30
lines changed

src/executor/execution_kernel.cu

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@ namespace mscclpp {
88

99
template <typename PacketType>
1010
void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, void* src, void* dst, void* scratch,
11-
size_t scratchSize, DataType dataType, DeviceExecutionPlan* plan,
12-
size_t sharedMemSize, cudaStream_t stream, uint32_t flag) {
11+
DataType dataType, DeviceExecutionPlan* plan, size_t sharedMemSize,
12+
cudaStream_t stream, uint32_t flag) {
1313
switch (dataType) {
1414
case DataType::INT32:
1515
executionKernel<int32_t, PacketType><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
16-
rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, scratchSize, plan, flag
16+
rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, plan, flag
1717
#if defined(ENABLE_NPKIT)
1818
,
1919
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
@@ -23,7 +23,7 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
2323
break;
2424
case DataType::UINT32:
2525
executionKernel<uint32_t><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
26-
rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, scratchSize, plan, flag
26+
rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, plan, flag
2727
#if defined(ENABLE_NPKIT)
2828
,
2929
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
@@ -33,7 +33,7 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
3333
break;
3434
case DataType::FLOAT16:
3535
executionKernel<half><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
36-
rank, (half*)src, (half*)dst, (half*)scratch, scratchSize, plan, flag
36+
rank, (half*)src, (half*)dst, (half*)scratch, plan, flag
3737
#if defined(ENABLE_NPKIT)
3838
,
3939
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
@@ -43,7 +43,7 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
4343
break;
4444
case DataType::FLOAT32:
4545
executionKernel<float><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
46-
rank, (float*)src, (float*)dst, (float*)scratch, scratchSize, plan, flag
46+
rank, (float*)src, (float*)dst, (float*)scratch, plan, flag
4747
#if defined(ENABLE_NPKIT)
4848
,
4949
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
@@ -53,7 +53,7 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
5353
break;
5454
case DataType::BFLOAT16:
5555
executionKernel<__bfloat16><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
56-
rank, (__bfloat16*)src, (__bfloat16*)dst, (__bfloat16*)scratch, scratchSize, plan, flag
56+
rank, (__bfloat16*)src, (__bfloat16*)dst, (__bfloat16*)scratch, plan, flag
5757
#if defined(ENABLE_NPKIT)
5858
,
5959
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
@@ -65,12 +65,10 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
6565
}
6666

6767
template void ExecutionKernel::launchKernel<LL16Packet>(int rank, int nthreadblocks, int nthreads, void* src, void* dst,
68-
void* scratch, size_t scratchSize, DataType dataType,
69-
DeviceExecutionPlan* plan, size_t sharedMemSize,
70-
cudaStream_t stream, uint32_t flag);
68+
void* scratch, DataType dataType, DeviceExecutionPlan* plan,
69+
size_t sharedMemSize, cudaStream_t stream, uint32_t flag);
7170
template void ExecutionKernel::launchKernel<LL8Packet>(int rank, int nthreadblocks, int nthreads, void* src, void* dst,
72-
void* scratch, size_t scratchSize, DataType dataType,
73-
DeviceExecutionPlan* plan, size_t sharedMemSize,
74-
cudaStream_t stream, uint32_t flag);
71+
void* scratch, DataType dataType, DeviceExecutionPlan* plan,
72+
size_t sharedMemSize, cudaStream_t stream, uint32_t flag);
7573
} // namespace mscclpp
7674
#endif

src/executor/execution_plan.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ std::vector<ChannelInfo> ExecutionPlan::Impl::getChannelInfos(int rank, BufferTy
100100
}
101101

102102
std::vector<ChannelInfo> ExecutionPlan::Impl::getChannelInfosByDstRank(int rank, BufferType bufferType) const {
103-
auto pred = [rank, bufferType](const ChannelInfo& info) { return info.dstBufferType == bufferType; };
103+
auto pred = [bufferType](const ChannelInfo& info) { return info.dstBufferType == bufferType; };
104104
return filter(this->channelInfosByDstRank.at(rank), pred);
105105
}
106106

@@ -159,10 +159,10 @@ size_t ExecutionPlan::Impl::getScratchBufferSize(int rank, size_t inputSize, siz
159159

160160
size_t scratchBufferSize = sizePerRank * this->scratchChunks.at(rank);
161161
if (this->isUsingPacket) {
162-
scratchBufferSize *= 2; // data + flag
162+
scratchBufferSize *= 2; /* data + flag */
163163
}
164164
if (this->isUsingDoubleScratchBuffer) {
165-
scratchBufferSize *= 2; // double buffer
165+
scratchBufferSize *= 2; /* double buffer */
166166
}
167167
return scratchBufferSize;
168168
}
@@ -174,7 +174,7 @@ int ExecutionPlan::Impl::getThreadblockCount(int rank) const { return this->oper
174174

175175
int ExecutionPlan::Impl::getNThreadsPerBlock() const { return this->nThreadsPerBlock; }
176176

177-
bool ExecutionPlan::Impl::getIsUsingDoubleScratchBuffer() const { return this->getIsUsingDoubleScratchBuffer; }
177+
bool ExecutionPlan::Impl::getIsUsingDoubleScratchBuffer() const { return this->isUsingDoubleScratchBuffer; }
178178

179179
void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t outputSize, size_t contsSrcOffset,
180180
size_t constDstOffset) {

src/executor/executor.cc

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -309,12 +309,8 @@ struct Executor::Impl {
309309
static uint32_t flag = 0;
310310
int nthreadblocks = context.deviceExecutionPlans.size();
311311
char* kernelScratchBufferPtr = context.scratchBuffer.get();
312-
size_t kernelScratchBufferSize = context.scratchBufferSize;
313-
if (context.isUsingDoubleScratchBuffer) {
314-
kernelScratchBufferSize /= 2;
315-
if (flag % 2) {
316-
kernelScratchBufferPtr += kernelScratchBufferSize;
317-
}
312+
if (context.isUsingDoubleScratchBuffer && (flag % 2)) {
313+
kernelScratchBufferPtr += context.scratchBufferSize / 2;
318314
}
319315
#if defined(ENABLE_NPKIT)
320316
#if defined(__HIP_PLATFORM_AMD__)
@@ -332,13 +328,13 @@ struct Executor::Impl {
332328
switch (packetType) {
333329
case PacketType::LL16:
334330
ExecutionKernel::launchKernel<LL16Packet>(rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff,
335-
(void*)kernelScratchBufferPtr, kernelScratchBufferSize, dataType,
331+
(void*)kernelScratchBufferPtr, dataType,
336332
(DeviceExecutionPlan*)context.deviceExecutionPlansBuffer.get(),
337333
sharedMemSize, stream, ++flag);
338334
break;
339335
case PacketType::LL8:
340336
ExecutionKernel::launchKernel<LL8Packet>(rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff,
341-
(void*)kernelScratchBufferPtr, kernelScratchBufferSize, dataType,
337+
(void*)kernelScratchBufferPtr, dataType,
342338
(DeviceExecutionPlan*)context.deviceExecutionPlansBuffer.get(),
343339
sharedMemSize, stream, ++flag);
344340
break;

src/include/execution_kernel.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ MSCCLPP_DEVICE_INLINE void handlePutPacket(DeviceHandle<SmChannel>* smChannels,
301301

302302
template <typename T, typename PacketType, bool SendToRemote = true>
303303
MSCCLPP_DEVICE_INLINE void handleReduceSendPacket(T* dst, uint32_t dstOffsetByBytes, T* src, uint32_t srcOffsetByBytes,
304-
T* inputBuff, size_t inputBuffSize, uint32_t* inputOffsets, int nSrcs,
304+
T* inputBuff, uint32_t* inputOffsets, int nSrcs,
305305
DeviceHandle<SmChannel>* smChannels, uint8_t* outputChannelIndexes,
306306
uint32_t* outputOffsets, int nDstChannels, size_t size,
307307
uint32_t flag) {
@@ -331,8 +331,8 @@ MSCCLPP_DEVICE_INLINE void handleReduceSendPacket(T* dst, uint32_t dstOffsetByBy
331331
}
332332

333333
template <typename PacketType>
334-
MSCCLPP_DEVICE_INLINE void handleCopyPacket(void* dst, void* src, size_t srcSize, uint32_t dstOffset,
335-
uint32_t srcOffset, size_t size, uint32_t flag) {
334+
MSCCLPP_DEVICE_INLINE void handleCopyPacket(void* dst, void* src, uint32_t dstOffset, uint32_t srcOffset, size_t size,
335+
uint32_t flag) {
336336
PacketType* srcPackets = (PacketType*)((char*)src + 2 * srcOffset);
337337
PacketPayload<PacketType>* result = (PacketPayload<PacketType>*)((char*)dst + dstOffset);
338338
size_t nPackets = size * 2 / sizeof(PacketType);
@@ -343,8 +343,8 @@ MSCCLPP_DEVICE_INLINE void handleCopyPacket(void* dst, void* src, size_t srcSize
343343
}
344344

345345
template <typename PacketType>
346-
MSCCLPP_DEVICE_INLINE void handleTransformToPacket(void* dst, void* src, size_t dstSize, uint32_t dstOffset,
347-
uint32_t srcOffset, size_t size, uint32_t flag) {
346+
MSCCLPP_DEVICE_INLINE void handleTransformToPacket(void* dst, void* src, uint32_t dstOffset, uint32_t srcOffset,
347+
size_t size, uint32_t flag) {
348348
dstOffset = dstOffset * 2;
349349
mscclpp::putPackets<PacketType>(dst, dstOffset, src, srcOffset, size, threadIdx.x, blockDim.x, flag);
350350
}

0 commit comments

Comments
 (0)