Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from functools import partial 

2from multiprocessing import freeze_support, get_context 

3from zlib import adler32 

4 

5from mango.constants import c 

6from mango.errors import ExitQuiet, _strings 

7from mango.debug import DBG 

8 

9 

10def ismpi(argparse=True, opts={}): 

11 """Entry point function for MPI running.""" 

12 from mango.mango import _main_run 

13 if c.comm.Get_size() >= 1: 

14 if c.comm.Get_rank() == 0: 

15 _main_run(argparse, opts) 

16 else: 

17 _mpi_handler() 

18 else: 

19 raise ExitQuiet(f'{_strings.ENDC}{_strings._F[0]}mpi4py not installed') 

20 

21 

22def mp_handle(): 

23 """ 

24 Check if we are running in MPI or normal mode. 

25 

26 Returns 

27 ------- 

28 Runner: func 

29 function for running the calculation 

30 

31 """ 

32 if c.comm.Get_size() > 1: 

33 return _mpi_handler 

34 else: 

35 return mp_handler 

36 

37 

38def _mpi_handler(*args): 

39 """ 

40 MPI handler. 

41 

42 Each processor gets a statistic to run. 

43 

44 Parameters 

45 ---------- 

46 mp_worker: object 

47 Object to run 

48 data: dictionary 

49 data to be passed to the object 

50 stats: int 

51 number of repetitions to complete 

52 

53 """ 

54 worldsize = c.comm.Get_size() 

55 rank = c.comm.Get_rank() 

56 run = 1 

57 if rank != 0: 

58 mp_worker = data = stats = None 

59 mp_worker = c.comm.bcast(mp_worker, root=0) 

60 data = c.comm.bcast(data, root=0) 

61 stats = c.comm.bcast(stats, root=0) 

62 c.Error = data.Error 

63 else: 

64 mp_worker = c.comm.bcast(args[0], root=0) 

65 data = c.comm.bcast(args[1], root=0) 

66 stats = c.comm.bcast(args[2], root=0) 

67 if len(stats) < worldsize: 

68 data.Error("W creating more statistics, unused processors") 

69 # Simplified for now extra procs not used otherwise 

70 nostats = len(stats) 

71 if nostats > worldsize: 

72 remain = nostats - worldsize 

73 run = remain // worldsize 

74 if rank + 1 <= remain % worldsize: 

75 run += 1 

76 

77 # TODO Run number of stats expected with procs leftover if wanted 

78 try: 

79 for i in range(run): 

80 stat = rank + (worldsize * i) 

81 mp_worker(data, stats[stat]) 

82 except Exception as e: 

83 raise ExitQuiet(f"{type(e).__name__}: {e.args[0]}") if not DBG else e 

84 finally: 

85 c.comm.barrier() 

86 

87 

88def _mpi_handler_extracpus(*args): 

89 """ 

90 MPI and OpenMP setup. 

91 

92 Splits processors into node groups and local groups 

93 

94 Parameters 

95 ---------- 

96 mp_worker: object 

97 Object to run 

98 data: dictionary 

99 data to be passed to the object 

100 stats: int 

101 number of repetitions to complete 

102 

103 """ 

104 worldsize = c.comm.Get_size() 

105 rank = c.comm.Get_rank() 

106 # perprocessor should be changed to allow for multiple processor "blocks" 

107 if rank != 0: 

108 stats = None 

109 stats = c.comm.bcast(stats, root=0) 

110 comms, rnr, run = getlocalprocs(c.comm, stats, perprocessor=1, Error=None) 

111 else: 

112 stats = c.comm.bcast(args[2], root=0) 

113 comms, rnr, run = getlocalprocs(c.comm, stats, perprocessor=1, Error=args[1].Error) 

114 

115 if comms[1].Get_rank() == 0 and rank != 0: 

116 mp_worker = data = None 

117 mp_worker = c.comm.bcast(mp_worker, root=0) 

118 data = c.comm.bcast(data, root=0) 

119 c.Error = data.Error 

120 elif rank == 0: 

121 mp_worker = c.comm.bcast(args[0], root=0) 

122 data = c.comm.bcast(args[1], root=0) 

123 try: 

124 if comms[1].Get_rank() == 0: 

125 for i in range(run): 

126 stat = rank + (worldsize * i) 

127 mp_worker(data, stat) # ,comms) # maybe be useful for openmp+ mpi 

128 except Exception as e: 

129 raise ExitQuiet(f"{type(e).__name__}: {e.args[0]}") if not DBG else e 

130 finally: 

131 if comms[1].Get_rank() == 0: 

132 comms[1].barrier() 

133 

134 

135def _mpi_handler_rank1ctl(*args): 

136 """ 

137 MPI function for control processor and workers. 

138 

139 Parameters 

140 ---------- 

141 mp_worker: object 

142 Object to run 

143 data: dictionary 

144 data to be passed to the object 

145 stats: int 

146 number of repetitions to complete 

147 

148 """ 

149 worldsize = c.comm.Get_size() 

150 rank = c.comm.Get_rank() 

151 if rank != 0: 

152 while True: 

153 c.comm.send(rank, 0) 

154 calc = c.comm.recv(source=0) 

155 if not calc: 

156 break 

157 

158 mp_worker = c.comm.recv(source=0) 

159 data = c.comm.recv(source=0) 

160 c.Error = data.Error 

161 

162 try: 

163 mp_worker(data, c.comm.Get_rank()) 

164 except Exception as e: 

165 raise ExitQuiet(f"{type(e).__name__}: {e.args[0]}") if not DBG else e 

166 else: 

167 # TODO Run number of stats expected with procs leftover if wanted 

168 # Simplified for now extra procs not used otherwise 

169 if args[2] < worldsize: 

170 c.Error("W creating more statistics, unused processors") 

171 

172 # TODO allow file saving by control (maybe give control more than 1 cpu?, Is it io or cpu limited?) 

173 # Work out how many jobs each proc needs, remainder means an extra loop + if rank < x 

174 # 2nd Loop like below with range(worldsize-1) includes: iosave/end recv (disksaveloop becomes a send function) 

175 for i in range(max(args[2], worldsize - 1)): 

176 dest = c.comm.recv(source=c.MPI.ANY_SOURCE) 

177 c.comm.send(True, dest=dest) 

178 c.comm.send(args[0], dest=dest) 

179 c.comm.send(args[1], dest=dest) 

180 for i in range(worldsize - 1): 

181 dest = c.comm.recv(source=c.MPI.ANY_SOURCE) 

182 c.comm.send(False, dest=dest) 

183 

184 c.comm.barrier() 

185 

186 

187def mp_handler(mp_worker, data, stats, para=True): 

188 """ 

189 Multiprocessing handler. 

190 

191 Uses pythons builtin OpenMP like multiprocessing 

192 to separate parallel commands into different processes 

193 

194 Parameters 

195 ---------- 

196 mp_worker: object 

197 Object to run 

198 data: dictionary 

199 data to be passed to the object 

200 stats: int 

201 number of repetitions to complete 

202 para: bool 

203 multiprocess or not 

204 

205 """ 

206 

207 rnge = min(c.processors, len(stats)) 

208 

209 p_func = partial(mp_worker, data) 

210 

211 if para: 

212 ctx = get_context('forkserver') 

213 with ctx.Pool(rnge) as p: 

214 results = p.map_async(p_func, stats) 

215 p.close() 

216 p.join() 

217 

218 try: 

219 # exceptions are reraised by multiprocessing, should already be dealt with 

220 return results.get() 

221 except KeyboardInterrupt: 

222 p.terminate() 

223 except Exception as e: 

224 p.terminate() 

225 raise ExitQuiet(f"{type(e).__name__}: {e.args[0]}") if not DBG else e 

226 

227 else: 

228 results = [] 

229 for i in stats: 

230 results += [p_func(i)] 

231 return results 

232 

233 

234def getlocalprocs(commworld, stats, perprocessor=1, Error=None): 

235 """ 

236 Split processors into groups. 

237 

238 Parameters 

239 ---------- 

240 commworld: class 

241 MPI_COMM_WORLD communicator 

242 stats: int 

243 number of statistics to calculate 

244 perprocessor: int 

245 number of cores per group 

246 

247 Returns 

248 ------- 

249 communicators: tuple 

250 node communicator, group communicator 

251 realNodeRank: int 

252 rank 

253 run: int 

254 number of statistics to run 

255 

256 """ 

257 name = "{}".format(c.MPI.Get_processor_name()[:10]).encode() 

258 

259 # Computer adler32 to compare node names 

260 adname = adler32(name) 

261 nodecomm = commworld.Split(color=adname, key=commworld.Get_rank()) # change color (new group) and key (key maybe rank) 

262 

263 # Avoid false collisions 

264 names = nodecomm.allgather(name) 

265 realNodeRank = 0 

266 for i in range(nodecomm.Get_rank()): 

267 if name == names[i]: 

268 realNodeRank += 1 

269 

270 run, groupcomm = splitnode(perprocessor, stats, commworld, nodecomm, adname) 

271 return (nodecomm, groupcomm), realNodeRank, run 

272 

273 

274def splitnode(perprocessor, stats, commworld, nodecomm, adname): 

275 """ 

276 Split cores into groups. 

277 

278 Avoids splitting cores across nodes 

279 

280 Parameters 

281 ---------- 

282 perprocessor: int 

283 number of cores per group 

284 stats: int 

285 number of statistics to calculate 

286 commworld: class 

287 MPI_COMM_WORLD communicator 

288 nodecomm: class 

289 local node communicator 

290 adname: str 

291 current adler string 

292 

293 Returns 

294 ------- 

295 run: int or None 

296 number of statistics for group (only for rank 0) otherwise None 

297 groupcomm: 

298 local group communicator 

299 

300 """ 

301 import numpy as np 

302 

303 nsize = nodecomm.Get_size() 

304 nrank = nodecomm.Get_rank() 

305 wsize = commworld.Get_size() 

306 wrank = commworld.Get_rank() 

307 

308 numgroups = nsize // perprocessor 

309 

310 tgroups = commworld.allreduce(numgroups) if nrank == 0 else commworld.allreduce(0) 

311 

312 rmpernode = nsize % perprocessor 

313 leftrank = nsize - rmpernode 

314 

315 # Give each remaining processor to groups in order 

316 group = (nsize - nrank) % numgroups if rmpernode > 0 and nrank >= leftrank else nrank // perprocessor 

317 

318 groupcomm = nodecomm.Split(color=adler32('{}{}'.format(group, adname).encode()), key=nrank) 

319 

320 if stats < tgroups: 

321 if wrank == 0: 

322 c.Error("W creating more statistics, unused processors") 

323 stats = tgroups 

324 # Split stats evenly across nodes 

325 

326 if groupcomm.Get_rank() == 0: 

327 # Number of times to run worker 

328 run = stats * groupcomm.Get_size() // wsize 

329 spread = np.array(list(filter((None).__ne__, commworld.allgather([wrank, run])))) 

330 spread = spread[spread[:, 1].argsort()[::-1]][::-1] 

331 remain = stats - np.sum(spread[:, 1]) 

332 

333 if remain > 0: 

334 if wrank in spread[:remain, :][:, 0]: 

335 run += 1 

336 return run, groupcomm 

337 else: 

338 commworld.allgather(None) 

339 return None, groupcomm 

340 

341 

342if __name__ == '__main__': 

343 freeze_support()