@@ -66,7 +66,6 @@ struct ExecutionContext {
6666 size_t scratchBufferSize;
6767 std::shared_ptr<char > deviceExecutionPlansBuffer;
6868 int nthreadsPerBlock;
69- bool isUsingDoubleScratchBuffer;
7069};
7170
7271struct Executor ::Impl {
@@ -83,11 +82,13 @@ struct Executor::Impl {
8382
8483 ExecutionContext setupExecutionContext (int rank, void * sendbuff, void * recvbuff, size_t inputMessageSize,
8584 size_t outputMessageSize, size_t contsSrcOffset, size_t constDstOffset,
86- size_t sendBufferSize, size_t recvBufferSize, const ExecutionPlan& plan) {
85+ size_t sendBufferSize, size_t recvBufferSize, const ExecutionPlan& plan,
86+ int flag) {
8787 ExecutionContextKey key = {sendbuff, recvbuff, sendBufferSize, recvBufferSize, plan.impl_ ->name };
8888 if (this ->contexts .find (key) != this ->contexts .end ()) {
8989 plan.impl_ ->operationsReset ();
90- plan.impl_ ->lightLoadExecutionPlan (inputMessageSize, outputMessageSize, contsSrcOffset, constDstOffset);
90+ plan.impl_ ->lightLoadExecutionPlan (inputMessageSize, outputMessageSize, contsSrcOffset, constDstOffset, rank,
91+ sendBufferSize, recvBufferSize, flag);
9192 this ->setupDeviceExecutionPlan (this ->contexts [key], rank, plan);
9293 this ->contexts [key].deviceExecutionPlansBuffer =
9394 allocExtSharedCuda<char >(this ->contexts [key].deviceExecutionPlans .size () * sizeof (DeviceExecutionPlan));
@@ -98,16 +99,16 @@ struct Executor::Impl {
9899 }
99100
100101 plan.impl_ ->reset ();
101- plan.impl_ ->loadExecutionPlan (inputMessageSize, outputMessageSize, contsSrcOffset, constDstOffset);
102+ plan.impl_ ->loadExecutionPlan (inputMessageSize, outputMessageSize, contsSrcOffset, constDstOffset, rank,
103+ sendBufferSize, recvBufferSize, flag);
102104
103105 ExecutionContext context;
104- size_t scratchBufferSize = plan.impl_ ->getScratchBufferSize (rank, sendBufferSize, recvBufferSize );
106+ size_t scratchBufferSize = plan.impl_ ->getScratchBufferSize ();
105107 std::shared_ptr<char > scratchBuffer = allocExtSharedCuda<char >(scratchBufferSize);
106108 context.scratchBuffer = scratchBuffer;
107109 context.scratchBufferSize = scratchBufferSize;
108110 context.proxyService = std::make_shared<ProxyService>();
109111 context.nthreadsPerBlock = plan.impl_ ->getNThreadsPerBlock ();
110- context.isUsingDoubleScratchBuffer = plan.impl_ ->getIsUsingDoubleScratchBuffer ();
111112 this ->setupConnections (context, rank, plan);
112113 this ->setupRegisteredMemories (context, sendbuff, recvbuff, sendBufferSize, recvBufferSize, rank, plan);
113114 this ->setupChannels (context, sendbuff, recvbuff, sendBufferSize, recvBufferSize, rank, plan);
@@ -305,13 +306,8 @@ struct Executor::Impl {
305306 }
306307
307308 void launchKernel (ExecutionContext& context, int rank, void * sendbuff, void * recvbuff, DataType dataType,
308- cudaStream_t stream, PacketType packetType) {
309- static uint32_t flag = 0 ;
309+ cudaStream_t stream, PacketType packetType, uint32_t flag) {
310310 int nthreadblocks = context.deviceExecutionPlans .size ();
311- char * kernelScratchBufferPtr = context.scratchBuffer .get ();
312- if (context.isUsingDoubleScratchBuffer && (flag % 2 )) {
313- kernelScratchBufferPtr += context.scratchBufferSize / 2 ;
314- }
315311#if defined(ENABLE_NPKIT)
316312#if defined(__HIP_PLATFORM_AMD__)
317313 if (nthreadblocks > NPKIT_MAX_NUM_GPU_THREADBLOCKS) {
@@ -327,16 +323,14 @@ struct Executor::Impl {
327323#endif
328324 switch (packetType) {
329325 case PacketType::LL16:
330- ExecutionKernel::launchKernel<LL16Packet>(rank, nthreadblocks, context.nthreadsPerBlock , sendbuff, recvbuff,
331- (void *)kernelScratchBufferPtr, dataType,
332- (DeviceExecutionPlan*)context.deviceExecutionPlansBuffer .get (),
333- sharedMemSize, stream, ++flag);
326+ ExecutionKernel::launchKernel<LL16Packet>(
327+ rank, nthreadblocks, context.nthreadsPerBlock , sendbuff, recvbuff, (void *)context.scratchBuffer .get (),
328+ dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffer .get (), sharedMemSize, stream, flag);
334329 break ;
335330 case PacketType::LL8:
336- ExecutionKernel::launchKernel<LL8Packet>(rank, nthreadblocks, context.nthreadsPerBlock , sendbuff, recvbuff,
337- (void *)kernelScratchBufferPtr, dataType,
338- (DeviceExecutionPlan*)context.deviceExecutionPlansBuffer .get (),
339- sharedMemSize, stream, ++flag);
331+ ExecutionKernel::launchKernel<LL8Packet>(
332+ rank, nthreadblocks, context.nthreadsPerBlock , sendbuff, recvbuff, (void *)context.scratchBuffer .get (),
333+ dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffer .get (), sharedMemSize, stream, flag);
340334 break ;
341335 default :
342336 throw Error (" Invalid packet type" , ErrorCode::ExecutorError);
@@ -349,17 +343,18 @@ Executor::Executor(std::shared_ptr<Communicator> comm) : impl_(std::make_unique<
349343void Executor::execute (int rank, void * sendbuff, void * recvbuff, size_t sendBuffSize,
350344 [[maybe_unused]] size_t recvBuffSize, DataType dataType, const ExecutionPlan& plan,
351345 cudaStream_t stream, PacketType packetType) {
346+ static uint32_t flag = 1 ;
352347 size_t sendBytes, recvBytes;
353348 CUdeviceptr sendBasePtr, recvBasePtr;
354349 MSCCLPP_CUTHROW (cuMemGetAddressRange (&sendBasePtr, &sendBytes, (CUdeviceptr)sendbuff));
355350 MSCCLPP_CUTHROW (cuMemGetAddressRange (&recvBasePtr, &recvBytes, (CUdeviceptr)recvbuff));
356351 size_t offsetIn = (char *)sendbuff - (char *)sendBasePtr;
357352 size_t offsetOut = (char *)recvbuff - (char *)recvBasePtr;
358-
359353 ExecutionContext context =
360354 this ->impl_ ->setupExecutionContext (rank, (void *)sendBasePtr, (void *)recvBasePtr, sendBuffSize, recvBuffSize,
361- offsetIn, offsetOut, sendBytes, recvBytes, plan);
362- this ->impl_ ->launchKernel (context, rank, sendbuff, recvbuff, dataType, stream, packetType);
355+ offsetIn, offsetOut, sendBytes, recvBytes, plan, flag);
356+ this ->impl_ ->launchKernel (context, rank, sendbuff, recvbuff, dataType, stream, packetType, flag);
357+ flag++;
363358}
364359
365360Executor::~Executor () = default ;
0 commit comments