from scipy.constants import physical_constants
from numpy import finfo, float32, float64, zeros, bincount, ndarray, unique
from collections import defaultdict, OrderedDict
from contextlib import contextmanager, suppress
from sys import stdout, stderr, executable, argv, exit
from os import getcwd, name as osname, devnull, stat, remove
from subprocess import run, PIPE
from functools import wraps
from _version import version
from mango.debug import debug, verb
[docs]class const():
    """
    | Constants and version number.
    | Internal Units
    |
    | Main
    |
    * time: 1e-12 [s]
    * length: 1e-6 [cm]
    * mass: 1e-18 [g]
    | Derived
    |
    * energy unit: 1e-6 [erg]
    * temperature unit: 1e0 [K]
    * magnetic induction: 1e6 [gauss]
    * magnetisation: 1e-12 [emu]
    * force: 1 [g/cm.s^2]
    |
    """
    @debug(['const'])
    def __init__(self):
        """Get constants and check environment."""
        self.version = version
        self.KB = physical_constants["Boltzmann constant"][0] * 1.0e13  # energy units /K
        self.GBARE = -physical_constants["electron gyromag. ratio"][0] * 1.0e-10   # unit: 1/(gauss *ps)
        self.EPS = finfo(float32).eps
        self.EPS2 = finfo(float64).eps
        self.Ar_KB = 119.8 * self.KB
        self.Ar_sigma = 0.03405
        self.opsys = osname
        self.eijk = LCT()
        self.set_accuracy()
        self.verb = verb
        self.t, self.tinfo = GT()
        self.profile = PROF()
        self.havedisplay = MPL()
        self.using_IP = UIP()
        self.interactive = INTTERM()
        self.ustr, self.unicode = UCODE(self.t)
        self.processors = CPU()
        self.comm, self.MPI = EMPI()
        self.reseed()
        self.posfuncs = ['vv_trans', 'vv_mag', 'ff_trans', 'ff_mag', 'hh_mag', 'force_vars', 'energy_vars']
        self.units = OrderedDict({"xyz": {"position": '(nm)',
                                          "magnetisation": '(p.emu)',
                                          "forces": r'($\mu$N)',
                                          "momentum": r'($\mu$g.nm/s)'},
                                  "energy": {"Etotal": r'($\mu$ erg)',
                                             "mag_pot": r'($\mu$ erg)',
                                             "trans_pot": r'($\mu$ erg)',
                                             "kinetic": r'($\mu$ erg)'},
                                  "momenta": {"Mtotal": r'($\mu$g.nm/s)',
                                              "total_mag": r'($\mu$g.nm/s)',
                                              "total_angular": r'($\mu$g.nm/s)',
                                              "CoM": '(nm)',
                                              "CoM_vel": '(nm/ps)'},
                                  "time": "(ps)"})
        c2 = [v2 for v in self.units.values() if type(v) == dict for v2 in v.keys()]
        c2 += ['time']
        self.columns_flat = set(c2)
        self.columns = {"xyz": c2[:4],
                        "energy": c2[4:8],
                        "momenta": c2[8:13]}
        self.files = list(self.columns.keys())
[docs]    def reseed(self, seed=12345):
        self._Gen, rng = MTrnd(seed)
        self.random = self._Gen(rng)
        self.rgen = self.random.bit_generator 
[docs]    def jump(self, jumps):
        self.rgen = self.rgen.jumped(jumps)
        self.random = self._Gen(self.rgen) 
[docs]    def set_accuracy(self, EPS_val=None):
        self.EPS_val = EPS_val if EPS_val else self.EPS2 
[docs]    def Error(self, msg):
        """
        Simple error printer.
        Use managers.serverwrapper('Error') decorator for fully fledged error management
        """
        print(msg)
        if msg.startswith('F'):
            exit(1) 
[docs]    def Prog(self, msg):
        """Blank function, replaced on the fly with progressbar if required."""
        pass 
[docs]    def echo(self, msg, x, y):
        """Print at specified cursor location."""
        with self.t.location(x, y):
            print(msg, end='', flush=True) 
[docs]    def clear(self):
        """Clear screen and move cursor."""
        print(self.t.clear + self.t.move_y(0), end='', flush=True) 
    def _banner(self):
        print(r"___  ___  ___   _   _  _____  _____ ")
        print(r"|  \/  | / _ \ | \ | ||  __ \|  _  |")
        print(r"| .  . |/ /_\ \|  \| || |  \/| | | |")
        print(r"| |\/| ||  _  || . ` || | __ | | | |")
        print(r"| |  | || | | || |\  || |_\ \\ \_/ /")
        print(r"\_|  |_/\_| |_/\_| \_/ \____/ \___/ ")
        print(f"Version: {self.version}", end="\n\n")
 
[docs]def MTrnd(seed=12345):
    """Random number generator."""
    from numpy.random import Generator, MT19937
    rgen = MT19937(seed)
    return Generator, rgen 
[docs]def GT():
    """
    Get terminal information.
    Returns
    -------
    Terminal: instance
        Blessings terminal class instance
    Terminal info: dict
        Terminal info
    """
    class Term():
        width = None
        height = None
        does_styling = False
        clear = ''
        subscript = ''
        no_subscript = ''
        is_a_tty = False
        @contextmanager
        def location(*args):
            yield
        def move_y(*args):
            return ''
    try:
        from blessings import Terminal, curses
        t = Terminal()
    except (ImportError, curses.error if 'Terminal' in locals() else ModuleNotFoundError) as e:
        if not stdout.isatty():
            verb.print(f"Blessings: {e}", file=stderr)
        t = Term()
    width = t.width if t.is_a_tty else 0
    tinfo = {'c0': [0, width // 2], 'c1': [width // 2, width],
             'otty': stdout.isatty(), 'etty': stderr.isatty()}
    return t, tinfo 
[docs]def LCT():
    """
    Construct Levi-Civita Tensor.
    Returns
    -------
    eijk: array
        Levi-Civita tensor
    """
    eijk = zeros((3, 3, 3))
    eijk[0, 1, 2] = eijk[1, 2, 0] = eijk[2, 0, 1] = 1
    eijk[0, 2, 1] = eijk[2, 1, 0] = eijk[1, 0, 2] = -1
    return eijk 
[docs]def UIP():
    """
    Test for Jupyter Notebooks.
    Returns
    -------
    using_IP: bool
        Are we using Jupyter notebooks?
    """
    try:
        using_IP = str(get_ipython()).split()[0].split('.')[-1] == 'ZMQInteractiveShell'
    except NameError:
        using_IP = False
    return using_IP 
[docs]def MPL():
    """
    Test matplotlib is functioning.
    Returns
    -------
    mpl: bool
        Can we use matplotlib?
    """
    with open(devnull, "a") as nf:
        exitval = run([executable, "-c", "import matplotlib.pyplot as plt; plt.figure()"],
                      stderr=nf)
    return exitval.returncode == 0 
[docs]def EMPI():
    """
    Test whether MPI entry point os used.
    Returns
    -------
    MPI.COMM_WORLD: comm
        All processors communicator
    MPI: module
        MPI module from mpi4py or class
    """
    def mpifuns():
        class CW():
            def Get_size():
                return 0
            def Get_rank():
                return 0
        class MPI():
            COMM_WORLD = CW
        return MPI
    if argv[0].split('/')[-1].endswith("mpi"):
        try:
            from mpi4py import MPI
        except ImportError:
            MPI = mpifuns()
    else:
        MPI = mpifuns()
    return MPI.COMM_WORLD, MPI 
[docs]def CPU():
    """
    Get number of cores available.
    Returns
    -------
    processors: int
        number of processors
    """
    try:
        pind = argv.index("--cores")
        processors = argv[pind + 1]
        del argv[pind + 1]
        del argv[pind]
    except ValueError:
        try:
            from os import sched_getaffinity
            processors = len(sched_getaffinity(0))
        except ImportError:
            try:
                from numexpr import detect_number_of_cores
                processors = detect_number_of_cores()
            except ImportError:
                # I tried
                from multiprocessing import cpu_count
                processors = cpu_count()
    return int(processors) 
[docs]def UCODE(t):
    """
    Set up unicode strings.
    Returns
    -------
    ustr: class
        Unicode chars store
    uc: bool
        Is environment unicode?
    """
    unicode = run(['locale', 'charmap'], stdout=PIPE, stderr=PIPE)
    uc = (unicode.stdout.decode()[:-1] == 'UTF-8')
    ustr = _variables()
    if uc:
        ustr.mu = "\u03BC"
        ustr.chi = "\u03C7"
        ustr.tau = "\u03C4"
        ustr.taub = ustr.tau + t.subscript + 'b' + t.no_subscript
        ustr.taun = ustr.tau + t.subscript + 'n' + t.no_subscript
    else:
        verb.print("Character encoding: ", unicode.stdout)
        ustr.mu = "mu"
        ustr.chi = "chi"
        ustr.tau = "tau"
        ustr.taub = "tau_b"
        ustr.taun = "tau_n"
    return ustr, uc 
[docs]def INTTERM():
    """
    Check for interactive terminal.
    Returns
    -------
    interactive: bool
        Is environment interactive?
    """
    try:
        from sys import ps1
        interactive = True
    except ImportError:
        interactive = False
    return interactive 
[docs]def PROF():
    """
    Test for code profiling.
    Returns
    -------
    profile: bool
        Is profile in argv?
    """
    try:
        pind = argv.index("--profile")
        del argv[pind]
        profile = True
    except ValueError:
        profile = False
    return profile 
[docs]def nestedDict():
    """
    Nested Dictionaries.
    In principle this creates an infinitely deep nested dictionary
    eg. dict_name["key"]["key"]["key"]= value
    each key can have sub keys and because defaultdict creates
    the key if it doesn't exist this is all that is needed
    """
    return defaultdict(nestedDict) 
[docs]class getkeywords():
    """Get keywords and organise them."""
    def __init__(self):
        """Get defaults and keywords."""
        self.words = self._getkeywords()
        self.defaults = self._getdefaults()
        self.explanation = self._gethelp()
    def _gethelp(self):
        """
        Get default vaules for all arguments.
        Returns
        -------
        dict
            default values
        """
        explanation = {}
        for no, i in enumerate(self.words.values()):
            explanation[no] = i[1][1]['help']
        return explanation
    def _getdefaults(self):
        """
        Get default vaules for all arguments.
        Returns
        -------
        dict
            default values
        """
        defs = {}
        for i in self.words.values():
            defs[i[0][0]] = i[0][2]
        return defs
[docs]    def flgorvar(self):
        """
        Sort arguments into list of flags or variables.
        Returns
        -------
        var_list: list
            list of vars
        flg_list: list
            list of flags
        """
        var_list = ['extra_iter', 'RandNoState', 'defaults',
                    'time', 'ms', 'sigma', 'limit']  # 1e-12 [s], 1e6[emu/cm^3],  1e-6 [cm],1/1.e-6 cm (inverted)
        flag_list = ['files', 'field', 'lastframe']
        def renamed(lis, key):
            n = 4 if len(key) == 5 else 0
            lis += [key[n]]
        for i in self.words.values():
            if i[0][3] == "var":
                renamed(var_list, i[0])
            elif i[0][3] == "flg":
                renamed(flag_list, i[0])
        return var_list, flag_list 
    @staticmethod
    def _range(val, min=0, max=1):
        from argparse import ArgumentTypeError
        value = float(val)
        if min <= value <= max:
            return value
        else:
            raise ArgumentTypeError('value not in range {}-{}'.format(min, max))
    @debug(["keys"])
    def _getkeywords(self):
        """
        Definition of all keywords.
        Includes argparse calls, default name, types, values and container
        Returns
        -------
        OrderedDict
            All current keywords
        """
        return OrderedDict(
            # Input files
            {"InputFile": (("inputfile", str, None, "var"),
                           ('-I',
                            dict(dest='inputfile', help="R|Input file location"))),
             "RestartFile": (("restart", str, None, "flg"),
                             ('-R',
                              dict(dest="restart", help="R|Restart file location"))),
             "PositionFile": (("location", str, None, "var"),
                              ('-P',
                               dict(dest="location", help="R|Positions file location"))),
             # Properties
             "ExternalFieldFreq": (("nu", float, 1e-6, "var"),  # [Hz]  267.0e3 was katherine
                                   ('-nu',
                                    dict(help='R|External Field Frequency\n(default: {:3.2e}[Hz])'))),
             "ExternalFieldStrength": (("H_0", float, 167.0, "var"),  # 1e-6 gauss 167.0 * 1e6 was katherine
                                       (['-f', '--field'],
                                        dict(dest="H_0", help='R|External Field Strength\n(default: {}[1e-6 Gauss])'))),
             "MediumViscosity": (("eta", float, 1.002e-2, "var"),  # [g/(cm *s)]
                                 (['-e', '--eta'],
                                  dict(help='R|Solvent Viscosity\n(default: {} [g/(cm *s)])'))),
             "MagneticDensity": (("Mdens", float, 85, "var"),  # [emu/g]
                                 ('-Mdens',
                                  dict(help='R|Magnetic Density of particles\n(default {} [emu/g])'))),
             "ParticleDensity": (("dens", float, 6.99, "var"),  # [g /cm^3]
                                 (['-dn', '--dens'],
                                  dict(nargs='+', help='R|Density of each particle\n(default: {}[g /cm^3])'))),
             "ParticleRadius": (("radius", float, 5.0, "var"),  # [1e-6cm]
                                (['-rad', '--radius'],
                                 dict(nargs='+', help='R|Radius of each particle\n(default: {}[1e-6cm])'))),
             "Temperature": (("temp", float, 298.0, "var"),  # [K]
                             (['-T', '--temp'],
                              dict(help='R|Temperature\n(default: {}[K])'))),
             # Calculation Options
             "BackgroundNoise": (("noise", bool, True, "flg"),
                                 ('--no-noise', dict(action="store_false", dest='noise',
                                                     help='R|Background Noise switch\n(default: {})'))),
             "Boxsize": (("boxsize", float, 50, "var"),  # [1e-6cm]
                         (['-bx', '--boxsize'],
                          dict(nargs='+', help='R|Periodic boundary box size\n(default: max of (1.1 * Number of Molecules * Radius) or {}[1e-6cm])'))),
             "Epsilon": (("epsilon", float, None, "var"),  # e.u.
                         ('--epsilon',
                          dict(help='R|Set depth of Potential Well\n(default: Scales wrt Ar (isothermal compresibility))'))),
             "Iterations": (("nmax", float, 4.0e2, "var"),
                            (['-n', '--nmax'],
                             dict(help='R|Number of Iterations\n(default: {})'))),
             "Labview": (("labview", bool, True, "flg"),
                         ('--no-labview',
                          dict(action="store_false",
                               help='R|Remove initial cluster momentum\n(default: {})'))),
             "MagnetisationSW": (("mag_sw", bool, True, "flg"),
                                 ('--no-mag',
                                  dict(action="store_false", dest='mag_sw', help='R|Magnetisation calculation switch\n(default: {})'))),
             "NumberParticles": (("no_molecules", int, 5, "var"),
                                 (['-np', '--no_particles'],
                                  dict(dest='no_molecules', help='R|Number of Particles\n(default: {})'))),
             "Optimisation": (("opt", bool, False, "flg"),
                              ('--opt',
                               dict(action="store_true", dest='opt', help='R|Auto-minimise energy of the system\n(default: {})'))),
             "PeriodicBoundaries": (("pbc", bool, True, "flg"),
                                    ('--no-pbc',
                                     dict(action="store_false", dest='pbc', help='R|Periodic Boundary Conditions\n(default: {})'))),
             "Potential": (('potential', str, False, "var"),
                           ('--potential',
                            dict(help='R|Change potential used'))),
             "Repetitions": (("stats", int, 10, "var"),
                             (['-st', '--stats'],
                              dict(help='R|Number of statistical repetitions\n(default: {})'))),
             "SkipSave": (("skip", int, 1, "var"),
                          (['-sk', '--skip'],
                           dict(help='R|Save every # iterations\n(default: {})'))),
             "SusceptibilityCalc": (("suscep", list, False, "flg"),
                                    ('--suscep',
                                     dict(nargs='*', type=str, choices=['mag', 'vel', 'angular', 'inertia', 'rotation', ''],
                                          help='R|Zero field Calculation\n(default: {})'))),
             "ThermalOptim": (('op_thermal', bool, False, "flg"),
                              ('--op_thermal',
                               dict(action='store_true', help='R|Add thermal noise to the momentum after the optimisation\n(default: {})'))),
             "Timestep": (("dt", float, 20, "var"),  # [ps]
                          ('-dt',
                           dict(help='R|Timestep\n(default: {}[ps])'))),
             # Postprocessing Options
             "Align": (('align', bool, False, "flg"),
                       ('--align',
                        dict(action="store_true", help='R|Align with Kabsch algorithm\n(default: {})'))),
             "Blocking": (('block', list, [5, 10, 20, 50, 100], "var"),
                          ('--bl',
                           dict(nargs='+', type=int, dest='block',
                                help='R|Block averaging bins\n(default: {})'))),
             "CreateFile": (("cfile", list, False, "flg"),
                            ('--cfile',
                             dict(nargs='+', type=str, choices=['xyz', 'energy', 'momenta', 'lf'],
                                  help='R|Create output files\n(default: {})'))),
             "Equilibrate": (('eq', float, 0.1, "flg"),
                             ('-eq',
                              dict(type=self._range, dest='eq',
                                   help='R|Strip fraction of trajectory for equilibration\n(default: {})'))),
             "PlotColumns": (("column", list, False, "flg"),
                             ('--column',
                              dict(nargs='+', type=str, help='R|Plot simple datasets'))),
             "Run": (("run", int, False, "flg"),
                     ('--run',
                      dict(help='R|Select run number\n(default: {})'))),
             "KineticTemp": (('kinetic', bool, False, "flg"),
                             ('--kinetic',
                              dict(action="store_true", help='R|Calculate kinetic temperature\n(default: {})'))),
             "LengthCheck": (('lengthcheck', bool, False, "flg"),
                             ('--lengthcheck',
                              dict(action="store_true", help='R|Calculate average particle separation\n(default: {})'))),
             "ShowGraphs": (("showg", bool, False, "flg"),
                            ('--showg',
                             dict(action="store_true", help='R|Show graphs during run\n(default: {})'))),
             "SaveGraphs": (("saveg", bool, False, "flg"),
                            ('--saveg',
                             dict(action="store_true", help='{}{}'.format('R|Save graphs\n',
                                                                          'WARNING: slow for large datasets\n(default: {})')))),
             "UseFile": (("ufile", list, False, "flg"),
                         ('--ufile',
                          dict(type=str, help='{}'.format('R|Post process from a specific file\n',
                                                          'Possible file types: [xyz, hdf5, pkl, txt]\n',
                                                          'for xyz see help(mango.io.read_data.xyzppread)')))),
             # Run style changes
             "LogDirectory": (("logs", str, "{}/Mango_{}_Logs".format(getcwd(), c.version), "var"),
                              ('--logs',
                               dict(help='R|Log save directory\n(default:./{})'))),
             "Parallel": (("parallel", bool, True, "flg"),
                          ('--no-parallel',
                           dict(action="store_false", dest='parallel', help='R|Parallel Computation\n(default: {})'))),
             "SaveChunk": (("savechunk", float, 1e4, "var"),
                           (['-sv', '--savechunk'],
                            dict(help='{}{}'.format('R|Save to file every # iterations\n',
                                                    '(default: Smallest of {:3.0e} and Nmax)')))),
             "SaveFiletype": (("save_type", str, 'hdf5', "flg"),
                              ('--save_type',
                               dict(choices=['hdf5', 'pkl', 'txt'], help='R|Save filetype (hdf5, pkl, txt)\n(default: {})'))),
             "Scfiter": (("scfiter", int, 20, "var"),
                         ('--scfiter',
                          dict(help='R|Magnetic self consistancy loop before warning\n(default: {})'))),
             "Walltime": (("walltime", int, 1e7, "var"),
                          ('--walltime',
                           dict(help='R|Run time\n(default: {:3.0e}[s])'))),
             "Verbosity": (("nout", int, 2, "flg"),
                           (['-v', '--verbosity'],
                            dict(dest='nout', action='count', help='R|Verbosity level\n(default: {}, Max implemented: 5)'))),
             # Data for Neel Relaxations from: Journal of Magnetism and Magnetic Materials 321 (2009) 3126-3131
             "NeelRelaxations": (("neel", bool, False, "flg"),
                                 ('--neel',
                                  dict(action="store_true", help='R|Enable Neel relaxations\n(default: {})'))),
             "NeelAnisotropicFactor": (("keff", float, 2.25e-8, "var"),  # Unit [1.e12 erg/cm^3]
                                       (['-k', '--keff'],
                                        dict(help='R|Set anisotropic factor\n(default: {:3.2e}[erg*cm3])'))),
             "NeelTau0": (("tauN_0", float, 1.7, "var"),  # Unit [1.e-12 s]
                          ('--tauN_0',
                           dict(help='R|Zero value for {}\n'.format(c.ustr.taun) + '(default: {:3.2e}[1.e-12 s])'))),
             "NeelTemp0": (("temp0", float, 262.0, "var"),  # Unit: ([K]
                           ('--temp0',
                            dict(help='R|Neel zero temperature\n(default: {:3.2e}[K])'))),
             }) 
class _variables():
    """
    Dynamic creation of variables.
    Initial input is a dictionary but further values can be added
    as with any class by:
    class_instance.variable_name = value
    """
    def __init__(self, **argd):
        self.__dict__.update(argd)
    def print_contents(self):
        for i, j in self.__dict__.items():
            if isinstance(j, ndarray):
                pr_j, index, counts = unique(j, return_index=True, return_counts=True)
                if index.size == 1:
                    pr_j = "{} x {}".format(pr_j[0], counts[0])
                else:
                    pr_j = pr_j, index, counts
            else:
                pr_j = j
            print(i, pr_j)
[docs]def switch(switch_name):
    """
    Decorate to switch off a function.
    Improvements- TODO
    * Still return expected return values but with no changes to variables or new variables set to 0 or none
    * only call if switched off
    """
    def sw_decorate(func_to_decorate):
        @wraps(func_to_decorate)
        def wrapper(*args, **kw):
            switch = getattr(getattr(args[0], "flg"), switch_name)
            if switch:
                result = func_to_decorate(*args, **kw)
            else:
                def nfunc(*args, **kw):
                    pass
                result = nfunc()
            return result
        return wrapper
    return sw_decorate 
[docs]def rmemptyfile(filename):
    """
    Remove files of Zero size.
    Parameters
    ----------
    filename: str
        Full filename
    """
    with suppress(FileNotFoundError):
        if stat(filename).st_size == 0:
            remove(filename) 
[docs]def boolconvert(boolean):
    """
    Convert strings to bool correctly.
    Parameters
    ----------
    boolean: str
    Returns
    -------
    bool
    """
    return boolean.lower() in ("true", "t", "1") 
[docs]def asciistring(length):
    """
    Get a list of characters of desired length.
    Current max is 26 characters
    Parameters
    ----------
    length: int
        number of characters to return
    Returns
    -------
    str of length specified
    """
    return [chr(i) for i in range(ord('a'), ord('z') + 1)][:length] 
[docs]class bin_count():
    """
    Fast ufunc.at.
    bincount OR ufunc.at
    but ufunc.at is slower: https://github.com/numpy/numpy/issues/5922
    Much faster even with the manipulation needed for bincount
    pass flattened arrays in eg arr.ravel()
    """
    __slots__ = ['leng', 'shape']
    def __init__(self, **argd):
        """
        Intialise.
        dictionary needs leng and shape
        """
        for k, v in argd.items():
            setattr(self, k, v)
[docs]    def addat(self, to, at, fro):
        """Add at."""
        to += bincount(at, fro, minlength=self.leng) 
[docs]    def subtractat(self, to, at, fro):
        """Subtract at."""
        to -= bincount(at, fro, minlength=self.leng) 
[docs]    def reshape(self, to):
        """Reshape to given shape."""
        return to.reshape(self.shape)  
c = const()
keys = getkeywords()
keys.words = keys.words if 'helpout.py' in argv else {'keywords'}
if __name__ == "__main__":
    pass