diff --git a/src/ext/collectives/allgather/allgather_fullmesh_2.cu b/src/ext/collectives/allgather/allgather_fullmesh_2.cu index 895818228..72a2be9d9 100644 --- a/src/ext/collectives/allgather/allgather_fullmesh_2.cu +++ b/src/ext/collectives/allgather/allgather_fullmesh_2.cu @@ -17,14 +17,14 @@ __global__ void __launch_bounds__(1024, 1) const size_t tid = threadIdx.x + blockIdx.x * blockDim.x; const size_t lid = tid % WARP_SIZE; const size_t wid = tid / WARP_SIZE; + const size_t nPeer = nRanksPerNode - 1; - // Round down to multiple of warp size - const size_t nThread = (blockDim.x * gridDim.x) / WARP_SIZE * WARP_SIZE; + // Round down to multiple of peer count. + const size_t nThread = (blockDim.x * gridDim.x) / WARP_SIZE / nPeer * nPeer * WARP_SIZE; if (tid >= nThread) { return; } const size_t nWarp = nThread / WARP_SIZE; - const size_t nPeer = nRanksPerNode - 1; const size_t chanOffset = nPeer * blockIdx.x; auto memChans = memoryChannels + chanOffset;