Skip to content

Commit 96a6d56

Browse files
committed
Add separate option for double scratch buffer
1 parent b72decb commit 96a6d56

File tree

4 files changed

+59
-43
lines changed

4 files changed

+59
-43
lines changed

src/executor/execution_plan.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,10 +157,14 @@ size_t ExecutionPlan::Impl::getScratchBufferSize(int rank, size_t inputSize, siz
157157
else
158158
throw mscclpp::Error("Output or Input chunks must be greater than 0", mscclpp::ErrorCode::ExecutorError);
159159

160+
size_t scratchBufferSize = sizePerRank * this->scratchChunks.at(rank);
160161
if (this->isUsingPacket) {
161-
return sizePerRank * this->scratchChunks.at(rank) * 2 /* data + flag*/ * 2 /*double buffer*/;
162+
scratchBufferSize *= 2; // data + flag
162163
}
163-
return sizePerRank * this->scratchChunks.at(rank);
164+
if (this->isUsingDoubleScratchBuffer) {
165+
scratchBufferSize *= 2; // double buffer
166+
}
167+
return scratchBufferSize;
164168
}
165169
std::vector<Operation> ExecutionPlan::Impl::getOperations(int rank, int threadblock) const {
166170
return this->operations.at(rank)[threadblock];
@@ -170,6 +174,8 @@ int ExecutionPlan::Impl::getThreadblockCount(int rank) const { return this->oper
170174

171175
int ExecutionPlan::Impl::getNThreadsPerBlock() const { return this->nThreadsPerBlock; }
172176

177+
bool ExecutionPlan::Impl::getIsUsingDoubleScratchBuffer() const { return this->getIsUsingDoubleScratchBuffer; }
178+
173179
void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t outputSize, size_t contsSrcOffset,
174180
size_t constDstOffset) {
175181
std::ifstream file(this->planPath);
@@ -182,6 +188,7 @@ void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t outputSize,
182188
this->isUsingPacket = true;
183189
}
184190
this->nThreadsPerBlock = obj.value("num_threads_per_block", 1024);
191+
this->isUsingDoubleScratchBuffer = obj["use_double_scratch_buffer"];
185192
const auto& gpus = obj["gpus"];
186193

187194
for (const auto& gpu : gpus) {
@@ -209,6 +216,7 @@ void ExecutionPlan::Impl::lightLoadExecutionPlan(size_t inputSize, size_t output
209216
if (protocol == "LL") {
210217
this->isUsingPacket = true;
211218
}
219+
this->isUsingDoubleScratchBuffer = obj["use_double_scratch_buffer"];
212220
const auto& gpus = obj["gpus"];
213221

214222
for (const auto& gpu : gpus) {

src/executor/executor.cc

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ struct ExecutionContext {
6666
size_t scratchBufferSize;
6767
std::shared_ptr<char> deviceExecutionPlansBuffer;
6868
int nthreadsPerBlock;
69+
bool isUsingDoubleScratchBuffer;
6970
};
7071

7172
struct Executor::Impl {
@@ -106,6 +107,7 @@ struct Executor::Impl {
106107
context.scratchBufferSize = scratchBufferSize;
107108
context.proxyService = std::make_shared<ProxyService>();
108109
context.nthreadsPerBlock = plan.impl_->getNThreadsPerBlock();
110+
context.isUsingDoubleScratchBuffer = plan.impl_->getIsUsingDoubleScratchBuffer();
109111
this->setupConnections(context, rank, plan);
110112
this->setupRegisteredMemories(context, sendbuff, recvbuff, sendBufferSize, recvBufferSize, rank, plan);
111113
this->setupChannels(context, sendbuff, recvbuff, sendBufferSize, recvBufferSize, rank, plan);
@@ -306,6 +308,14 @@ struct Executor::Impl {
306308
cudaStream_t stream, PacketType packetType) {
307309
static uint32_t flag = 0;
308310
int nthreadblocks = context.deviceExecutionPlans.size();
311+
char* kernelScratchBufferPtr = context.scratchBuffer.get();
312+
size_t kernelScratchBufferSize = context.scratchBufferSize;
313+
if (context.isUsingDoubleScratchBuffer) {
314+
kernelScratchBufferSize /= 2;
315+
if (flag % 2) {
316+
kernelScratchBufferPtr += kernelScratchBufferSize;
317+
}
318+
}
309319
#if defined(ENABLE_NPKIT)
310320
#if defined(__HIP_PLATFORM_AMD__)
311321
if (nthreadblocks > NPKIT_MAX_NUM_GPU_THREADBLOCKS) {
@@ -321,16 +331,16 @@ struct Executor::Impl {
321331
#endif
322332
switch (packetType) {
323333
case PacketType::LL16:
324-
ExecutionKernel::launchKernel<LL16Packet>(
325-
rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(),
326-
context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffer.get(),
327-
sharedMemSize, stream, ++flag);
334+
ExecutionKernel::launchKernel<LL16Packet>(rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff,
335+
(void*)kernelScratchBufferPtr, kernelScratchBufferSize, dataType,
336+
(DeviceExecutionPlan*)context.deviceExecutionPlansBuffer.get(),
337+
sharedMemSize, stream, ++flag);
328338
break;
329339
case PacketType::LL8:
330-
ExecutionKernel::launchKernel<LL8Packet>(
331-
rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(),
332-
context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffer.get(),
333-
sharedMemSize, stream, ++flag);
340+
ExecutionKernel::launchKernel<LL8Packet>(rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff,
341+
(void*)kernelScratchBufferPtr, kernelScratchBufferSize, dataType,
342+
(DeviceExecutionPlan*)context.deviceExecutionPlansBuffer.get(),
343+
sharedMemSize, stream, ++flag);
334344
break;
335345
default:
336346
throw Error("Invalid packet type", ErrorCode::ExecutorError);

src/include/execution_kernel.hpp

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -277,15 +277,14 @@ MSCCLPP_DEVICE_INLINE void handleReadReduceCopySend(T* output, uint32_t outputOf
277277
}
278278

279279
template <typename PacketType>
280-
MSCCLPP_DEVICE_INLINE void handlePutPacket(size_t scratchSize, DeviceHandle<SmChannel>* smChannels,
280+
MSCCLPP_DEVICE_INLINE void handlePutPacket(DeviceHandle<SmChannel>* smChannels,
281281
DeviceHandle<SimpleProxyChannel>* proxyChannels, uint8_t* dstChannelIndexes,
282282
uint32_t* dstOffsets, uint32_t* srcOffsets, int nDstChannels, uint32_t size,
283283
ChannelType chType, uint32_t flag) {
284-
const size_t scratchBaseOffset = flag & 0x1 ? 0 : scratchSize >> 1;
285284
if (chType == ChannelType::SM) {
286285
for (int index = 0; index < nDstChannels; ++index) {
287-
smChannels[dstChannelIndexes[index]].putPackets<PacketType>(
288-
scratchBaseOffset + dstOffsets[index] * 2, srcOffsets[index], size, threadIdx.x, blockDim.x, flag);
286+
smChannels[dstChannelIndexes[index]].putPackets<PacketType>(dstOffsets[index] * 2, srcOffsets[index], size,
287+
threadIdx.x, blockDim.x, flag);
289288
}
290289
}
291290
if (chType == ChannelType::PROXY) {
@@ -294,8 +293,8 @@ MSCCLPP_DEVICE_INLINE void handlePutPacket(size_t scratchSize, DeviceHandle<SmCh
294293
return;
295294
}
296295
// For proxy channel, we assume src and dst are in packet format
297-
uint32_t dstOffset = (dstOffsets[tid] << 1) + scratchBaseOffset;
298-
uint32_t srcOffset = (srcOffsets[tid] << 1) + scratchBaseOffset;
296+
uint32_t dstOffset = dstOffsets[tid] << 1;
297+
uint32_t srcOffset = srcOffsets[tid] << 1;
299298
proxyChannels[dstChannelIndexes[tid]].put(dstOffset, srcOffset, size << 1);
300299
}
301300
}
@@ -307,15 +306,14 @@ MSCCLPP_DEVICE_INLINE void handleReduceSendPacket(T* dst, uint32_t dstOffsetByBy
307306
uint32_t* outputOffsets, int nDstChannels, size_t size,
308307
uint32_t flag) {
309308
size_t nPackets = size * 2 / sizeof(PacketType);
310-
const size_t intputBaseOffset = flag & 0x1 ? 0 : inputBuffSize >> 1;
311309
const uint32_t srcOffset = srcOffsetByBytes / sizeof(PacketPayload<PacketType>);
312310
const uint32_t dstOffset = dstOffsetByBytes / sizeof(PacketPayload<PacketType>);
313311
PacketPayload<PacketType>* srcPacketPayload = (PacketPayload<PacketType>*)src + srcOffset;
314312
PacketPayload<PacketType>* dstPacketPayload = (PacketPayload<PacketType>*)dst + dstOffset;
315313
for (size_t idx = threadIdx.x; idx < nPackets; idx += blockDim.x) {
316314
PacketPayload<PacketType> data = {};
317315
for (int index = 0; index < nSrcs; ++index) {
318-
PacketType* pkt = (PacketType*)((char*)inputBuff + intputBaseOffset + 2 * inputOffsets[index]);
316+
PacketType* pkt = (PacketType*)((char*)inputBuff + 2 * inputOffsets[index]);
319317
PacketPayload<PacketType> val = pkt[idx].read(flag);
320318
data = add_vectors<T>(data, val);
321319
}
@@ -325,7 +323,7 @@ MSCCLPP_DEVICE_INLINE void handleReduceSendPacket(T* dst, uint32_t dstOffsetByBy
325323
if (SendToRemote) {
326324
PacketType pkt(data, flag);
327325
for (int index = 0; index < nDstChannels; ++index) {
328-
size_t offset = (intputBaseOffset + outputOffsets[index] * 2) / sizeof(PacketType);
326+
size_t offset = outputOffsets[index] * 2 / sizeof(PacketType);
329327
smChannels[outputChannelIndexes[index]].write(offset + idx, pkt);
330328
}
331329
}
@@ -335,8 +333,7 @@ MSCCLPP_DEVICE_INLINE void handleReduceSendPacket(T* dst, uint32_t dstOffsetByBy
335333
template <typename PacketType>
336334
MSCCLPP_DEVICE_INLINE void handleCopyPacket(void* dst, void* src, size_t srcSize, uint32_t dstOffset,
337335
uint32_t srcOffset, size_t size, uint32_t flag) {
338-
const size_t inputScratchBaseOffset = flag & 0x1 ? 0 : srcSize >> 1;
339-
PacketType* srcPackets = (PacketType*)((char*)src + inputScratchBaseOffset + 2 * srcOffset);
336+
PacketType* srcPackets = (PacketType*)((char*)src + 2 * srcOffset);
340337
PacketPayload<PacketType>* result = (PacketPayload<PacketType>*)((char*)dst + dstOffset);
341338
size_t nPackets = size * 2 / sizeof(PacketType);
342339
for (size_t idx = threadIdx.x; idx < nPackets; idx += blockDim.x) {
@@ -348,8 +345,7 @@ MSCCLPP_DEVICE_INLINE void handleCopyPacket(void* dst, void* src, size_t srcSize
348345
template <typename PacketType>
349346
MSCCLPP_DEVICE_INLINE void handleTransformToPacket(void* dst, void* src, size_t dstSize, uint32_t dstOffset,
350347
uint32_t srcOffset, size_t size, uint32_t flag) {
351-
const size_t outputScratchBaseOffset = flag & 0x1 ? 0 : dstSize >> 1;
352-
dstOffset = dstOffset * 2 + outputScratchBaseOffset;
348+
dstOffset = dstOffset * 2;
353349
mscclpp::putPackets<PacketType>(dst, dstOffset, src, srcOffset, size, threadIdx.x, blockDim.x, flag);
354350
}
355351

@@ -403,7 +399,7 @@ MSCCLPP_DEVICE_INLINE void handleCopy(void* dst, void* src, uint32_t dstOffset,
403399

404400
template <typename T, typename PacketType = LL16Packet>
405401
__global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* input, T* output, T* scratch,
406-
size_t scratchSize, DeviceExecutionPlan* plan, uint32_t flag
402+
DeviceExecutionPlan* plan, uint32_t flag
407403
#if defined(ENABLE_NPKIT)
408404
,
409405
NpKitEventCollectContext* npKitEventCollectContexts, uint64_t* cpuTimestamp) {
@@ -501,28 +497,28 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu
501497
op.inputChannelIndexes, op.outputOffsets, op.inputOffsets, op.nOutputs, op.nInputs,
502498
op.size, false);
503499
} else if (op.type == OperationType::PUT_PACKET) {
504-
handlePutPacket<PacketType>(scratchSize, smChannels, proxyChannels, op.outputChannelIndexes, op.outputOffsets,
505-
op.inputOffsets, op.nOutputs, op.size, op.channelType, flag);
500+
handlePutPacket<PacketType>(smChannels, proxyChannels, op.outputChannelIndexes, op.outputOffsets, op.inputOffsets,
501+
op.nOutputs, op.size, op.channelType, flag);
506502
} else if (op.type == OperationType::REDUCE_SEND_PACKET) {
507503
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
508504
T* src = getBuffer(input, output, scratch, op.srcBufferType);
509-
handleReduceSendPacket<T, PacketType>(dst, op.dstOffset, src, op.srcOffset, scratch, scratchSize, op.inputOffsets,
510-
op.nInputs, smChannels, op.outputChannelIndexes, op.outputOffsets,
511-
op.nOutputs, op.size, flag);
505+
handleReduceSendPacket<T, PacketType>(dst, op.dstOffset, src, op.srcOffset, scratch, op.inputOffsets, op.nInputs,
506+
smChannels, op.outputChannelIndexes, op.outputOffsets, op.nOutputs, op.size,
507+
flag);
512508
} else if (op.type == OperationType::REDUCE_PACKET) {
513509
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
514510
T* src = getBuffer(input, output, scratch, op.srcBufferType);
515-
handleReduceSendPacket<T, PacketType, false>(dst, op.dstOffset, src, op.srcOffset, scratch, scratchSize,
516-
op.inputOffsets, op.nInputs, smChannels, op.outputChannelIndexes,
517-
op.outputOffsets, op.nOutputs, op.size, flag);
511+
handleReduceSendPacket<T, PacketType, false>(dst, op.dstOffset, src, op.srcOffset, scratch, op.inputOffsets,
512+
op.nInputs, smChannels, op.outputChannelIndexes, op.outputOffsets,
513+
op.nOutputs, op.size, flag);
518514
} else if (op.type == OperationType::COPY_PACKET) {
519515
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
520516
T* src = getBuffer(input, output, scratch, op.srcBufferType);
521-
handleCopyPacket<PacketType>(dst, src, scratchSize, op.dstOffset, op.srcOffset, op.size, flag);
517+
handleCopyPacket<PacketType>(dst, src, op.dstOffset, op.srcOffset, op.size, flag);
522518
} else if (op.type == OperationType::TRANSFORM_TO_PACKET) {
523519
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
524520
T* src = getBuffer(input, output, scratch, op.srcBufferType);
525-
handleTransformToPacket<PacketType>(dst, src, scratchSize, op.dstOffset, op.srcOffset, op.size, flag);
521+
handleTransformToPacket<PacketType>(dst, src, op.dstOffset, op.srcOffset, op.size, flag);
526522
} else if (op.type == OperationType::REDUCE_SEND) {
527523
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
528524
T* src = getBuffer(input, output, scratch, op.srcBufferType);
@@ -548,12 +544,12 @@ class ExecutionKernel {
548544
#if defined(MSCCLPP_DEVICE_HIP)
549545
template <typename PacketType>
550546
static void launchKernel(int rank, int nthreadblocks, int nthreads, void* src, void* dst, void* scratch,
551-
size_t scratchSize, DataType dataType, DeviceExecutionPlan* plan, size_t sharedMemSize,
552-
cudaStream_t stream, uint32_t flag = 0) {
547+
DataType dataType, DeviceExecutionPlan* plan, size_t sharedMemSize, cudaStream_t stream,
548+
uint32_t flag = 0) {
553549
switch (dataType) {
554550
case DataType::INT32:
555551
executionKernel<int32_t, PacketType><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
556-
rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, scratchSize, plan, flag
552+
rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, plan, flag
557553
#if defined(ENABLE_NPKIT)
558554
,
559555
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
@@ -563,7 +559,7 @@ class ExecutionKernel {
563559
break;
564560
case DataType::UINT32:
565561
executionKernel<uint32_t, PacketType><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
566-
rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, scratchSize, plan, flag
562+
rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, plan, flag
567563
#if defined(ENABLE_NPKIT)
568564
,
569565
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
@@ -573,7 +569,7 @@ class ExecutionKernel {
573569
break;
574570
case DataType::FLOAT16:
575571
executionKernel<half, PacketType><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
576-
rank, (half*)src, (half*)dst, (half*)scratch, scratchSize, plan, flag
572+
rank, (half*)src, (half*)dst, (half*)scratch, plan, flag
577573
#if defined(ENABLE_NPKIT)
578574
,
579575
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
@@ -583,7 +579,7 @@ class ExecutionKernel {
583579
break;
584580
case DataType::FLOAT32:
585581
executionKernel<float, PacketType><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
586-
rank, (float*)src, (float*)dst, (float*)scratch, scratchSize, plan, flag
582+
rank, (float*)src, (float*)dst, (float*)scratch, plan, flag
587583
#if defined(ENABLE_NPKIT)
588584
,
589585
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
@@ -593,7 +589,7 @@ class ExecutionKernel {
593589
break;
594590
case DataType::BFLOAT16:
595591
executionKernel<__bfloat16, PacketType><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
596-
rank, (__bfloat16*)src, (__bfloat16*)dst, (__bfloat16*)scratch, scratchSize, plan, flag
592+
rank, (__bfloat16*)src, (__bfloat16*)dst, (__bfloat16*)scratch, plan, flag
597593
#if defined(ENABLE_NPKIT)
598594
,
599595
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
@@ -606,8 +602,8 @@ class ExecutionKernel {
606602
#else // !defined(MSCCLPP_DEVICE_HIP)
607603
template <typename PacketType>
608604
static void launchKernel(int rank, int nthreadblocks, int nthreads, void* src, void* dst, void* scratch,
609-
size_t scratchSize, DataType dataType, DeviceExecutionPlan* plan, size_t sharedMemSize,
610-
cudaStream_t stream, uint32_t flag = 0);
605+
DataType dataType, DeviceExecutionPlan* plan, size_t sharedMemSize, cudaStream_t stream,
606+
uint32_t flag = 0);
611607
#endif // !defined(MSCCLPP_DEVICE_HIP)
612608
};
613609
} // namespace mscclpp

src/include/execution_plan.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ struct ExecutionPlan::Impl {
6969
std::vector<Operation> getOperations(int rank, int threadblock) const;
7070
int getThreadblockCount(int rank) const;
7171
int getNThreadsPerBlock() const;
72+
bool getIsUsingDoubleScratchBuffer() const;
7273

7374
void loadExecutionPlan(size_t inputSize, size_t outputSize, size_t contsSrcOffset, size_t constDstOffset);
7475
void lightLoadExecutionPlan(size_t inputSize, size_t outputSize, size_t contsSrcOffset, size_t constDstOffset);
@@ -96,6 +97,7 @@ struct ExecutionPlan::Impl {
9697
size_t inputSize;
9798
size_t outputSize;
9899
int nThreadsPerBlock;
100+
bool isUsingDoubleScratchBuffer;
99101

100102
private:
101103
std::pair<size_t, u_int32_t> calcSizePerRank(int rank, size_t inputSize, size_t outputSize) const;

0 commit comments

Comments
 (0)