@@ -8,12 +8,12 @@ namespace mscclpp {
88
99template <typename PacketType>
1010void ExecutionKernel::launchKernel (int rank, int nthreadblocks, int nthreads, void * src, void * dst, void * scratch,
11- size_t scratchSize, DataType dataType, DeviceExecutionPlan* plan,
12- size_t sharedMemSize, cudaStream_t stream, uint32_t flag) {
11+ DataType dataType, DeviceExecutionPlan* plan, size_t sharedMemSize ,
12+ cudaStream_t stream, uint32_t flag) {
1313 switch (dataType) {
1414 case DataType::INT32:
1515 executionKernel<int32_t , PacketType><<<nthreadblocks, nthreads, sharedMemSize, stream>>> (
16- rank, (int32_t *)src, (int32_t *)dst, (int32_t *)scratch, scratchSize, plan, flag
16+ rank, (int32_t *)src, (int32_t *)dst, (int32_t *)scratch, plan, flag
1717#if defined(ENABLE_NPKIT)
1818 ,
1919 NpKit::GetGpuEventCollectContexts (), NpKit::GetCpuTimestamp ());
@@ -23,7 +23,7 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
2323 break ;
2424 case DataType::UINT32:
2525 executionKernel<uint32_t ><<<nthreadblocks, nthreads, sharedMemSize, stream>>> (
26- rank, (uint32_t *)src, (uint32_t *)dst, (uint32_t *)scratch, scratchSize, plan, flag
26+ rank, (uint32_t *)src, (uint32_t *)dst, (uint32_t *)scratch, plan, flag
2727#if defined(ENABLE_NPKIT)
2828 ,
2929 NpKit::GetGpuEventCollectContexts (), NpKit::GetCpuTimestamp ());
@@ -33,7 +33,7 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
3333 break ;
3434 case DataType::FLOAT16:
3535 executionKernel<half><<<nthreadblocks, nthreads, sharedMemSize, stream>>> (
36- rank, (half*)src, (half*)dst, (half*)scratch, scratchSize, plan, flag
36+ rank, (half*)src, (half*)dst, (half*)scratch, plan, flag
3737#if defined(ENABLE_NPKIT)
3838 ,
3939 NpKit::GetGpuEventCollectContexts (), NpKit::GetCpuTimestamp ());
@@ -43,7 +43,7 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
4343 break ;
4444 case DataType::FLOAT32:
4545 executionKernel<float ><<<nthreadblocks, nthreads, sharedMemSize, stream>>> (
46- rank, (float *)src, (float *)dst, (float *)scratch, scratchSize, plan, flag
46+ rank, (float *)src, (float *)dst, (float *)scratch, plan, flag
4747#if defined(ENABLE_NPKIT)
4848 ,
4949 NpKit::GetGpuEventCollectContexts (), NpKit::GetCpuTimestamp ());
@@ -53,7 +53,7 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
5353 break ;
5454 case DataType::BFLOAT16:
5555 executionKernel<__bfloat16><<<nthreadblocks, nthreads, sharedMemSize, stream>>> (
56- rank, (__bfloat16*)src, (__bfloat16*)dst, (__bfloat16*)scratch, scratchSize, plan, flag
56+ rank, (__bfloat16*)src, (__bfloat16*)dst, (__bfloat16*)scratch, plan, flag
5757#if defined(ENABLE_NPKIT)
5858 ,
5959 NpKit::GetGpuEventCollectContexts (), NpKit::GetCpuTimestamp ());
@@ -65,12 +65,10 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
6565}
6666
6767template void ExecutionKernel::launchKernel<LL16Packet>(int rank, int nthreadblocks, int nthreads, void * src, void * dst,
68- void * scratch, size_t scratchSize, DataType dataType,
69- DeviceExecutionPlan* plan, size_t sharedMemSize,
70- cudaStream_t stream, uint32_t flag);
68+ void * scratch, DataType dataType, DeviceExecutionPlan* plan,
69+ size_t sharedMemSize, cudaStream_t stream, uint32_t flag);
7170template void ExecutionKernel::launchKernel<LL8Packet>(int rank, int nthreadblocks, int nthreads, void * src, void * dst,
72- void * scratch, size_t scratchSize, DataType dataType,
73- DeviceExecutionPlan* plan, size_t sharedMemSize,
74- cudaStream_t stream, uint32_t flag);
71+ void * scratch, DataType dataType, DeviceExecutionPlan* plan,
72+ size_t sharedMemSize, cudaStream_t stream, uint32_t flag);
7573} // namespace mscclpp
7674#endif
0 commit comments