from numpy import (arange, array, mean, histogram,
                   sqrt, einsum, exp, float64, ones_like,
                   var, savez_compressed)
from scipy.optimize import curve_fit
from scipy.integrate import simps
from math import gamma
from os import listdir
# from numexpr import evaluate
# from sys import platform as _platform
# from mayavi_xyz import setup as mayavi
# from multiprocessing import cpu_count
# import subprocess
from mango.constants import c
from mango.io import read_data, file_write  # notfinished
import mango.imports as imports
from mango.pp.util import (Dataarrays, getvars, onevent, showgraph, loader)
from mango.pp.acf import get_suscep, save_suscep
from mango.pp.com import CoM_motion_rm
from mango.pp.plot import plot_col, plotter, lengthcheckplot
from mango.debug import debug
[docs]@debug(['io'])
def readin(fname, run, directory, block, flg, fig_event=None):
    """
    Manipulate Data.
    This function reads in data from file and then produces visualisation or files as required
    Possible options from input flags:
    * Extended XYZ file (Positions, Magnetisations, Forces)
    * Suceptibility (Data needs to be collected specifically for the purpose)
    * Data comparison plots, plot set(s) of data against time etc.
    * Kinetic Temperature
    * Chain length check
    Parameters
    ----------
    fname: string
        filename (no extension)
    run: int
        run number
    directory: string
        file directory
    flg: class
        Required flags: save_type, column, suscep, files, kinetic, lengthcheck,
        align, showg, saveg, save_type, ufile,
    """
    if flg.save_type == "txt":
        c.Error("F Postprocessing is not implemented for this filetype and process")
    mpl = imports.matplotlib()
    if not flg.ufile:
        names = [name for name in listdir(directory)
                 if name.startswith(fname[len(directory):-2]) and not name.endswith(("svg", "dat"))]
        reader = read_data("{}{}".format(directory, names[0]), flg.save_type)
    else:
        reader = read_data(*flg.ufile[0].rsplit(".", 1), xyzpp=flg.ufile[0].endswith('xyz'))
    name = reader.fname
    var, incomplete_run, tauN_V = get_ppvars(flg, reader, extras=True)
    (timescale, datalen, xyz_data,
     energy_data, momenta_data, sus_data, stats) = collectdata(reader, name, flg, var, flg.eq)
    if (flg.suscep or flg.column) and flg.showg:
        fig_event = onevent(mpl)
    # create datafiles
    for i in flg.files:
        if flg.files[i]:
            data = locals()[f'{i}_data']
            kind = i[0].upper()
            # paralisable?
            for no, i in enumerate(stats):
                file_write(data[no, ...], kind, timescale, directory,
                           "{}.{}{}".format(run, i, "nofin" if incomplete_run else ""), flg, mmag=var.ms * var.vol, boxsize=var.boxsize)
    # plot raw data graphs
    if flg.column:
        location = {"xyz": {"position": xyz_data[..., 0:3],
                            "momentum": xyz_data[..., 3:6],
                            "magnetisation": xyz_data[..., 6:9],
                            "forces": xyz_data[..., 9:12]} if flg.files['xyz'] else None,
                    "energy": {"Etotal": energy_data[..., 3],
                               "mag_pot": energy_data[..., 2],
                               "trans_pot": energy_data[..., 1],
                               "kinetic": energy_data[..., 0]} if flg.files['energy'] else None,
                    "momenta": {"Mtotal": momenta_data[..., 0:3],
                                "total_mag": momenta_data[..., 3:6],
                                "total_angular": momenta_data[..., 6:9],
                                "CoM": momenta_data[..., 9:12],
                                "CoM_vel": momenta_data[..., 12:15]} if flg.files['momenta'] else None,
                    "time": timescale}
        plot_col(mpl, fig_event, flg, location, directory, run, stats, var.temp, var.no_mol)
    # compute the atc functions
    if flg.suscep and not isinstance(flg.suscep, bool):
        atc_data = get_suscep(flg, var.skip, var.dt, sus_data, block,
                              stats, var.mass, var.vol, var.ms)
        for bl in atc_data.keys():
            for i in atc_data[bl].keys():
                save_suscep(flg, directory, run, var.no_mol, bl,
                            **atc_data[bl][i])
    # kinetic temperature
    if flg.kinetic:
        baseline, histo, err = kinetic_temperature(flg.align, name, var.mass, xyz_data[..., 3:6])
        xy = [[baseline, 'units'], [histo, 'units', err]]
        xyname = [array([['x'], ['']]), array([['y'], ['']])]
        plotter(mpl, fig_event, flg, xy, xyname, 2, directory, run)
    # Distance between pairs
    if flg.lengthcheck:
        ld = loader(flg, run, directory)
        for bl in block:
            try:
                out = ld.lengthcheck(bl)
                c.Error("M Loaded length data from npz file")
            except FileNotFoundError:
                out = lengthcheck(flg, run, directory, xyz_data[..., 0:3], bl)
            print(out['overallmean'], out['overallerr'])
            lengthcheckplot(mpl, fig_event, flg, directory, var.no_mol, run, timescale['X'], var.radius, **out)
    showgraph(mpl, flg.showg) 
[docs]def get_ppflags(reader):
    """Get post processing flags if stored in save file."""
    return getvars(reader.read("flags"),
                   ['files', 'kinetic', 'lengthcheck', 'save_type',
                    'ufile', 'suscep', 'prog']) 
[docs]def get_ppvars(flg, reader, extras=False):
    """Get post processing variables."""
    variables = reader.read("vars")
    var = getvars(variables,
                  ['boxsize', 'no_molecules', 'stats', 'dt',
                   'nmax', 'vol', 'radius',
                   'ms', 'temp', 'mass', 'skip', 'written',
                   'skip_iters', 'epsilon', 'sigma', 'limit'])
    if 'tauN' in variables:
        tauN_V = getvars(variables, ['tauN_0', 'keff', 'temp0'])
        flg.neel = True
    else:
        flg.neel = False
        tauN_V = ''
    var.skip_iters += 1
    incomplete_run = False
    if var.written != var.skip_iters:
        if var.written < var.skip_iters:
            c.Error(">W Incomplete run only {}/{} steps completed".format(var.written, var.skip_iters))
            incomplete_run = True
        var.skip_iters = var.written
        var.nmax = var.skip_iters * var.skip
    return var if not extras else (var, incomplete_run, tauN_V) 
[docs]def collectdata(reader, name, flg, var, eq=0.1):
    """
    Collect data from files as required.
    Parameters
    ----------
    reader: instance
        file reader
    name: str
        file name
    flg: instance
        flg storage instance
        requires - files kinetic lengthcheck save_type ufile suscep align prog
    var: instance
        requires - dt, no_mol, stats, skip, skip_iters, epsilon, sigma, limit, mass
    eq: float
        equilibration fraction
    Returns
    -------
    timescale
    datalen
    xyz_data
    energy_data
    momenta_data
    sus_data
    stats
    """
    da = Dataarrays(flg, reader, var, name, eq)
    timescale = {'X': (arange(0, da.var.skip_iters) * var.dt * var.skip)[int(da.var.skip_iters * eq):],
                 'F': (arange(0, da.var.skip_iters) * var.dt * var.skip)}
    for no, (stat, name) in enumerate(da.names.items()):
        if flg.prog:
            c.progress.sections = da.datalen
            c.progress.bar("Reading In Data", (stat + 1 / da.datalen))
        else:
            print(name + " ...", end='', flush=True)
        da.xyz(no, name)
        da.energy(no, name)
        da.momenta(no, name)
        da.sus(no, name)
        if not flg.prog:
            print("Done")
    print("\n")
    CoM = CoM_motion_rm(var.mass, var.no_mol)
    if flg.align:
        CoM.align(da.xyz_data, da.sus_data)
    elif any(anin in flg.suscep if not isinstance(flg.suscep, bool) else False for anin in ['angular', 'inertia']):
        da.sus_data.angular, da.sus_data.inertia = CoM.inertia_angular(da.xyz_data, da.sus_data)
    return timescale, da.datalen, da.xyz_data, da.energy_data, da.momenta_data, da.sus_data, da.var.stats 
[docs]def kinetic_temperature(align, basename, mass, mom, equi=0):
    """
    Kinetic Temperature.
    Calculate the kinetic temperature of the system and write to file.
    Parameters
    ----------
    TODO
    """
    # Mass rescaling
    dummy = einsum("ijk...,k->ijk...", mom[equi:], 1 / sqrt(mass))
    # Speed
    dummy = sqrt(einsum("ijkl,ijkl->ijk", dummy, dummy))
    # Histogram
    histo, edges = histogram(dummy, bins='fd', density=False)  # dummy is flatten by default
    baseline = 0.5 * (edges[1:] + edges[:-1])
    histo = histo.astype(float64)
    err = sqrt(histo)  # Poisson's estimate
    norm = simps(histo, baseline)
    histo /= norm
    err /= norm
    # The multidimensional Maxwell-Boltzmann distribution
    # v^(n-1)*exp(-v^2/2)/2^(n/2-1)/gamma(n/2)
    def func(x, amb, dim):
        return x**(dim - 1) * exp(-0.5 * (x / amb)**2) / (amb**dim) / (2**(0.5 * dim - 1)) / gamma(0.5 * dim)
    dim = 3.
    amb = (gamma(0.5 * dim) / gamma(0.5 * (dim + 1.)) / sqrt(2.)) * mean(dummy)  # amb = sqrt(kB T)
    popt, pcov = curve_fit(func, baseline, histo, p0=[amb, dim])
    amb = popt[0]
    dim = popt[1]
    temp_mb = amb**2 / c.KB
    print("# Kinetic temperature (histogram): %12.5f [K]\n" % (temp_mb))
    print("# Kinetic dimension (histogram): %12.5f []\n" % (dim))
    if align:
        name = basename + "_HISTO_ALIGNED.dat"
    else:
        name = basename + "_HISTO.dat"
    with open(name, 'w') as file:
        for b, h, e in zip(baseline, histo, err):
            file.write("  %12.5e  %12.5e  %12.5e  %12.5e\n" % (b, h, e, func(b, amb, dim)))
    return baseline, histo, err 
[docs]def lengthcheck(flg, run, directory, pos, block):
    """
    Length of chain check.
    Assumes the particles furthest away from each other are particle 0 and N
    Improvements:
    generalise for systems with multiple chains
    generalise for systems without chains eg furthest distance between particles in a cluster
    """
    # If attempted manually diff will be 8*EPS2 sometimes due to dx**2+dy**2+dz**2
    dummy = ones_like(pos)
    dist = einsum("kliz,kljz->klijz", pos, dummy) - einsum("kliz,kljz->kljiz", pos, dummy)
    dm = sqrt(einsum('klijz,klijz->klij', dist, dist))
    dmeandm = mean(dm, axis=0)
    vardm = var(dm, ddof=1, axis=0) / block
    errdm = sqrt(vardm)
    overallmean = mean(dmeandm, axis=0)
    overallerr = sqrt(mean(vardm, axis=0))
    name = "{}{}_distance_arrays_{}{}".format(directory, "{}Run{}".format(
        'S_' if flg.suscep else '', run), block, "_ALIGNED" if flg.align else '')
    savez_compressed(name, dmeandm=dmeandm, errdm=errdm, overallmean=overallmean, overallerr=overallerr)
    return {"dmeandm": dmeandm, "errdm": errdm, "overallmean": overallmean, "overallerr": overallerr} 
if __name__ == "__main__":
    pass