@@ -148,12 +148,19 @@ std::vector<BufferType> ExecutionPlan::Impl::getConnectedBufferTypes(int rank) c
148148 }
149149 return std::vector<BufferType>(bufferTypes.begin (), bufferTypes.end ());
150150}
151- size_t ExecutionPlan::Impl::getScratchBufferSize (int rank, size_t inputSize) const {
151+ size_t ExecutionPlan::Impl::getScratchBufferSize (int rank, size_t inputSize, size_t outputSize) const {
152+ size_t sizePerRank;
153+ if (this ->inputChunks .at (rank) != 0 )
154+ sizePerRank = inputSize / this ->inputChunks .at (rank);
155+ else if (this ->outputChunks .at (rank) != 0 )
156+ sizePerRank = outputSize / this ->outputChunks .at (rank);
157+ else
158+ throw mscclpp::Error (" Output or Input chunks must be greater than 0" , mscclpp::ErrorCode::ExecutorError);
159+
152160 if (this ->isUsingPacket ) {
153- return inputSize / this ->inputChunks .at (rank) * this ->scratchChunks .at (rank) * 2 /* data + flag*/ *
154- 2 /* double buffer*/ ;
161+ return sizePerRank * this ->scratchChunks .at (rank) * 2 /* data + flag*/ * 2 /* double buffer*/ ;
155162 }
156- return inputSize / this -> inputChunks . at (rank) * this ->scratchChunks .at (rank);
163+ return sizePerRank * this ->scratchChunks .at (rank);
157164}
158165std::vector<Operation> ExecutionPlan::Impl::getOperations (int rank, int threadblock) const {
159166 return this ->operations .at (rank)[threadblock];
@@ -163,7 +170,8 @@ int ExecutionPlan::Impl::getThreadblockCount(int rank) const { return this->oper
163170
164171int ExecutionPlan::Impl::getNThreadsPerBlock () const { return this ->nThreadsPerBlock ; }
165172
166- void ExecutionPlan::Impl::loadExecutionPlan (size_t inputSize, size_t contsSrcOffset, size_t constDstOffset) {
173+ void ExecutionPlan::Impl::loadExecutionPlan (size_t inputSize, size_t outputSize, size_t contsSrcOffset,
174+ size_t constDstOffset) {
167175 std::ifstream file (this ->planPath );
168176 json obj = json::parse (file);
169177 if (this ->name != obj[" name" ]) {
@@ -186,10 +194,12 @@ void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t contsSrcOff
186194 this ->setupChannels (gpus);
187195
188196 this ->inputSize = inputSize;
197+ this ->outputSize = outputSize;
189198 this ->setupOperations (gpus, contsSrcOffset, constDstOffset);
190199}
191200
192- void ExecutionPlan::Impl::lightLoadExecutionPlan (size_t inputSize, size_t contsSrcOffset, size_t constDstOffset) {
201+ void ExecutionPlan::Impl::lightLoadExecutionPlan (size_t inputSize, size_t outputSize, size_t contsSrcOffset,
202+ size_t constDstOffset) {
193203 std::ifstream file (this ->planPath );
194204 json obj = json::parse (file);
195205 if (this ->name != obj[" name" ]) {
@@ -210,6 +220,7 @@ void ExecutionPlan::Impl::lightLoadExecutionPlan(size_t inputSize, size_t contsS
210220 }
211221
212222 this ->inputSize = inputSize;
223+ this ->outputSize = outputSize;
213224 this ->setupOperations (gpus, contsSrcOffset, constDstOffset);
214225}
215226
@@ -313,8 +324,9 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffse
313324 // Get the relevant channel index in rank channelInfos
314325 operation.inputChannelIndexes [i] =
315326 channelIndexes[{srcBufferType, dstBufferType, operation.channelType }][op[" i_cids" ][i][" id" ]];
316- operation.inputOffsets [i] = this ->getOffset (rank, this ->inputSize , (uint32_t )op[" i_cids" ][i][" off" ]) +
317- (srcBufferType != BufferType::SCRATCH ? contsSrcOffset : 0 );
327+ operation.inputOffsets [i] =
328+ this ->getOffset (rank, this ->inputSize , this ->outputSize , (uint32_t )op[" i_cids" ][i][" off" ]) +
329+ (srcBufferType != BufferType::SCRATCH ? contsSrcOffset : 0 );
318330 chunkIndexes.push_back ((uint32_t )op[" i_cids" ][i][" off" ]);
319331 }
320332 }
@@ -323,8 +335,9 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffse
323335 operation.nInputs = op[" srcs" ].size ();
324336 operation.inputBufferType = convertToBufferType (op[" srcs" ][0 ][" buff" ]);
325337 for (int i = 0 ; i < operation.nInputs ; i++) {
326- operation.inputOffsets [i] = this ->getOffset (rank, this ->inputSize , (uint32_t )op[" srcs" ][i][" off" ]) +
327- (operation.inputBufferType != BufferType::SCRATCH ? contsSrcOffset : 0 );
338+ operation.inputOffsets [i] =
339+ this ->getOffset (rank, this ->inputSize , this ->outputSize , (uint32_t )op[" srcs" ][i][" off" ]) +
340+ (operation.inputBufferType != BufferType::SCRATCH ? contsSrcOffset : 0 );
328341 chunkIndexes.push_back ((uint32_t )op[" srcs" ][i][" off" ]);
329342 }
330343 }
@@ -335,8 +348,9 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffse
335348 BufferType dstBufferType = convertToBufferType (op[" o_buff" ][" dst" ]);
336349 operation.outputChannelIndexes [i] =
337350 channelIndexes[{srcBufferType, dstBufferType, operation.channelType }][op[" o_cids" ][i][" id" ]];
338- operation.outputOffsets [i] = this ->getOffset (rank, this ->inputSize , (uint32_t )op[" o_cids" ][i][" off" ]) +
339- (dstBufferType != BufferType::SCRATCH ? constDstOffset : 0 );
351+ operation.outputOffsets [i] =
352+ this ->getOffset (rank, this ->inputSize , this ->outputSize , (uint32_t )op[" o_cids" ][i][" off" ]) +
353+ (dstBufferType != BufferType::SCRATCH ? constDstOffset : 0 );
340354 chunkIndexes.push_back ((uint32_t )op[" o_cids" ][i][" off" ]);
341355 }
342356 }
@@ -345,27 +359,29 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffse
345359 operation.nOutputs = op[" dsts" ].size ();
346360 operation.outputBufferType = convertToBufferType (op[" dsts" ][0 ][" buff" ]);
347361 for (int i = 0 ; i < operation.nOutputs ; i++) {
348- operation.outputOffsets [i] = this ->getOffset (rank, this ->inputSize , (uint32_t )op[" dsts" ][i][" off" ]) +
349- (operation.outputBufferType != BufferType::SCRATCH ? constDstOffset : 0 );
362+ operation.outputOffsets [i] =
363+ this ->getOffset (rank, this ->inputSize , this ->outputSize , (uint32_t )op[" dsts" ][i][" off" ]) +
364+ (operation.outputBufferType != BufferType::SCRATCH ? constDstOffset : 0 );
350365 chunkIndexes.push_back ((uint32_t )op[" dsts" ][i][" off" ]);
351366 }
352367 }
353368 if (op.contains (" srcbuff" )) {
354369 operation.srcBufferType = convertToBufferType (op[" srcbuff" ]);
355370 }
356371 if (op.contains (" srcoff" )) {
357- operation.srcOffset = this ->getOffset (rank, this ->inputSize , (uint32_t )op[" srcoff" ]);
372+ operation.srcOffset = this ->getOffset (rank, this ->inputSize , this -> outputSize , (uint32_t )op[" srcoff" ]);
358373 chunkIndexes.push_back ((uint32_t )op[" srcoff" ]);
359374 }
360375 if (op.contains (" dstbuff" )) {
361376 operation.dstBufferType = convertToBufferType (op[" dstbuff" ]);
362377 }
363378 if (op.contains (" dstoff" )) {
364- operation.dstOffset = this ->getOffset (rank, this ->inputSize , (uint32_t )op[" dstoff" ]);
379+ operation.dstOffset = this ->getOffset (rank, this ->inputSize , this -> outputSize , (uint32_t )op[" dstoff" ]);
365380 chunkIndexes.push_back ((uint32_t )op[" dstoff" ]);
366381 }
367382 if (op.contains (" cnt" )) {
368- operation.size = this ->getNChunkSize (rank, this ->inputSize , (uint32_t )op[" cnt" ], chunkIndexes);
383+ operation.size =
384+ this ->getNChunkSize (rank, this ->inputSize , this ->outputSize , (uint32_t )op[" cnt" ], chunkIndexes);
369385 }
370386 ops.push_back (operation);
371387 }
@@ -374,14 +390,33 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffse
374390 }
375391}
376392
377- size_t ExecutionPlan::Impl::getOffset (int rank, size_t inputSize, uint32_t chunkIndex, uint32_t alignment) const {
393+ std::pair<size_t , u_int32_t > ExecutionPlan::Impl::calcSizePerRank (int rank, size_t inputSize, size_t outputSize) const {
394+ std::pair<size_t , u_int32_t > sizePerRank;
395+ if (this ->inputChunks .at (rank) == 0 && this ->outputChunks .at (rank) == 0 ) {
396+ throw mscclpp::Error (" Output or Input chunks must be greater than 0" , mscclpp::ErrorCode::ExecutorError);
397+ } else if (this ->inputChunks .at (rank) != 0 && this ->outputChunks .at (rank) != 0 ) {
398+ if (inputSize / this ->inputChunks .at (rank) != outputSize / this ->outputChunks .at (rank))
399+ throw mscclpp::Error (" Size per chunks inconsistent" , mscclpp::ErrorCode::ExecutorError);
400+ else
401+ sizePerRank = std::make_pair (inputSize, this ->inputChunks .at (rank));
402+ } else if (this ->inputChunks .at (rank) != 0 ) {
403+ sizePerRank = std::make_pair (inputSize, this ->inputChunks .at (rank));
404+ } else if (this ->outputChunks .at (rank) != 0 ) {
405+ sizePerRank = std::make_pair (outputSize, this ->outputChunks .at (rank));
406+ }
407+ return sizePerRank;
408+ }
409+
410+ size_t ExecutionPlan::Impl::getOffset (int rank, size_t inputSize, size_t outputSize, uint32_t chunkIndex,
411+ uint32_t alignment) const {
378412 if (inputSize % alignment != 0 ) {
379413 throw Error (" inputSize must be a multiple of alignment" , ErrorCode::ExecutorError);
380414 }
381415
382416 const int nGroups = this ->chunkGroups .at (rank);
383- uint32_t nInputChunks = this ->inputChunks .at (rank);
384- uint32_t nelems = inputSize / (alignment * sizeof (uint8_t ));
417+ auto sizePerRank = calcSizePerRank (rank, inputSize, outputSize);
418+ uint32_t nInputChunks = sizePerRank.second ;
419+ uint32_t nelems = sizePerRank.first / (alignment * sizeof (uint8_t ));
385420 if (nelems % nGroups != 0 ) {
386421 throw Error (" Input size must be a multiple of nGroups" , ErrorCode::ExecutorError);
387422 }
@@ -397,12 +432,12 @@ size_t ExecutionPlan::Impl::getOffset(int rank, size_t inputSize, uint32_t chunk
397432 return static_cast <size_t >(offset) * alignment;
398433}
399434
400- size_t ExecutionPlan::Impl::getNChunkSize (int rank, size_t inputSize, uint32_t nChunks,
435+ size_t ExecutionPlan::Impl::getNChunkSize (int rank, size_t inputSize, size_t outputSize, uint32_t nChunks,
401436 const std::vector<uint32_t > chunkIndexes) const {
402437 size_t nChunkSize = 0 ;
403438 for (uint32_t index : chunkIndexes) {
404- uint32_t beginOff = getOffset (rank, inputSize, index);
405- uint32_t endOff = getOffset (rank, inputSize, index + nChunks);
439+ uint32_t beginOff = getOffset (rank, inputSize, outputSize, index);
440+ uint32_t endOff = getOffset (rank, inputSize, outputSize, index + nChunks);
406441 if (nChunkSize == 0 ) {
407442 nChunkSize = endOff - beginOff;
408443 } else if (nChunkSize != endOff - beginOff) {
0 commit comments