我建议以下方法:
- 将二维数组切割成“几乎相等”大小的块,即本地列数接近
num_columns/mpi_size。
- 使用
mpi_gatherv 收集块,它使用不同大小的块。
要获得“几乎相等”的列数,请将本地列数设置为 num_columns / mpi_size 的整数值,并且仅对第一个 mod(num_columns,mpi_size) mpi 任务递增一。
下表演示了 (10,12) 矩阵在 5 个 MPI 进程上的划分:
01 02 03 11 12 13 21 22 31 32 41 42
01 02 03 11 12 13 21 22 31 32 41 42
01 02 03 11 12 13 21 22 31 32 41 42
01 02 03 11 12 13 21 22 31 32 41 42
01 02 03 11 12 13 21 22 31 32 41 42
01 02 03 11 12 13 21 22 31 32 41 42
01 02 03 11 12 13 21 22 31 32 41 42
01 02 03 11 12 13 21 22 31 32 41 42
01 02 03 11 12 13 21 22 31 32 41 42
01 02 03 11 12 13 21 22 31 32 41 42
这里第一个数字是进程的一个id,第二个数字是本地列数。
如您所见,进程 0 和 1 各有 3 列,而所有其他进程各只有 2 列。
您可以在下面找到我编写的工作示例代码。
最棘手的部分是为 MPI_Gatherv 生成 rcounts 和 displs 数组。讨论的表格是代码的输出。
program mpi2d
implicit none
include 'mpif.h'
integer myid, nprocs, ierr
integer,parameter:: m = 10 ! global number of rows
integer,parameter:: n = 12 ! global number of columns
integer nloc ! local number of columns
integer array(m,n) ! global m-by-n, i.e. m rows and n columns
integer,allocatable:: loc(:,:) ! local piece of global 2d array
integer,allocatable:: rcounts(:) ! array of nloc's (for mpi_gatrherv)
integer,allocatable:: displs(:) ! array of displacements (for mpi_gatherv)
integer i,j
! Initialize
call mpi_init(ierr)
call mpi_comm_rank(MPI_COMM_WORLD, myid, ierr)
call mpi_comm_size(MPI_COMM_WORLD, nprocs, ierr)
! Partition, i.e. get local number of columns
nloc = n / nprocs
if (mod(n,nprocs)>myid) nloc = nloc + 1
! Compute partitioned array
allocate(loc(m,nloc))
do j=1,nloc
loc(:,j) = myid*10 + j
enddo
! Build arrays for mpi_gatherv:
! rcounts containes all nloc's
! displs containes displacements of partitions in terms of columns
allocate(rcounts(nprocs),displs(nprocs))
displs(1) = 0
do j=1,nprocs
rcounts(j) = n / nprocs
if(mod(n,nprocs).gt.(j-1)) rcounts(j)=rcounts(j)+1
if((j-1).ne.0)displs(j) = displs(j-1) + rcounts(j-1)
enddo
! Convert from number of columns to number of integers
nloc = m * nloc
rcounts = m * rcounts
displs = m * displs
! Gather array on root
call mpi_gatherv(loc,nloc,MPI_INT,array,
& rcounts,displs,MPI_INT,0,MPI_COMM_WORLD,ierr)
! Print array on root
if(myid==0)then
do i=1,m
do j=1,n
write(*,'(I04.2)',advance='no') array(i,j)
enddo
write(*,*)
enddo
endif
! Finish
call mpi_finalize(ierr)
end