Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from functools import partial
2from multiprocessing import freeze_support, get_context
3from zlib import adler32
5from mango.constants import c
6from mango.errors import ExitQuiet, _strings
7from mango.debug import DBG
10def ismpi(argparse=True, opts={}):
11 """Entry point function for MPI running."""
12 from mango.mango import _main_run
13 if c.comm.Get_size() >= 1:
14 if c.comm.Get_rank() == 0:
15 _main_run(argparse, opts)
16 else:
17 _mpi_handler()
18 else:
19 raise ExitQuiet(f'{_strings.ENDC}{_strings._F[0]}mpi4py not installed')
22def mp_handle():
23 """
24 Check if we are running in MPI or normal mode.
26 Returns
27 -------
28 Runner: func
29 function for running the calculation
31 """
32 if c.comm.Get_size() > 1:
33 return _mpi_handler
34 else:
35 return mp_handler
38def _mpi_handler(*args):
39 """
40 MPI handler.
42 Each processor gets a statistic to run.
44 Parameters
45 ----------
46 mp_worker: object
47 Object to run
48 data: dictionary
49 data to be passed to the object
50 stats: int
51 number of repetitions to complete
53 """
54 worldsize = c.comm.Get_size()
55 rank = c.comm.Get_rank()
56 run = 1
57 if rank != 0:
58 mp_worker = data = stats = None
59 mp_worker = c.comm.bcast(mp_worker, root=0)
60 data = c.comm.bcast(data, root=0)
61 stats = c.comm.bcast(stats, root=0)
62 c.Error = data.Error
63 else:
64 mp_worker = c.comm.bcast(args[0], root=0)
65 data = c.comm.bcast(args[1], root=0)
66 stats = c.comm.bcast(args[2], root=0)
67 if len(stats) < worldsize:
68 data.Error("W creating more statistics, unused processors")
69 # Simplified for now extra procs not used otherwise
70 nostats = len(stats)
71 if nostats > worldsize:
72 remain = nostats - worldsize
73 run = remain // worldsize
74 if rank + 1 <= remain % worldsize:
75 run += 1
77 # TODO Run number of stats expected with procs leftover if wanted
78 try:
79 for i in range(run):
80 stat = rank + (worldsize * i)
81 mp_worker(data, stats[stat])
82 except Exception as e:
83 raise ExitQuiet(f"{type(e).__name__}: {e.args[0]}") if not DBG else e
84 finally:
85 c.comm.barrier()
88def _mpi_handler_extracpus(*args):
89 """
90 MPI and OpenMP setup.
92 Splits processors into node groups and local groups
94 Parameters
95 ----------
96 mp_worker: object
97 Object to run
98 data: dictionary
99 data to be passed to the object
100 stats: int
101 number of repetitions to complete
103 """
104 worldsize = c.comm.Get_size()
105 rank = c.comm.Get_rank()
106 # perprocessor should be changed to allow for multiple processor "blocks"
107 if rank != 0:
108 stats = None
109 stats = c.comm.bcast(stats, root=0)
110 comms, rnr, run = getlocalprocs(c.comm, stats, perprocessor=1, Error=None)
111 else:
112 stats = c.comm.bcast(args[2], root=0)
113 comms, rnr, run = getlocalprocs(c.comm, stats, perprocessor=1, Error=args[1].Error)
115 if comms[1].Get_rank() == 0 and rank != 0:
116 mp_worker = data = None
117 mp_worker = c.comm.bcast(mp_worker, root=0)
118 data = c.comm.bcast(data, root=0)
119 c.Error = data.Error
120 elif rank == 0:
121 mp_worker = c.comm.bcast(args[0], root=0)
122 data = c.comm.bcast(args[1], root=0)
123 try:
124 if comms[1].Get_rank() == 0:
125 for i in range(run):
126 stat = rank + (worldsize * i)
127 mp_worker(data, stat) # ,comms) # maybe be useful for openmp+ mpi
128 except Exception as e:
129 raise ExitQuiet(f"{type(e).__name__}: {e.args[0]}") if not DBG else e
130 finally:
131 if comms[1].Get_rank() == 0:
132 comms[1].barrier()
135def _mpi_handler_rank1ctl(*args):
136 """
137 MPI function for control processor and workers.
139 Parameters
140 ----------
141 mp_worker: object
142 Object to run
143 data: dictionary
144 data to be passed to the object
145 stats: int
146 number of repetitions to complete
148 """
149 worldsize = c.comm.Get_size()
150 rank = c.comm.Get_rank()
151 if rank != 0:
152 while True:
153 c.comm.send(rank, 0)
154 calc = c.comm.recv(source=0)
155 if not calc:
156 break
158 mp_worker = c.comm.recv(source=0)
159 data = c.comm.recv(source=0)
160 c.Error = data.Error
162 try:
163 mp_worker(data, c.comm.Get_rank())
164 except Exception as e:
165 raise ExitQuiet(f"{type(e).__name__}: {e.args[0]}") if not DBG else e
166 else:
167 # TODO Run number of stats expected with procs leftover if wanted
168 # Simplified for now extra procs not used otherwise
169 if args[2] < worldsize:
170 c.Error("W creating more statistics, unused processors")
172 # TODO allow file saving by control (maybe give control more than 1 cpu?, Is it io or cpu limited?)
173 # Work out how many jobs each proc needs, remainder means an extra loop + if rank < x
174 # 2nd Loop like below with range(worldsize-1) includes: iosave/end recv (disksaveloop becomes a send function)
175 for i in range(max(args[2], worldsize - 1)):
176 dest = c.comm.recv(source=c.MPI.ANY_SOURCE)
177 c.comm.send(True, dest=dest)
178 c.comm.send(args[0], dest=dest)
179 c.comm.send(args[1], dest=dest)
180 for i in range(worldsize - 1):
181 dest = c.comm.recv(source=c.MPI.ANY_SOURCE)
182 c.comm.send(False, dest=dest)
184 c.comm.barrier()
187def mp_handler(mp_worker, data, stats, para=True):
188 """
189 Multiprocessing handler.
191 Uses pythons builtin OpenMP like multiprocessing
192 to separate parallel commands into different processes
194 Parameters
195 ----------
196 mp_worker: object
197 Object to run
198 data: dictionary
199 data to be passed to the object
200 stats: int
201 number of repetitions to complete
202 para: bool
203 multiprocess or not
205 """
207 rnge = min(c.processors, len(stats))
209 p_func = partial(mp_worker, data)
211 if para:
212 ctx = get_context('forkserver')
213 with ctx.Pool(rnge) as p:
214 results = p.map_async(p_func, stats)
215 p.close()
216 p.join()
218 try:
219 # exceptions are reraised by multiprocessing, should already be dealt with
220 return results.get()
221 except KeyboardInterrupt:
222 p.terminate()
223 except Exception as e:
224 p.terminate()
225 raise ExitQuiet(f"{type(e).__name__}: {e.args[0]}") if not DBG else e
227 else:
228 results = []
229 for i in stats:
230 results += [p_func(i)]
231 return results
234def getlocalprocs(commworld, stats, perprocessor=1, Error=None):
235 """
236 Split processors into groups.
238 Parameters
239 ----------
240 commworld: class
241 MPI_COMM_WORLD communicator
242 stats: int
243 number of statistics to calculate
244 perprocessor: int
245 number of cores per group
247 Returns
248 -------
249 communicators: tuple
250 node communicator, group communicator
251 realNodeRank: int
252 rank
253 run: int
254 number of statistics to run
256 """
257 name = "{}".format(c.MPI.Get_processor_name()[:10]).encode()
259 # Computer adler32 to compare node names
260 adname = adler32(name)
261 nodecomm = commworld.Split(color=adname, key=commworld.Get_rank()) # change color (new group) and key (key maybe rank)
263 # Avoid false collisions
264 names = nodecomm.allgather(name)
265 realNodeRank = 0
266 for i in range(nodecomm.Get_rank()):
267 if name == names[i]:
268 realNodeRank += 1
270 run, groupcomm = splitnode(perprocessor, stats, commworld, nodecomm, adname)
271 return (nodecomm, groupcomm), realNodeRank, run
274def splitnode(perprocessor, stats, commworld, nodecomm, adname):
275 """
276 Split cores into groups.
278 Avoids splitting cores across nodes
280 Parameters
281 ----------
282 perprocessor: int
283 number of cores per group
284 stats: int
285 number of statistics to calculate
286 commworld: class
287 MPI_COMM_WORLD communicator
288 nodecomm: class
289 local node communicator
290 adname: str
291 current adler string
293 Returns
294 -------
295 run: int or None
296 number of statistics for group (only for rank 0) otherwise None
297 groupcomm:
298 local group communicator
300 """
301 import numpy as np
303 nsize = nodecomm.Get_size()
304 nrank = nodecomm.Get_rank()
305 wsize = commworld.Get_size()
306 wrank = commworld.Get_rank()
308 numgroups = nsize // perprocessor
310 tgroups = commworld.allreduce(numgroups) if nrank == 0 else commworld.allreduce(0)
312 rmpernode = nsize % perprocessor
313 leftrank = nsize - rmpernode
315 # Give each remaining processor to groups in order
316 group = (nsize - nrank) % numgroups if rmpernode > 0 and nrank >= leftrank else nrank // perprocessor
318 groupcomm = nodecomm.Split(color=adler32('{}{}'.format(group, adname).encode()), key=nrank)
320 if stats < tgroups:
321 if wrank == 0:
322 c.Error("W creating more statistics, unused processors")
323 stats = tgroups
324 # Split stats evenly across nodes
326 if groupcomm.Get_rank() == 0:
327 # Number of times to run worker
328 run = stats * groupcomm.Get_size() // wsize
329 spread = np.array(list(filter((None).__ne__, commworld.allgather([wrank, run]))))
330 spread = spread[spread[:, 1].argsort()[::-1]][::-1]
331 remain = stats - np.sum(spread[:, 1])
333 if remain > 0:
334 if wrank in spread[:remain, :][:, 0]:
335 run += 1
336 return run, groupcomm
337 else:
338 commworld.allgather(None)
339 return None, groupcomm
342if __name__ == '__main__':
343 freeze_support()