from numpy import (arange, array, conjugate, where, block,
                   mean, amax, zeros_like, zeros, var, sum as nsum, linalg,
                   sqrt, savetxt, einsum, exp, log, log2, diag)
from scipy.optimize import fminbound, curve_fit
from scipy.interpolate import interp1d
from math import pi
from os import path
from contextlib import suppress
from mango.constants import c, nestedDict, asciistring
from mango.pp.util import at_least2d_end # debye,
import mango.imports as imports
from mango.debug import debug
# from numpy import seterr
# seterr(all='raise')
[docs]def get_suscep(flg,
               skip, dt, sus_data, blocks,
               stats, mass, vol, ms):
    """
    Calculate the ACF of data for different properties.
    Parameters
    ----------
    Returns
    -------
    atc_data: dict
        A dictionary of all the calculated ACF
    """
    autocorr_labels = {"mag": {"x": r'$\omega$ [$10^{{{}}}$]',
                               "y": r"MDOS [$10^{{{}}}$]",
                               "type": "MDOS"},
                       "vel": {"x": r'$\omega$ [$10^{{{}}}$]',
                               "y": r"VDOS [$10^{{{}}}$]",
                               "type": "VDOS"},
                       "angular": {"x": r'$\omega$ [$10^{{{}}}$]',
                                   "y": r"ADOS [$10^{{{}}}$]",
                                   "type": "ADOS"},
                       "inertia": {"x": r'$\omega$ [$10^{{{}}}$]',
                                   "y": r"IDOS [$10^{{{}}}$]",
                                   "type": "IDOS"},
                       "rotation": {"x": r'$\omega$ [$10^{{{}}}$]',
                                    "y": r"RDOS [$10^{{{}}}$]",
                                    "type": "RDOS"},
                       "mag_ifft": {"x": r"Time [$10^{{{}}}$s]",
                                    "y": r"MACF [$10^{{{}}}$]",
                                    "type": "MACF"},
                       "vel_ifft": {"x": r"Time [$10^{{{}}}$s]",
                                    "y": r"VACF [$10^{{{}}}$]",
                                    "type": "VACF"},
                       "angular_ifft": {"x": r"Time [$10^{{{}}}$s]",
                                        "y": r"AACF [$10^{{{}}}$]",
                                        "type": "AACF"},
                       "inertia_ifft": {"x": r"Time [$10^{{{}}}$s]",
                                        "y": r"IACF [$10^{{{}}}$]",
                                        "type": "IACF"},
                       "rotation_ifft": {"x": r"Time [$10^{{{}}}$s]",
                                         "y": r"RACF [$10^{{{}}}$]",
                                         "type": "RACF"}}
    if 'vel' in flg.suscep:
        sus_data["vel"] = einsum("hijk, j -> hijk", sus_data["mom"], 1 / mass)
    atc_data = nestedDict()
    autocorr = Autocorr(flg, sus_data, dt, skip, ms * vol)
    autocorr.conv(len(stats))
    for bl in blocks:
        autocorr.analysis(blocks=bl)
        for i in flg.suscep:
            atc_data[bl][i] = {"xaxis": autocorr.freq, "yaxis": autocorr.ret[i]['data'], "stddev_y": autocorr.ret[i]['error'],
                               # "popt": tau, "popt_fitted": ad_tau.popt_fit,
                               "autocorr_labels": autocorr_labels[i], 'scale': autocorr.ret[i]['scale']}
            iatc = i + '_ifft'
            atc_data[bl][iatc] = {"xaxis": autocorr.time, "yaxis": autocorr.ret[i]['ifft'], "stddev_y": autocorr.ret[i]['ifft_error'],
                                  # "popt": tau, "popt_fitted": ad_tau.popt_fit,
                                  "autocorr_labels": autocorr_labels[iatc], 'scale': autocorr.ret[i]['scale']}
    return atc_data 
[docs]class Autocorr():
    """Calculates the ACF and MDOS of provided data."""
    @debug(['acf'])
    def __init__(self, flg, sus_data, dt, skip, mmag):
        """
        Store data to be processed.
        Parameters
        ----------
        flg: instance
            flags class instance
        sus_data: dict
            data dictionary
        dt: float
            timestep
        skip: int
            number of iterations skipped
        mmag: float
            Mmag
        """
        self.flg = flg
        self.dt = dt
        self.skip = skip
        self.mmag = mmag
        self.fft, self.ffts, self.ifft, self.byte_align = imports.fftw()
        self.shape = sus_data[list(sus_data.keys())[0]].shape
        self.ret = {}
        self.sus_collect = {}
        self._shape_data(sus_data)
        if hasattr(self, 'DBG'):
            self.sus_data = sus_data.copy()
    def _calc(self, func, *args):
        # Paralisable loop c.processors/numkeys
        for autocorr_type in self.sus_collect.keys():
            func(autocorr_type, *args)
        if hasattr(self, 'DBG'):
            from mango.pp.acf_checker import checker, autoold
            ret2, freq2, time2 = autoold(c.Error, self.sus_data, self.dt, self.skip, self.mmag, blocks=self.blocks)
            checker(self.ret, ret2, self.freq, freq2, self.time, time2)
[docs]    def conv(self, no_stats, RMS=True):
        """Carry out blocking convergence tests."""
        self._calc(self._blockingconvergence, no_stats, RMS) 
[docs]    def analysis(self, blocks):
        """
        Carry out DOS and ACF calculation for provided data and blocks.
        Parameters
        ----------
        blocks: int
            number of blocks
        """
        self.sample = self.shape[0] * self.shape[1] // blocks
        self.sample_pad = self.ffts.next_fast_len(2 * self.sample - 1)
        self.freq = 2.0 * pi * self.ffts.fftshift(self.ffts.fftfreq(self.sample_pad, d=(self.dt * self.skip)))
        self.norm_val = self.freq[1] - self.freq[0]
        # arange(-self.sample + 1, self.sample) * self.dt * self.skip
        self.time = 2. * pi * self.ffts.fftshift(self.ffts.fftfreq(self.sample_pad, d=self.norm_val))
        self.blocks = blocks
        self._calc(self._blockinganalysis) 
    def _shape_data(self, sus_data):
        """
        Shape arrays dependent on type.
        Parameters
        ----------
        sus_data: dict
            dict of data arrays
        """
        reshape = []
        for autocorr_type in self.flg.suscep:
            # if autocorr_type in ['chi']:
            #     continue
            # else:
            value = sus_data[autocorr_type]
            if autocorr_type in ['mag']:
                # Sum over the number of atoms
                # Flatten and refold -> Blocking will only trim from
                # first trajetory for correct dimensions
                # data =
                # stddev = self._get_av_err(data, av=False) # - (stddev if self.flg.align else 0)
                #  - (mean(einsum("il, il ->il", data[:, 0], data[:, 0]), axis=1) if self.flg.align else 0
                data = self._remove_av(einsum('ijkl->ijl', value)).reshape(-1, 3)
            elif autocorr_type in ["rotation"]:
                if self.flg.align:
                    data = self._remove_av(einsum('ijkl, il -> ijk', value,
                                                  mean(einsum("ijkz ->ijz", sus_data["mag"]), axis=1))).reshape(-1, 3)
                else:
                    c.Error("W rotation autocorrelation only works with align=True.")
            else:
                data = self._remove_av(value).reshape([-1, *value.shape[2:]])
            if autocorr_type in ['inertia', "vel"]:
                reshape = [-1, 3]
            else:
                reshape = [3]
            print(autocorr_type, data.shape)
            e_string = "{0}, {0} -> ".format(''.join(asciistring(data.ndim)))
            self.sus_collect[autocorr_type] = {'data': data, 'scale': 2 * pi * einsum(e_string, data, data) / data.size,
                                               'reshapeto': reshape}
    @staticmethod
    def _remove_av(value):
        """
        Remove the average over the iterations before any modifcation.
        Trajectory may not be fully relaxed
        Removing the mean after blocking leads to spurious errors
        """
        return value - mean(value, axis=1)[:, None]
    def _fft_av_err(self, dummy, scale, blocks):
        """Scale fft errors."""
        ave, err = self._get_av_err(dummy, blocks)
        print(ave[0])
        return scale * self.ffts.fftshift(ave), scale * self.ffts.fftshift(err)
    @staticmethod
    def _get_av_err(dummy, blocks=1, axis=0, av=True):
        """
        Average data over blocks and calculate errors.
        Parameters
        ----------
        dummy: array
            data array
        blocks: int
            number of blocks
        Returns
        -------
        ave: array
            averaged data array
        err: array
            standard deviation of averaged data
        """
        if dummy.shape[0] == 1:
            err = zeros(dummy.shape[1:])
        else:
            err = sqrt(var(dummy, ddof=1, axis=axis) / blocks)
        return (mean(dummy, axis=axis), err) if av else err
    def _blockingconvergence(self, autocorr_type, no_stats, RMS=True):
        """
        Convergence test for blocking.
        Parameters
        ----------
        autocorr_type: str
            autocorr name
        RMS: bool
            Root mean squared or mean
        """
        print('# Block averages {}'.format('magnetic moment [emu]' if autocorr_type == 'mag' else autocorr_type))
        print('# Block, sample, average, error')
        data = self.sus_collect[autocorr_type]['data']
        if RMS:
            def RMSwrap(data, **kw):
                e_string = "{0}, {0} -> ab".format(''.join(asciistring(data.ndim)))
                return sqrt(mean(einsum(e_string, data, data), **kw))
        else:
            def RMSwrap(data, **kw):
                e_string = "{} -> ab".format(''.join(asciistring(data.ndim)))
                return mean(einsum(e_string, data), **kw)
        for local_sample in (no_stats * 2**x for x in range(int(log2(self.shape[0] * self.shape[1]) - log2(self.shape[0])))):
            no_blocks = self.shape[0] * self.shape[1] // local_sample
            blocked = data[:local_sample * no_blocks].reshape([no_blocks, local_sample, *self.sus_collect[autocorr_type]['reshapeto']])
            reblock = RMSwrap(blocked, axis=1)
            ave, err = self._get_av_err(reblock, no_blocks)
            print(local_sample, no_blocks, ave, err)  # if local_sample > 1 else 0)
    @staticmethod
    def _get_estring(ndim):
        """
        Return the 3 character strings for einsum of DOS.
        eg 3 dimensions would return:
        dos = 'abc'
        cdos = 'abd'
        fstr = 'abcd'
        Parameters
        ----------
        ndim: int
            number of dimensions of array
        Returns
        -------
        dos_str: str
            string of char length ndim
        cdos_str: str
            string of char length ndim
        fstr: str
            string of char length of ndim + 1
        """
        chr_list = asciistring(ndim + 1)
        dos_str = ''.join(chr_list[:-1])
        cdos_str = ''.join(chr_list[:-2] + [chr_list[-1]])
        return dos_str, cdos_str, ''.join(chr_list)
    @staticmethod
    def _pmsum(data):
        # TODO per particle correleation
        # if shape[1] == no_mol and shape[0] == iter (?)
        # if data.ndim == 3:
        #     return einsum("ijk -> ik", data)
        return data
    def _blockinganalysis(self, autocorr_type):
        """
        DOS and ACF calculation.
        creates storage dictionary 'ret'
        Parameters
        ----------
        autocorr_type: str
            autocorr name
        """
        scale = self.sus_collect[autocorr_type]['scale']
        reshapeto = [self.blocks, self.sample, *self.sus_collect[autocorr_type]['reshapeto']]
        # Scale
        # if autocorr_type is 'mag':  # 2*pi*<M*M>/3
        #     # scale = 2. * pi * einsum("i, i -> ", self.mmag, self.mmag) / (3. * self.mmag.size)
        #     scale = einsum("ijk, ijk -> ", mxyz, mxyz) / mxyz.size  # This is equal to <M*M>/3 for independent particles
        # else:
        # 2*pi*<v*v>/3, 2*pi*<M*M>/3
        # Blocking
        print(autocorr_type, self.sus_collect[autocorr_type]['data'].shape)
        mxyz = self.sus_collect[autocorr_type]['data'][:self.sample * self.blocks].reshape(reshapeto)
        self.ret[autocorr_type] = {'scale': scale}
        print('DOS')
        # DOS
        dos = self.fft(self.byte_align(mxyz), n=self.sample_pad, axis=1, threads=c.processors)
        print("DOne")
        # modulus of the data gives separate real and imaginary parts
        dos = self._pmsum(dos)
        dos_str, cdos_str, f_str = self._get_estring(dos.ndim)
        mod_dos = einsum("{}, {} -> {}".format(dos_str, cdos_str, f_str), dos, conjugate(dos))
        # Calculate eigenvalues
        eig_dos = linalg.eigvalsh(mod_dos)
        # Normalisation: Integral of DOS equal 1
        norm = self.norm_val * einsum('{} -> '.format(dos_str), eig_dos).real / self.blocks
        eig_dos /= norm
        print('IDOS')
        # Normalisation: ACF(0) = 1
        ifftx = 2 * pi * self.ifft(self.byte_align(eig_dos), axis=1, threads=c.processors) / (self.dt * self.skip)
        print("DOne")
        # Error and averaging
        (self.ret[autocorr_type]['data'],
         self.ret[autocorr_type]['error']) = self._fft_av_err(eig_dos, scale, self.blocks)
        (self.ret[autocorr_type]['ifft'],
         self.ret[autocorr_type]['ifft_error']) = self._fft_av_err(ifftx, scale, self.blocks) 
[docs]@debug(['save'])
def save_suscep(flg, directory, run, no_mol, cblock, xaxis, yaxis, stddev_y,
                autocorr_labels, scale, line_d={}, **kwargs):
    """
    Save the susceptibility calculation.
    Saves autocorrelation data in raw text,
    The columns are shown below with yaxis and std_y will always be 3 columns:
    xaxis yaxis... std_y...
    Parameters
    ----------
    directory: str
        save location
    run: int
        run number
    no_mol: int
        number of particles
    cblock: int
        current block length
    xaxis: ndarray
        x data
    yaxis: ndarray
        y data
    stddev_y: ndarray
        y error data
    autocorr_labels: dict
        dictionary of x and y labels
    scale: float
        scale value for data
    """
    # if (path.isfile(figname) or path.isfile("{}{}.dat".format(figname, "ALIGN" if flg.align else ''))) and flg.saveg:
    #     c.Error('W File {} exists, skipping'.format(figname.rsplit('/', 1)[-1]))
    #     return None
    xaxis, yaxis, stddev_y = at_least2d_end(xaxis, yaxis, stddev_y)
    if ('DOS' in autocorr_labels["type"]) or ('ACF' in autocorr_labels["type"]):
        header = f'scale factor: {scale}'  # Remove 2 pi which is from Fourier transform definintion
    data = block([xaxis, yaxis, stddev_y])
    print(data.shape, xaxis.shape, yaxis.shape, stddev_y.shape)
    name = "{}S_Run{}_mol-{:g}_autocorr_{}_blk{}".format(directory, run, no_mol, autocorr_labels["type"], cblock)
    savetxt("{}{}.dat".format(name, "ALIGN" if flg.align else ''), data, header=header)