Using MPI in Python
Environment Setup
# For Ubuntu
apt-get install python3 python3-dev libopenmpi-dev openmpi-bin mpi-default-bin
pip3 install mpi4py numpy matplotlib
# For CentOS yum-config-manager --disable source yum install python python-devel python-pip python-matplotlib python-numpy mpich mpich-devel echo 'export PATH="/usr/lib64/mpich/bin/:$PATH"' >> ~/.bashrc echo 'export C_INCLUDE_PATH="/usr/include/mpich-x86_64/:$C_INCLUDE_PATH"' >> ~/.bashrc echo 'export LD_LIBRARY_PATH="/usr/lib64/mpich/lib/:$LD_LIBRARY_PATH"' >> ~/.bashrc
# CentOS install matplotlib
yum-builddep python-matplotlib
python-pip install matplotlib
MPI Code
# file: mandelbrot-seq.py import numpy as np from matplotlib import pyplot as plt def mandelbrot(c, maxit): z = c for n in range(maxit): if abs(z) > 2: return n z = z*z + c return 0 xmin,xmax = -2.0, 1.0 ymin,ymax = -1.0, 1.0 width,height = 320,200 maxit = 127 xlin = np.linspace(xmin, xmax, width) ylin = np.linspace(ymin, ymax, height) C = np.empty((width,height), np.int64) for w in range(width): for h in range(height): C[w, h] = mandelbrot(xlin[w] + 1j*ylin[h], maxit) plt.imshow(C, aspect='equal') plt.spectral() plt.show()
# file: mandelbrot-mpi.py import numpy as np from mpi4py import MPI from matplotlib import pyplot as plt def mandelbrot(c, maxit): z = c for n in range(maxit): if abs(z) > 2: return n z = z*z + c return 0 xmin,xmax = -2.0, 1.0 ymin,ymax = -1.0, 1.0 width,height = 320,200 maxit = 127 comm = MPI.COMM_WORLD comm_size = comm.Get_size() comm_rank = comm.Get_rank() ncols = width // comm_size + (width % comm_size > comm_rank) col_start = comm.scan(ncols) - ncols xlin = np.linspace(xmin, xmax, width) ylin = np.linspace(ymin, ymax, height) C_local = np.empty((ncols,height), np.int64) for w in range(ncols): for h in range(height): C_local[w, h] = mandelbrot(xlin[w+col_start] + 1j*ylin[h], maxit) # Gather Results here comm_gather_num = comm.gather(ncols, root=0) C = None if comm_rank == 0: C = np.empty((width,height), np.int64) else: C = None rowtype = MPI.INT64_T.Create_contiguous(height) rowtype.Commit() comm.Gatherv(sendbuf=[C_local, MPI.INT64_T], recvbuf=[C, (comm_gather_num, None), rowtype], root=0) rowtype.Free() if comm_rank == 0: plt.imshow(C, aspect='equal') plt.spectral() plt.show()
Setup your hosts and have fun!
迴響