from numpy import (all as nall, array, array_str, reshape, where, einsum, concatenate,
                   ndarray, block, arange, savetxt, atleast_1d, atleast_2d, genfromtxt)
from re import search
from sys import stdout, argv
from os import get_terminal_size, path, getcwd, fstat
from shutil import copyfile
from datetime import datetime
from asyncio import new_event_loop
from contextlib import suppress
from mango.constants import c, _variables, keys, rmemptyfile, boolconvert
from mango.debug import debug
from mango.managers import serverwrapper
import mango.imports as imports
@debug(['io'])
def _input(func):
    """
    User input that wipes itself after input.
    (For error printing)
    """
    out = input(func)
    stdout.write("\x1b[A\x1b[A")
    a = get_terminal_size()
    stdout.write(" " * a.columns)
    return out
[docs]def separater(response):
    response = atleast_2d(response)
    cols = response.shape[-1]
    if cols == 3:
        response = (response, None, None)
    elif cols == 6:
        response = (response[..., :3], response[..., 3:6], None)
    elif cols >= 9:
        response = (response[..., :3], response[..., 3:6], response[..., 6:9])
    return response 
[docs]class Data():
    append = set(["position", "magnetisation", "momentum", "iter_time", "magnetic_field",
                  "forces", "neel_relaxation"]) 
[docs]class write_data(Data):
    """Writing all calculated data to file."""
    @debug(["write"])
    def __init__(self, var, flg):
        self.dump, self.save_type = imports.write(flg.save_type)
        name = var.name
        self.directory = name[:name.rfind("/") + 1]
        self.restartfile = "{}{}{}".format(self.directory, name[name.rfind("Run"):name.rfind("_")], "restart")
        self.total = var.skip_iters + 1
        self.nmax = self.total * var.no_molecules
        self.restarted = flg.restart
        self.restart()
[docs]    def setup(self, name):
        self.loop = new_event_loop()
        self.addtorestart(name)
        return self.save_type 
[docs]    def write(self, *args, **kw):
        self.loop.run_until_complete(self._write(*args, **kw)) 
    @debug(['io'])
    async def _write(self, moldata_dict):
        """Write data to specific file type."""
        if self.save_type == "hdf5":
            self.dump(moldata_dict['name'], moldata_dict,
                      compression=('blosc', 9), datalength=self.nmax, append=self.append)
        elif self.save_type == "pkl":
            with open(moldata_dict['name'], "ab") as w:
                self.dump(moldata_dict, w, protocol=-1)
        else:
            no_mol = moldata_dict["position"].shape[1]
            txt_data(moldata_dict, no_mol)
[docs]    def restart(self):
        if path.isfile(self.restartfile):
            copyfile(self.restartfile, self.restartfile + "R" + datetime.now().isoformat(timespec='minutes'))
        with open(self.restartfile, 'w') as rf:
            rf.write(f"save_type {self.save_type}\ndirectory {self.directory}\ntotal_tosave {self.total}\n") 
[docs]    def addtorestart(self, fname):
        with open(self.restartfile, "a") as rf:
            rf.write(f"{fname.rsplit('/')[-1]}\n")  
[docs]class read_data(Data):
    """
    Reads data in for post processing.
    Improvements
    * Only read in data required rather than all data
    - Works for hdf5 files
    """
    def __init__(self, fname, rtype=None, xyzpp=False):
        if rtype is None:
            fname, rtype = fname.rsplit('.', 1)
        self.xyzpp = xyzpp
        self.load, self.rtype = imports.read(rtype)
        self._newname(fname, xyzpp)
    def _readtype(self):
        if self.rtype == "hdf5":
            self._rd = self._hdf5read
        elif self.rtype == "pkl":
            self._rd = self._pickleread
        elif self.xyzpp:
            self._rd = self.xyzppread
            self.xvsloc = "{}.var".format(self.fname.split('.')[0])
        elif self.rtype == 'xyz':
            self._rd = self.xyzread
        else:
            self._rd = self.txtread
    def _hdf5read(self, obj, *args):
        return self.load(self.fname, obj, *args)
    def _pickleread(self, obj, *args):
        if obj is None:
            c.Error('F pkl files require key name')
        lspl = [x for x in filter(None, obj.split('/'))]
        data = []
        with open(self.fname, 'rb') as self.v:
            l_obj = lspl[0] if len(lspl) > 1 else obj
            while self.v.tell() < fstat(self.v.fileno()).st_size:
                data.append(self.load(self.v)[l_obj])
        if len(lspl) > 1:
            for n, reads in enumerate(data):
                for i in lspl[1:]:
                    data[n] = data[n][i]
        if l_obj in self.append:
            data = concatenate(data, axis=0)
        return data[-1] if l_obj in ['vars', 'flags'] else data
    def _newname(self, fname, xyzpp):
        self.xyzpp = xyzpp
        self.fname = fname if fname.endswith(self.rtype) else "{}.{}".format(fname, self.rtype)
        self._readtype()
[docs]    @debug(["io"])
    def read(self, obj=None, fname=None, restart=False, lengthcheck=False, keylist=False, chunk=None, xyzpp=False):
        """
        Read in pickle and hdf5 outputs.
        Parameters
        ----------
        obj: string
            object to be read in
        fname: string
            file to be read in
        restart: bool
            Is read in to restart run
        lengthcheck: bool
            Get lenth of position data and last iteration
        """
        if fname is None:
            fname = self.fname
        if self.fname != fname:
            self._newname(fname, xyzpp)
        if restart:
            if self.rtype in ["xyz", "txt"]:
                c.Error("F Restarting is not implemented for this filetype")
            return self.getrestart(*self.lengthcheck())
        elif lengthcheck:
            return self.lengthcheck()
        else:
            return self._rd(obj, chunk, keylist) 
[docs]    def getrestart(self, position, last_iter):
        """Get restart data."""
        restart = _variables(**self._rd('vars'))
        restart.pos = position[last_iter, :, :]
        restart.mag = self._rd("magnetisation")[last_iter, :, :]
        restart.mom = self._rd("momentum")[last_iter, :, :]
        restartflags = _variables(**self._rd('flags'))
        return restart, restartflags, last_iter 
[docs]    def lengthcheck(self):
        """
        Check which row number is the last iteration.
        Redundency check to make sure the last used row is the last row
        """
        position = self._rd('position')
        l_it = nall(position == 0.0, axis=(-2, -1))
        if not nall(l_it):
            # l_it == False produces correct answer
            # l_it is False does not
            last_iter = where(l_it == False)[0][-1]
        else:
            last_iter = 0
        return position, last_iter 
[docs]    def xyzppread(self, obj=None, *args):
        """
        Post process from xyz file.
        Requires associated <filename>.var file with variables needed for post processing
        also needs to include the variable columns with the xyz column variable names.
        e.g.
        mass     1.111
        radius   1.23e-4
        columns  position,velocity,force
        Only internal variable names currently supported
        """
        if 'xvars' not in self.__dict__.keys():
            self.xvars = self._xyzvars()
        if obj == 'vars':
            return self.xvars
        elif 'xyz' not in self.__dict__.keys():
            _, self.xyz = self.xyzread()
            self.xyz = separater(self.xyz)
        try:
            col = self.xvars['columns'].index(obj)
            return self.xyz[col]
        except ValueError:
            if "velocity" in self.xvars['columns']:
                col = self.xvars['columns'].index("velocity")
                return self.xyz[col] * self.xvars['mass'] 
[docs]    def xyzread(self, *args):
        """Read all data from (extended) xyz file."""
        with open(self.fname, "rb") as loc:
            for lines in loc:
                if lines[:-1] not in [b"", b"#"]:
                    break
        response = atleast_2d(genfromtxt(self._skip_lines(self.fname, lines), comments="#", dtype=str))
        molecules = response[:, 0]
        lastaxis = response.shape[-1] - 1
        lines = int(lines.decode("utf-8"))
        if (response.shape[0] % lines != 0 or lastaxis % 3 != 0):
            c.Error("F Input array shape not consistant with number of molecules specified")
        response = response[:, 1:].astype(float).reshape((-1, lines, lastaxis), order='F')
        return molecules, response 
    @staticmethod
    def _skip_lines(fname, lines):
        with open(fname, "rb") as file:
            skip = -1
            for lno, line in enumerate(file):
                if line.startswith(lines):
                    skip = lno + 1
                yield b"#" + line if line.startswith((b"i", lines)) or lno == skip else line
    def _xyzvars(self, vnames={}):
        with open(self.xvsloc) as xyzvars:
            for line in xyzvars:
                if "=" in line:
                    line = [x.strip() for x in line.split("=")]
                    val = line[1].split(' ')[0]
                    if "." in line[1].split(' ')[0]:
                        val = float(val)
                    else:
                        with suppress(ValueError):
                            val = int(val)
                    vnames[line[0].split(" ")[-1]] = val
        vnames['columns'] = self._col_splitter(vnames['columns'])
        for i in ['dens', 'vol']:
            vnames[i] = atleast_1d(vnames[i])
        return vnames
    @staticmethod
    def _col_splitter(columns):
        if "," in columns:
            return columns.split(',')
        else:
            return columns.split(' ')
[docs]    def txtread(self, obj, *args):
        """
        Read in text file output.
        * Not restartable
        Parameters
        ----------
        obj: string
            type of data to be read in, Variables "vars", Data "data"
        """
        if obj == "vars":
            self.variables = {}
            with open(self.fname, "r") as fp:
                for line in fp:
                    if search(r'[#]', line) and not search(',', line):
                        values = array(line.strip().split()[:])
                        self.variables[values[1]] = float(values[3])
                    elif search(',', line):
                        break
            # FIX from here some variables are not saved
            return self.variables
        elif obj == "data":
            # FIX data shape and return values
            data = []
            # read file
            with open(self.fname, "r") as fp:
                data_temp = array([line.strip().split() for line in fp if not search(r'[#]', line)])
            mol_sv = set(data_temp[:, 0])
            data = array(data_temp[:, 1:], dtype=float)
            nmax = int(self.variables["nmax"])
            mol = int(self.variables["no_mol"])
            stats = int(self.variables["stats"])
            iterations = data[:nmax, :2]
            neelr = {'mol_{}'.format(i): reshape(data[:, 2], (nmax, mol), order='F')[:, i] for i in range(mol)}
            magnetisation = reshape(data[:, 3:6], (stats, nmax, mol, 3), order='F')
            position = reshape(data[:, 9:12], (stats, nmax, mol, 3), order='F')
            forces = reshape(data[:, 12:15], (stats, nmax, mol, 3), order='F')
            momentum = reshape(data[:, 15:18], (stats, nmax, mol, 3), order='F')
            return {"molecule names": mol_sv,
                    "iter_time": iterations,
                    "neel_relaxation": neelr,
                    "magnetisation": magnetisation,
                    "position": position,
                    "forces": forces,
                    "momentum": momentum}  
[docs]@serverwrapper("Error")
def getvar():
    """
    Get variables from existing run.
    Run with
    mango_vars -f [Files...] -v [VARS....]
    """
    from argparse import ArgumentParser
    parser = ArgumentParser(prog='MangoVars',
                            description="MangoVars reader")
    parser.add_argument('-f', action='append', nargs='*', type=str, help='Files')
    parser.add_argument('-v', action='append', nargs='*', type=str, help='Variables')
    p = parser.parse_args()
    p.f = [item for sl in p.f for item in sl]
    p.v = [item for sl in p.v for item in sl]
    for f in p.f:
        print(f)
        name, save_type = f.rsplit('.', 1)
        reader = read_data(name, save_type)
        variables = reader.read("vars")
        for i in p.v:
            if i in variables:
                print(i, variables[i])
            else:
                # Get variable names
                c.Error('{}{}'.format('>W Only internal variable names currently supported\n',
                                      'Can\'t find {}'.format(i))) 
[docs]def restart_gen(file=None, reps=None, outfile=None):
    """
    Restart generator.
    Run with
    mango_restartregen [file] [reps] [outfile <optional>]
    """
    if '-h' in argv:
        print(restart_gen.__doc__)
        return
    lenarg = len(argv)
    file = argv[1].rsplit(".", maxsplit=2) if file is None else file.rsplit(".", maxsplit=2)
    reps = int(argv[2] if reps is None else reps)
    outfile = (outfile if outfile is not None else
               "Run{}restart".format(file[0].split("Run")[1].split("_")[0]) if lenarg != 4 else argv[3])
    save_type = file[-1]
    f_name = file[0]
    directory_F = f_name.rsplit("/", 1)
    directory = getcwd() if len(directory_F[0]) == 1 else "/" if len(directory_F[0]) < 1 else directory_F[0]
    name = []
    for i in range(reps):
        name.append("{}.{:g}.{}".format(directory_F[1], i + 1, save_type))
    total_tosave = read_data(name[0], save_type).read(restart=True)[0].skip_iters + 1
    with open(outfile, "w") as rgen:
        rgen.write("save_type {}\n".format(save_type))
        rgen.write("directory {}\n".format(directory))
        rgen.write("total_tosave {}\n".format(total_tosave))
        for i in name:
            rgen.write("{} \n".format(i))
    rmemptyfile(outfile) 
[docs]def restartfile_read(r_file):
    """Read restart file."""
    with open(r_file) as rf:
        filename = []
        # collectLI = []
        for line in rf:
            line_data = line.split()
            if line_data == []:
                pass
            elif line_data[0] == "save_type":
                save_type = line_data[1]
            elif line_data[0] == 'directory':
                directory = line_data[1]
            elif line_data[0] == 'total_tosave':
                total = int(line_data[1])
            else:
                filename.append(line_data[0])
    if not path.isdir(directory):
        rfile_loc = r_file.rsplit('/', 1)[0]
        directory = rfile_loc + "/" if rfile_loc != '' else "./"
    return [directory + file for file in filename], save_type, directory, total 
[docs]def get_restartdata(restart_data, filenames):
    """Collect all restart data from all files."""
    collectmag = {}
    collectmom = {}
    collectpos = {}
    collectiters = {}
    RandNoState = {}
    for name in filenames:
        stat = int(name.rsplit(".", 2)[-2]) - 1
        restart, restartflags, last_iter = restart_data.read(fname=name, restart=True)
        RandNoState = {**RandNoState, **restart.RandNoState}
        collectmag[stat] = array(restart.mag)
        collectmom[stat] = array(restart.mom)
        collectpos[stat] = array(restart.pos)
        collectiters[stat] = last_iter
    restart.RandNoState = RandNoState
    restart.mag = collectmag
    restart.mom = collectmom
    restart.pos = collectpos
    return restart, restartflags, collectiters 
[docs]def xyz_write(xyz_mol_data, timescale, directory, run, flg, mmag, boxsize):
    """
    Write an extended xyz file.
    ::
      [n, number of atoms]
      [timestep] [current time] [flags...]
      mol[0] [position x y z] [momentum x y z] [magnetisation x y z] [forces x y z]
      .
      .
      mol[n -1]...
    Parameters
    ----------
    xyz_mol_data: np.array
        array of data in the form [pos(xyz) mom(xyz) mag(xyz) force(xyz)] for each molecule
    timescale: np.array
        timescale list
    directory: string
        save directory
    run: int
        run number
    flg: class
        class of flags
        flags used - pbc, mag_sw, align
    mmag: float
        Magnetic saturation * volume
    boxsize: float
        periodic boxsize
    """
    xyz_dshape = xyz_mol_data.shape
    filename = "{}{}Run{}{}{}.m.xyz".format(directory, "S_" if flg.suscep else "",
                                            str(run), "ALIGN" if flg.align else "",
                                            'LF' if flg.lastframe else "")
    no_mol = xyz_dshape[-2]
    if not path.isfile(filename) or run.endswith('nofin'):
        with open(filename, "wb") as xyz_w:
            # Put force and velocity flags if needed
            if flg.pbc:
                axis = boxsize
            else:
                axis = False
            commentline = bytes(" momentum={} mag={} force={} boxsize={} \n".format(True, flg.mag_sw, True, axis), 'utf-8')
            xyz_string = 'MNP {}'.format(12 * ' % 15.7e')
            xyz_mol_data[..., 6:9] = einsum("ijk, j -> ijk", xyz_mol_data[..., 6:9], 1 / mmag)
            for i in [xyz_dshape[0] - 1] if flg.lastframe else range(xyz_dshape[0]):
                xyz_w.write(bytes("{} \ni={} time={:e}[ps]".format(no_mol, i, timescale[i]), 'utf-8'))
                xyz_w.write(commentline)
                savetxt(xyz_w, xyz_mol_data[i, :, :], fmt=xyz_string)
    else:
        c.Error(">M Complete XYZ file exists, skipping")
    rmemptyfile(filename) 
[docs]def file_write(data_array, f, timescale, directory, run, flg, mmag, boxsize):
    """
    Write a file of energy changes or momenta changes.
    Parameters
    ----------
    data_array: np.array
        shape of (x, y, 15) for momenta or (x, y, 4) for energy
    f: string
        Magnetisation "M" or Energy "E"
    run: int
        run number
    """
    def _write(f):
        filename = "{}{}Run{}Conservation_{}{}.txt".format(directory, "S_" if flg.suscep else "", run, f,
                                                           'LF' if flg.lastframe else "")
        if not path.isfile(filename) or run.endswith('nofin'):
            itern = arange(0, timescale['F'].shape[0])
            it_ti = block([[itern], [timescale['F']]])
            data = block([it_ti.T, data_array])
            savetxt(filename, data[-1][None, :] if flg.lastframe else data, fmt=format_string, header=titlestring)
        else:
            c.Error(">M Complete {} file exists, skipping".format("Momenta" if f == "M" else "Energy"))
    if f == "E":
        titlestring = "itern. time kinetic_e pot_e1 pot_e2 total_e"
        format_string = "%.0d" + " % 15.7e" * 5
        _write("E")
    if f == "M":
        titlestring = "itern. time total_m(x,y,z) mag_m(x,y,z) angular_m(x,y,z) CoM(x,y,z) CoM_vel(x,y,z)"
        format_string = "%.0d" + " % 15.7e" * 16
        _write("M")
    if f == 'X':
        xyz_write(data_array, timescale['X'], directory, run, flg, mmag, boxsize) 
[docs]def txt_data(moldata_dict, no_mol):
    """
    Write output as textfile.
    Specify "txt"
    Parameters
    ----------
    moldata_dict: dict
        dictionary of data to be written
    no_mol: int
        Number of Moelcules
    """
    newfile = (not path.isfile(moldata_dict['name']))
    stringfile = ''
    if newfile:
        stringfile += "# {:8s} = {:12} []\n".format("no_mol", no_mol + 1)
        stringfile += "# {:8s} = {:12.5g} [s]\n".format('t0', moldata_dict['vars']['t0'])
        stringfile += "# {:8s} = {:12.5g} [s]\n".format('dt', moldata_dict['vars']['dt'])
        stringfile += "# {:8s} = {:12} []\n".format('nmax', moldata_dict['vars']['nmax'])
        stringfile += "# {:8s} = {:12} []\n".format('skip', moldata_dict['vars']['skip'])
        stringfile += "# {:8s} = {:12.5g} []\n".format('stats', moldata_dict['vars']['stats'])
        stringfile += "\n"
        stringfile += "# {:8s} = {:12.5g} [K]\n".format('temp', moldata_dict['vars']['temp'])
        str_g = no_mol * "{:12.5g} "
        str_e = no_mol * "{:12.5e} "
        str_f = no_mol * "{:12.5f} "
        stringfile += "# {:8s} = {} [cm]\n".format('radius', str_g.format(*moldata_dict['vars']['radius']))
        stringfile += "# {:8s} = {} [g/cm^3]\n".format('dens', str_g.format(*moldata_dict['vars']['dens']))
        stringfile += "# {:8s} = {} [emu/cm^3]\n".format('ms', str_g.format(*moldata_dict['vars']['ms']))
        stringfile += "# {:8s} = {:12.5g} [oe]\n".format('H_0', moldata_dict['vars']['H_0'])
        stringfile += "# {:8s} = {:12.5g} [kHz]\n".format('nu', moldata_dict['vars']['nu'])
        stringfile += "# {:8s} = {:12.5g} [g/cm/s]\n".format('eta', moldata_dict['vars']['eta'])
        stringfile += "# {:8s} = {} []\n".format('chi0', str_g.format(*moldata_dict['vars']['chi0']))
        stringfile += "# {:8s} = {} [s]\n".format('tauB', str_g.format(*moldata_dict['vars']['tauB']))
        stringfile += "\n"
        stringfile += "# {:8s} = {} [cm^3]\n".format('volume', str_e.format(*moldata_dict['vars']['vol']))
        stringfile += "# {:8s} = {} []\n".format('alpha', str_f.format(*moldata_dict['vars']['alpha']))
        stringfile += "# {:8s} = {} [1/oe/s]\n".format('geff', str_f.format(*moldata_dict['vars']['geff']))
        stringfile += "# {:8s} = {} [oe]\n".format('c2', str_f.format(*moldata_dict['vars']['c2']))
        stringfile += "\n"
        with suppress(KeyError):
            stringfile += "# {:8s} = {:12.5g} [s]\n".format('tauN', moldata_dict['vars']['tauN'])
            stringfile += "# {:8s} = {:12.2e} [erg/cm^3]\n".format('keff', moldata_dict['vars']['keff'])
            stringfile += "# {:8s} = {:12.2e} [Hz]\n".format('nu_0', moldata_dict['vars']['nu_0'])
        stringfile += "\n"
        stringfile += "# {:8s} = {:12.5f} []\n".format('theta', *moldata_dict['vars']['theta'])
        stringfile += "# {:8s} = {:12.5f} []\n".format('phi', *moldata_dict['vars']['phi'])
        stringfile += "# {:8s} = {:12.5f} []\n".format('boxsize', moldata_dict['vars']['boxsize'])
        stringfile += "\n"
        stringfile += "# mol, iteration, time, neel, magnetisation (3 comp.), position, forces, momentum\n"
    for a, j in enumerate(array(moldata_dict['iter_time'])):
        for i in range(no_mol):
            mol = "mol_{}".format(i)
            k = int(moldata_dict['neel_relaxation'][mol][a]) if 'tauN' in moldata_dict['vars'].keys() else 0
            L = array_str(moldata_dict['magnetisation'][a, i])[1:-1]
            n = array_str(moldata_dict['position'][a, i])[1:-1]
            o = array_str(moldata_dict['forces'][a, i])[1:-1]
            p = array_str(moldata_dict['momentum'][a, i])[1:-1]
            stringfile += "{} {} {} {} {} {} {}\n".format(mol, str(j)[1:-1], k, L, n, o, p)
        stringfile += '\n'
    with open(moldata_dict['name'], "a") as w:
        w.write(stringfile)
    rmemptyfile(moldata_dict['name'])