from functools import partial
from multiprocessing import freeze_support, get_context
from zlib import adler32
from mango.constants import c
from mango.errors import ExitQuiet, _strings
from mango.debug import DBG
[docs]def ismpi(argparse=True, opts={}):
    """Entry point function for MPI running."""
    from mango.mango import _main_run
    if c.comm.Get_size() >= 1:
        if c.comm.Get_rank() == 0:
            _main_run(argparse, opts)
        else:
            _mpi_handler()
    else:
        raise ExitQuiet(f'{_strings.ENDC}{_strings._F[0]}mpi4py not installed') 
[docs]def mp_handle():
    """
    Check if we are running in MPI or normal mode.
    Returns
    -------
    Runner: func
        function for running the calculation
    """
    if c.comm.Get_size() > 1:
        return _mpi_handler
    else:
        return mp_handler 
def _mpi_handler(*args):
    """
    MPI handler.
    Each processor gets a statistic to run.
    Parameters
    ----------
    mp_worker: object
        Object to run
    data: dictionary
        data to be passed to the object
    stats: int
        number of repetitions to complete
    """
    worldsize = c.comm.Get_size()
    rank = c.comm.Get_rank()
    run = 1
    if rank != 0:
        mp_worker = data = stats = None
        mp_worker = c.comm.bcast(mp_worker, root=0)
        data = c.comm.bcast(data, root=0)
        stats = c.comm.bcast(stats, root=0)
        c.Error = data.Error
    else:
        mp_worker = c.comm.bcast(args[0], root=0)
        data = c.comm.bcast(args[1], root=0)
        stats = c.comm.bcast(args[2], root=0)
        if len(stats) < worldsize:
            data.Error("W creating more statistics, unused processors")
        # Simplified for now extra procs not used otherwise
    nostats = len(stats)
    if nostats > worldsize:
        remain = nostats - worldsize
        run = remain // worldsize
        if rank + 1 <= remain % worldsize:
            run += 1
    # TODO Run number of stats expected with procs leftover if wanted
    try:
        for i in range(run):
            stat = rank + (worldsize * i)
            mp_worker(data, stats[stat])
    except Exception as e:
        raise ExitQuiet(f"{type(e).__name__}: {e.args[0]}") if not DBG else e
    finally:
        c.comm.barrier()
def _mpi_handler_extracpus(*args):
    """
    MPI and OpenMP setup.
    Splits processors into node groups and local groups
    Parameters
    ----------
    mp_worker: object
        Object to run
    data: dictionary
        data to be passed to the object
    stats: int
        number of repetitions to complete
    """
    worldsize = c.comm.Get_size()
    rank = c.comm.Get_rank()
    # perprocessor should be changed to allow for multiple processor "blocks"
    if rank != 0:
        stats = None
        stats = c.comm.bcast(stats, root=0)
        comms, rnr, run = getlocalprocs(c.comm, stats, perprocessor=1, Error=None)
    else:
        stats = c.comm.bcast(args[2], root=0)
        comms, rnr, run = getlocalprocs(c.comm, stats, perprocessor=1, Error=args[1].Error)
    if comms[1].Get_rank() == 0 and rank != 0:
        mp_worker = data = None
        mp_worker = c.comm.bcast(mp_worker, root=0)
        data = c.comm.bcast(data, root=0)
        c.Error = data.Error
    elif rank == 0:
        mp_worker = c.comm.bcast(args[0], root=0)
        data = c.comm.bcast(args[1], root=0)
    try:
        if comms[1].Get_rank() == 0:
            for i in range(run):
                stat = rank + (worldsize * i)
                mp_worker(data, stat)  # ,comms) # maybe be useful for openmp+ mpi
    except Exception as e:
        raise ExitQuiet(f"{type(e).__name__}: {e.args[0]}") if not DBG else e
    finally:
        if comms[1].Get_rank() == 0:
            comms[1].barrier()
def _mpi_handler_rank1ctl(*args):
    """
    MPI function for control processor and workers.
    Parameters
    ----------
    mp_worker: object
        Object to run
    data: dictionary
        data to be passed to the object
    stats: int
        number of repetitions to complete
    """
    worldsize = c.comm.Get_size()
    rank = c.comm.Get_rank()
    if rank != 0:
        while True:
            c.comm.send(rank, 0)
            calc = c.comm.recv(source=0)
            if not calc:
                break
            mp_worker = c.comm.recv(source=0)
            data = c.comm.recv(source=0)
            c.Error = data.Error
            try:
                mp_worker(data, c.comm.Get_rank())
            except Exception as e:
                raise ExitQuiet(f"{type(e).__name__}: {e.args[0]}") if not DBG else e
    else:
        # TODO Run number of stats expected with procs leftover if wanted
        # Simplified for now extra procs not used otherwise
        if args[2] < worldsize:
            c.Error("W creating more statistics, unused processors")
        # TODO allow file saving by control (maybe give control more than 1 cpu?, Is it io or cpu limited?)
        # Work out how many jobs each proc needs, remainder means an extra loop + if rank < x
        # 2nd Loop like below with range(worldsize-1) includes: iosave/end recv (disksaveloop becomes a send function)
        for i in range(max(args[2], worldsize - 1)):
            dest = c.comm.recv(source=c.MPI.ANY_SOURCE)
            c.comm.send(True, dest=dest)
            c.comm.send(args[0], dest=dest)
            c.comm.send(args[1], dest=dest)
        for i in range(worldsize - 1):
            dest = c.comm.recv(source=c.MPI.ANY_SOURCE)
            c.comm.send(False, dest=dest)
    c.comm.barrier()
[docs]def mp_handler(mp_worker, data, stats, para=True):
    """
    Multiprocessing handler.
    Uses pythons builtin OpenMP like multiprocessing
    to separate parallel commands into different processes
    Parameters
    ----------
    mp_worker: object
        Object to run
    data: dictionary
        data to be passed to the object
    stats: int
        number of repetitions to complete
    para: bool
        multiprocess or not
    """
    rnge = min(c.processors, len(stats))
    p_func = partial(mp_worker, data)
    if para:
        ctx = get_context('forkserver')
        with ctx.Pool(rnge) as p:
            results = p.map_async(p_func, stats)
            p.close()
            p.join()
        try:
            # exceptions are reraised by multiprocessing, should already be dealt with
            return results.get()
        except KeyboardInterrupt:
            p.terminate()
        except Exception as e:
            p.terminate()
            raise ExitQuiet(f"{type(e).__name__}: {e.args[0]}") if not DBG else e
    else:
        results = []
        for i in stats:
            results += [p_func(i)]
        return results 
[docs]def getlocalprocs(commworld, stats, perprocessor=1, Error=None):
    """
    Split processors into groups.
    Parameters
    ----------
    commworld: class
        MPI_COMM_WORLD communicator
    stats: int
        number of statistics to calculate
    perprocessor: int
        number of cores per group
    Returns
    -------
    communicators: tuple
        node communicator, group communicator
    realNodeRank: int
        rank
    run: int
        number of statistics to run
    """
    name = "{}".format(c.MPI.Get_processor_name()[:10]).encode()
    # Computer adler32 to compare node names
    adname = adler32(name)
    nodecomm = commworld.Split(color=adname, key=commworld.Get_rank())  # change color (new group) and key (key maybe rank)
    # Avoid false collisions
    names = nodecomm.allgather(name)
    realNodeRank = 0
    for i in range(nodecomm.Get_rank()):
        if name == names[i]:
            realNodeRank += 1
    run, groupcomm = splitnode(perprocessor, stats, commworld, nodecomm, adname)
    return (nodecomm, groupcomm), realNodeRank, run 
[docs]def splitnode(perprocessor, stats, commworld, nodecomm, adname):
    """
    Split cores into groups.
    Avoids splitting cores across nodes
    Parameters
    ----------
    perprocessor: int
        number of cores per group
    stats: int
        number of statistics to calculate
    commworld: class
        MPI_COMM_WORLD communicator
    nodecomm: class
        local node communicator
    adname: str
        current adler string
    Returns
    -------
    run: int or None
        number of statistics for group (only for rank 0) otherwise None
    groupcomm:
        local group communicator
    """
    import numpy as np
    nsize = nodecomm.Get_size()
    nrank = nodecomm.Get_rank()
    wsize = commworld.Get_size()
    wrank = commworld.Get_rank()
    numgroups = nsize // perprocessor
    tgroups = commworld.allreduce(numgroups) if nrank == 0 else commworld.allreduce(0)
    rmpernode = nsize % perprocessor
    leftrank = nsize - rmpernode
    # Give each remaining processor to groups in order
    group = (nsize - nrank) % numgroups if rmpernode > 0 and nrank >= leftrank else nrank // perprocessor
    groupcomm = nodecomm.Split(color=adler32('{}{}'.format(group, adname).encode()), key=nrank)
    if stats < tgroups:
        if wrank == 0:
            c.Error("W creating more statistics, unused processors")
        stats = tgroups
    # Split stats evenly across nodes
    if groupcomm.Get_rank() == 0:
        # Number of times to run worker
        run = stats * groupcomm.Get_size() // wsize
        spread = np.array(list(filter((None).__ne__, commworld.allgather([wrank, run]))))
        spread = spread[spread[:, 1].argsort()[::-1]][::-1]
        remain = stats - np.sum(spread[:, 1])
        if remain > 0:
            if wrank in spread[:remain, :][:, 0]:
                run += 1
        return run, groupcomm
    else:
        commworld.allgather(None)
        return None, groupcomm 
if __name__ == '__main__':
    freeze_support()