"""
The Main engine.
Iteration loop module for updating properties of particles.
"""
from numpy import (zeros, zeros_like, ravel, ones, errstate,
                   exp, sqrt, array,
                   where, append, maximum, allclose,
                   add, subtract, multiply, einsum)
from numpy.core.multiarray import c_einsum
from math import pi
from scipy.special import comb
from scipy.optimize import minimize
from functools import partial
from mango.constants import c, bin_count
from mango.boundaries import periodic
from mango.imports import getpotentials
from mango.debug import debug
[docs]class Communal():
    """Reused variables and methods for propogation."""
    def __init__(self, epsilon, sigma, tri, comb, limit):
        """Initialise variables."""
        self.epsilon = epsilon
        self.sigma = sigma
        self.triag_ind = tri
        self.ia = tri[0]
        self.ja = tri[1]
        self.limit = limit
        self.comb = comb
        self.rij = zeros(self.comb)
        self.set_attr(['lj', 'lj6', 'rij3', 'rij4', 'magi', 'magj', 'dmag'], zeros_like(self.rij))
        self.lind = zeros_like(self.rij, dtype=bool)
        self.hind = zeros_like(self.rij, dtype=bool)  # self.hind[:] = ~lind
[docs]    def set_attr(self, var_list, val):
        """Initialise multiple arrays with same shape."""
        for i in var_list:
            setattr(self, i, val.copy())  
[docs]class countneel(object):
    """
    Counting and storing neel relaxations.
    initialise with timestep(dt) and Neel relaxation time (tauN)
    run with each magnetisation step
    """
    def __init__(self, var, dt):
        """Initialise variables."""
        with errstate(all='raise'):
            try:
                sigmaN = exp((var.keff * var.vol) / (c.KB * (var.temp - var.temp0)))
                var.tauN = var.tauN_0 * sigmaN
            except FloatingPointError:
                c.Error(">W OverflowError {} is infinite".format(c.ustr.taun))
                var.tauN = float("inf") * ones(var.no_molecules)
        self.count = zeros(var.no_molecules)
        self.tauN = var.tauN
        self.dt = dt
        self._modifydt(self.dt)
        self.no_mol = var.no_molecules
    def _modifydt(self, dt):
        self.A = exp(-dt / self.tauN)
        self.dt = dt
[docs]    def run(self, dt, mag):
        """Counter."""
        if dt != self.dt:
            self._modifydt(dt)
        self.mag = mag
        self.x = c.random.random(size=self.no_mol)
        ind = where(self.x > self.A)
        self.mag[ind] = -self.mag[ind]
        self.count[ind] += 1  
[docs]class position(Communal):
    """
    Main calculation class.
    Calling propogate will calculate 1 iteration
    """
    boundary = periodic()
    @debug(['init'])
    def __init__(self, **argd):
        """
        Initialise propgation routine.
        Sets up potentials and variables required.
        Optimises the positional configuration of the particles if required.
        dictionary keys: var, flg, mom, mag, pos, op_noise
        Parameters
        ----------
        argd: dict
            dictionary of data
        """
        self.__dict__.update(argd)
        self.count = 0
        self.mnorm = c.EPS_val + 1
        self.halfdt = self.var.dt * 0.5
        self.maxscf = 1 + self.var.scfiter * 2**5
        self.Mmag = self.var.ms * self.var.vol
        self.inv_Mmag = 1 / self.Mmag
        self.inv_Mmag2 = 1 / c_einsum('i,i -> i', self.Mmag, self.Mmag)
        self.invmass = 1 / self.var.mass
        self.invnatom = 1 / self.var.no_molecules
        self.inv_Mmag2_atm = self.inv_Mmag2 * self.invnatom
        self.sqEPS = c.EPS_val ** 2
        self.sqmnorm = zeros((0, 1))
        self.abslimit = c.EPS_val * 1.2
        self.prerange = range(0, 6)
        self.scfrange = range(0, self.var.scfiter)
        self.new_norm = [0, 0]
        self.sqmnorm = zeros(1)
        self.stoc_constant = 6 * pi * self.var.eta * self.var.hradius * self.invmass  # gammat
        self.estrings()
        self.noise_setup = zeros((3, self.var.no_molecules, 3))
        self.set_attr(['c1t', 'c2t'], zeros(self.var.no_molecules))
        self.set_attr(['tot', 'tmag', 'angular'], zeros(3))
        self.set_attr(['force', 'Hext', 'dm_mag_old', 'dm_umag',
                       'dm_diff', 'dm_hfield', 'dm_Heff',
                       'pos_tmp', 'noise_tmp'], zeros((self.var.no_molecules, 3)))
        self.tri = c.tri
        self.comb = comb(self.var.no_molecules, 2, exact=True)
        self.set_pot()
        self.init_pot()
        if self.flg.opt:
            self.dist, self.dist_matrix = self.boundary.setup(self.var.boxsize, self.flg.pbc, self.pos, self.var.sigma, self.tri)
            self.optimise(toll=c.EPS, iprint=2 if self.flg.nout >= 5 else 0)
        if self.flg.neel:
            self.neel_count = countneel(self.var, self.halfdt)
            self.magnetisation_step = partial(self.neel, self.magnetisation_step)
        else:
            self.neel_count = None
    def _noise(self, noise_row):
        sqrt(c.KB * self.var.temp * self.var.mass * (1. - self.c1t**2), out=self.c2t)  # [A/fs]
        c_einsum("i, ij -> ij", self.c2t, self.noise_setup[noise_row, ...], out=self.noise_tmp)
        add(self.mom, self.noise_tmp, out=self.mom)
[docs]    def estrings(self):
        self.dm1 = "iz,i->iz"
        self.dm2 = "iz,iz->i"
        self.dm3 = 'i,xyz,ix,iy->iz'
        self.scf1 = "iz,iz,i->"
        self.p_s = "ij, i -> ij" 
    @staticmethod
    def _noop(*args):
        pass
[docs]    def set_pot(self):
        """Set new potentials if required."""
        if self.var.potential:
            new_pot = getpotentials(self.var.potential)
            functionlist = [f for f in dir(self) if not f.startswith('_')]
            for i in new_pot.__all__:
                if ("vv" in i or i == "energy_vars") and not hasattr(energy, "_" + i):
                    setattr(energy, "_" + i, getattr(energy, i))
                    setattr(energy, i, getattr(new_pot, i))
                if ("ff" in i or "hh" in i or i == "force_vars") and not hasattr(force, "_" + i):
                    setattr(force, "_" + i, getattr(force, i))
                    setattr(force, i, getattr(new_pot, i))
                if i in functionlist:
                    setattr(self, "_" + i, getattr(self, i))
                    setattr(self, i, getattr(new_pot, i)) 
[docs]    def init_pot(self):
        """Initialise potentials."""
        com = Communal(self.var.epsilon, self.var.sigma, self.tri, self.comb, self.var.limit)
        self.e = energy(com)
        self.f = force(com, self.var.no_molecules) 
[docs]    def initialconditions(self):
        """
        Set initial conditions of particles and boundary conditions.
        If MPI is used reinitialise potentials.
        """
        if c.comm.Get_size() > 1 and c.comm.Get_rank() != 0:
            self.set_pot()
            self.init_pot()
        if not hasattr(self.boundary, 'wrap'):
            self.boundary.setup(self.var.boxsize, self.flg.pbc, self.pos, self.var.sigma, self.tri)
        else:
            self.boundary.reset_wrap()
        if not self.flg.mag_sw:
            self.magnetisation_step = self._noop
        self.noise = self._noise if self.flg.noise else self._noop
        self.pos[:] = self.boundary.wrapping(self.pos)
        self.dist, self.dist_matrix = self.boundary.distance(self.pos)
        self.f.force_vars(self.dist, self.dist_matrix)
        self.force_step()
        self.prop = {"pos": self.pos, "mom": self.mom, "mag": self.mag,
                     "neel_count": self.neel_count, "forces": self.force} 
[docs]    def neel(self, magstep, dt):
        """Wrap magnetisation if neel relaxation is used."""
        self.neel_count.run(self.halfdt, self.mag)
        self.mag[:] = self.neel_count.mag
        magstep(dt)
        self.neel_count.run(self.halfdt, self.mag)
        self.mag[:] = self.neel_count.mag 
[docs]    @debug(['prop'])
    def propagate(self, Hext, noise_setup):
        """Calculate propagation loop."""
        self.Hext = Hext
        self.noise_setup = noise_setup
        # half position_step
        self.position_step(self.halfdt)
        # half momentum_step
        self.f.force_vars(self.dist, self.dist_matrix)
        self.momentum_step(self.halfdt)
        # half stochastic_step
        self.stochastic_step(self.halfdt, 1)
        # full magnetisation_step
        self.magnetisation_step(self.var.dt)
        # half stochastic_step
        self.stochastic_step(self.halfdt, 2)
        # half momentum_step
        self.momentum_step(self.halfdt)
        # half position_step
        self.position_step(self.halfdt) 
[docs]    @debug(['pos'])
    @boundary
    def position_step(self, dt):
        """Calculate the new position of the particle."""
        c_einsum(self.p_s, self.mom, self.invmass, out=self.pos_tmp)
        add(self.pos, self.pos_tmp * dt, out=self.pos) 
[docs]    def force_step(self):
        """Calculate the forces between atoms."""
        add(self.f._ff_trans_return(), self.f.ff_mag(self.mag), out=self.force) 
[docs]    @debug(["mom"])
    def momentum_step(self, dt):
        """Calculate the momentum of all the particles."""
        self.force_step()
        add(self.mom, self.force * dt, out=self.mom) 
[docs]    def stochastic_step(self, dt, noise_row):
        """Apply noise to system."""
        exp(- self.stoc_constant * dt, out=self.c1t)  # dimensionless
        c_einsum(self.p_s, self.mom, self.c1t, out=self.mom)
        self.noise(noise_row) 
[docs]    def optimise(self, toll=1.e-6, iprint=2):
        """Optimise particle position based on magnetisation."""
        def func(x):
            pos = x[:self.pos.size].reshape(self.pos.shape)
            mag = x[self.pos.size:self.pos.size + self.mag.size].reshape(self.mag.shape)
            d, dm = self.boundary.distance(pos)
            self.e.energy_vars(dm)
            self.e.vv_trans()
            # Possible bug with 1 np, vv_mag returns []
            return einsum('i ->', self.e.epot_tr + self.e.vv_mag(mag, d))
        def dfunc(x):
            pos = x[:self.pos.size].reshape(self.pos.shape)
            mag = x[self.pos.size:self.pos.size + self.mag.size].reshape(self.mag.shape)
            d, dm = self.boundary.distance(pos)
            # Note that the gradient with respect to the magnetisation
            # gives the magnetic field. This cannot go to zero because
            # of the constrained normalisation (see below).
            self.f.force_vars(d, dm)
            return -append(self.f._ff_trans_return(), self.f.hh_mag(mag))
        def normalisation(x):
            mag = x[self.pos.size:self.pos.size + self.mag.size].reshape(self.mag.shape)
            return c_einsum("iz,iz->i", mag, mag) - c_einsum("iz,iz->i", self.mag, self.mag)
        x0 = append(self.pos, self.mag)
        # x = fmin_slsqp(func, x0, fprime=dfunc, f_eqcons=normalisation, acc=toll, iter=2000, iprint=iprint)
        res = minimize(func, x0, jac=dfunc, method='slsqp',
                       constraints=dict(type='eq', fun=normalisation),
                       options=dict(maxiter=2000, iprint=iprint, ftol=toll))
        x = res.x
        self.pos[:] = x[:self.pos.size].reshape(self.pos.shape)  # unit: 1.e-6 cm
        if self.flg.op_thermal:
            c_einsum("i, ij -> ij", sqrt(self.var.mass * c.KB * self.var.temp), self.op_noise, out=self.mom)
        if all(self.Mmag > c.EPS_val):
            self.mag = x[self.pos.size:self.pos.size + self.mag.size].reshape(self.mag.shape)  # unit: 1.e-12 emu
            scale = self.Mmag / maximum(sqrt(c_einsum("iz,iz->i", self.mag, self.mag)), c.EPS_val)
            self.mag = c_einsum("iz,i->iz", self.mag, scale)  # unit: 1.e-12 emu 
[docs]    def deltamag(self, mag, dt, umag, hfield, _Heff):
        """Update magnetisation."""
        add(mag, self.mag, out=mag)
        c_einsum(self.dm1, mag, 1 / sqrt(c_einsum(self.dm2, mag, mag)), out=umag)  # dimensionless
        c_einsum(self.dm1, umag, self.Mmag, out=mag)  # unit: 1.e-12 emu
        # WARNING: The new mag is used just to compute hfield!
        add(self.f.hh_mag(mag), self.Hext, out=hfield)   # unit: 1.e6 oersted
        # sLLG is remaining lines
        # WARNING: alpha is negative! Heff = hfield - _Heff
        c_einsum(self.dm3, self.var.alpha, c.eijk, umag, hfield, out=_Heff)  # unit: 1.e6 oersted
        subtract(hfield, _Heff, out=hfield)
        c_einsum(self.dm3, self.var.geff, c.eijk, mag, dt * hfield, out=hfield)  # unit: 1.e-12 emu 
[docs]    def scfloop(self, mag, dt, umag, hfield, Heff, mag_old, diff, new_norm, count=0):
        """
        SCF loop to update magnetisation and check the magnetisation norm isn't changed.
        Most (more than 50%) of the simulation is spent here
        improvements to the for loop have the largest impact
        """
        for count in self.scfrange:
            self.deltamag(mag, dt, umag, hfield, Heff)
            add(self.mag, hfield, out=mag)  # unit: 1.e-12 emu
            subtract(mag, mag_old, out=diff)
            if c_einsum(self.scf1, diff, diff, self.inv_Mmag2_atm) <= self.sqEPS:
                return count
            mag_old[:] = mag  # unit: 1.e-12 emu
        # From here assumption function is slower, hopefully rarely used
        self.var.scfiter *= 2
        self.scfrange = range(count, self.var.scfiter)
        mnorm = sqrt(c_einsum(self.scf1, diff, diff, self.inv_Mmag2_atm))
        if allclose(new_norm[0], mnorm):
            new_norm[1] += 1
        else:
            new_norm[:] = mnorm, 0
        if self.var.scfiter >= self.maxscf:
            if new_norm[1] >= 2 and mnorm < 1.3 * self.sqEPS:
                c.Error(f"W Norm seems stuck {mnorm}, continuing")
                self.var.scfiter //= 32
                new_norm = [0, 0]
                return count
            else:
                c.Error(f"F Exceeded self consistant iteration hard limit {new_norm}, {mnorm}")
        c.Error(
            "W Maximum number of iterations exceeded doubling to {} [MD step {}, mnorm={} ]".format(
                self.var.scfiter, self.time.iter, mnorm))
        return self.scfloop(mag, dt, umag, hfield, Heff, mag_old, diff, new_norm, count) 
[docs]    def magnetisation_step(self, dt):
        """Update magnetisation."""
        mag = self.mag.copy()  # unit: 1.e-12 emu
        umag = self.dm_umag
        Heff = self.dm_Heff
        hfield = self.dm_hfield
        # iteration here
        self.count += self.scfloop(mag, dt, umag, hfield, Heff, self.dm_mag_old, self.dm_diff, self.new_norm.copy())
        #  whole step
        self.deltamag(mag, dt, umag, hfield, Heff)
        add(self.mag, hfield, out=self.mag)  # unit 1.e-12 emu
        self.count += 1  
[docs]class energy(Communal):
    """Energy potentials."""
    def __init__(self, communal, singleframe=True):
        """Initialise energy constants and arrays."""
        Communal.__init__(self, communal.epsilon, communal.sigma, communal.triag_ind, communal.comb, communal.limit)
        self.tr_estr = ['{0}, {0}, {0}, {0}, {0}, {0} -> {0}', '{0}, {0} -> {0}']
        self.mag_estr = ['{0}, {0}, {0} -> {0}', '{0}, {1} -> {0}', '{0}, {0} -> {1}']
        if singleframe:
            estr = ['i', 'ij']
            self.unitv = zeros((self.comb, 3))
            self.epot_tr = zeros(self.comb if self.comb > 1 else 1)
        else:
            estr = ['ai', 'aij']
            self.triag_ind = (slice(None), *self.triag_ind)
            self.ia = (slice(None), self.ia)
            self.ja = (slice(None), self.ja)
            self.unitv = zeros((*self.comb, 3))
            self.epot_tr = zeros(self.comb)
        self.tr_estr[0] = self.tr_estr[0].format(estr[0])
        self.tr_estr[1] = self.tr_estr[1].format(estr[0])
        self.mag_estr[0] = self.mag_estr[0].format(estr[0])
        self.mag_estr[1] = self.mag_estr[1].format(estr[1], estr[0])
        self.mag_estr[2] = self.mag_estr[2].format(estr[1], estr[0])
[docs]    def energy_vars(self, dm):
        """Set reused energy variables."""
        self.rij[:] = dm[self.triag_ind]  # unit: cm 
[docs]    def vv_trans(self):
        """Calculate translation potential energy."""
        # unit: erg
        self.lind[:] = where(self.rij > self.limit, True, False)
        if True in self.lind:
            self.lj[:] = self.sigma * self.rij
            c_einsum(self.tr_estr[0], self.lj, self.lj, self.lj, self.lj, self.lj, self.lj, out=self.lj6)
            c_einsum(self.tr_estr[1], self.lind, self.epsilon * (1. + 4. * self.lj6 * (self.lj6 - 1.)), out=self.epot_tr)  # unit: erg 
[docs]    def vv_mag(self, mag, d):
        """Calculate magnetic potential energy."""
        magia = mag[self.ia]
        magja = mag[self.ja]
        c_einsum(self.mag_estr[0], self.rij, self.rij, self.rij, out=self.rij3)  # = self.rij * self.rij * self.rij
        c_einsum(self.mag_estr[1], d[self.triag_ind], self.rij, out=self.unitv)  # dimensionless
        c_einsum(self.mag_estr[2], magia, self.unitv, out=self.magi)
        c_einsum(self.mag_estr[2], magja, self.unitv, out=self.magj)
        c_einsum(self.mag_estr[2], magia, magja, out=self.dmag)
        return -(3. * self.magi * self.magj - self.dmag) * self.rij3  # potential energy unit: erg  
[docs]class force(Communal, bin_count):
    """Force potentials."""
    __slots__ = ['epsilon', 'sigma', 'triag_ind', 'comb', 'limit',
                 'zeros', 'epsilon24', 'rijz', 'ftrans', 'fmag', 'unitv',
                 'hfia', 'hfja', 'hfia_temp', 'hfja_temp', 'fmag_temp',
                 'fstr1', 'fstr2', 'fstr3', 'fstr4', 'fstr5', 'fstr6'
                 'ia', 'ja', 'rij',
                 'lj', 'lj6', 'rij3', 'rij4', 'magi', 'magj', 'dmag',
                 'lind', 'hind',
                 'long_ia', 'long_ja',
                 'force_trans', 'force_transR',
                 'magi', 'magj', 'dmag']
    def __init__(self, communal, nm):
        """Initialise force constants and arrays."""
        Communal.__init__(self, communal.epsilon, communal.sigma, communal.triag_ind, communal.comb, communal.limit)
        self.zeros = zeros((nm, 3))
        self.set_attr(['rijz', 'ftrans', 'fmag',
                       'unitv', 'hfia', 'hfja',
                       'hfia_temp', 'hfja_temp', 'fmag_temp'], zeros((self.comb, 3)))
        self.epsilon24 = -24 * self.epsilon  # avoid unneed operation repetition
        self.estrings()
        self.bin_count_setup()
[docs]    def estrings(self):
        self.fstr1 = "i, i, i -> i"
        self.fstr2 = "ij, i -> ij"
        self.fstr3 = "i, i, i, i, i, i -> i"
        self.fstr4 = "iz,iz -> i"
        self.fstr5 = "i, ij -> ij"
        self.fstr6 = "i, i, ij, i-> ij" 
[docs]    def bin_count_setup(self):
        """
        bin_count setup.
        create variables needed for bincount
        """
        self.long_ia = ravel(array([self.ia * 3, self.ia * 3 + 1, self.ia * 3 + 2]), 'F')
        self.long_ja = ravel(array([self.ja * 3, self.ja * 3 + 1, self.ja * 3 + 2]), 'F')
        self.leng = self.zeros.size 
[docs]    def force_vars(self, d, dm):
        """Set reused force variables."""
        self.rij = dm[self.triag_ind]  # unit: cm
        self.rijz = d[self.triag_ind]
        c_einsum(self.fstr1, self.rij, self.rij, self.rij, out=self.rij3)
        self.rij4[:] = self.rij3 * self.rij
        c_einsum(self.fstr2, self.rijz, self.rij, out=self.unitv)  # dimensionles
        self.ff_trans() 
[docs]    def ff_trans(self):
        """
        Create variables for all forces.
        forces are updated twice without positional changes, therefore no change in translational force
        """
        # translational force
        self.force_trans = self.zeros.copy()
        self.force_transR = self.force_trans.ravel()
        self.lind[:] = where(self.rij > self.limit, True, False)
        if True in self.lind:
            self.lj[:] = self.sigma * self.rij
            c_einsum(self.fstr3, self.lj, self.lj, self.lj, self.lj, self.lj, self.lj, out=self.lj6)
            c_einsum(self.fstr6, self.lind, self.epsilon24 * (2. * self.lj6 - 1.) * self.lj6,
                     self.rijz, self.rij * self.rij, out=self.ftrans)
            ftransR = self.ftrans.ravel()
            self.subtractat(self.force_transR, self.long_ia, ftransR)
            self.addat(self.force_transR, self.long_ja, ftransR) 
    def _ff_trans_return(self):
        """Calculate Translational Force."""
        return self.force_trans
[docs]    def ff_mag(self, mag):
        """Calculate Magnetic force."""
        force = self.zeros.copy()   # unit: dine
        forceR = force.ravel()
        magia = mag[self.ia]
        magja = mag[self.ja]
        c_einsum(self.fstr4, magia, self.unitv, out=self.magi)
        c_einsum(self.fstr4, magja, self.unitv, out=self.magj)
        c_einsum(self.fstr4, magia, magja, out=self.dmag)
        subtract(5. * self.magi * self.magj, self.dmag, out=self.dmag)
        subtract(c_einsum(self.fstr5, self.dmag, self.unitv),
                 c_einsum(self.fstr5, self.magi, magja), out=self.fmag_temp)
        subtract(self.fmag_temp, c_einsum(self.fstr5, self.magj, magia), out=self.fmag_temp)
        c_einsum(self.fstr2, self.fmag_temp, self.rij4, out=self.fmag)  # unit: dine
        self.fmag *= 3.
        fmagR = self.fmag.ravel()
        self.subtractat(forceR, self.long_ia, fmagR)
        self.addat(forceR, self.long_ja, fmagR)
        return force 
[docs]    def hh_mag(self, mag):
        """Calculate Applied field."""
        # Possible to split this function into two concurrent streams
        # Most calculation time is spent here (up to 50%)
        hfield = self.zeros.copy()  # unit: oersted
        hfieldR = hfield.ravel()
        magia = mag[self.ia]
        magja = mag[self.ja]
        c_einsum(self.fstr4, magia, self.unitv, out=self.magi)
        c_einsum(self.fstr4, magja, self.unitv, out=self.magj)
        subtract(3. * c_einsum(self.fstr5, self.magj, self.unitv), magja, out=self.hfia_temp)
        subtract(3. * c_einsum(self.fstr5, self.magi, self.unitv), magia, out=self.hfja_temp)
        c_einsum(self.fstr2, self.hfia_temp, self.rij3, out=self.hfia)  # unit: oersted
        c_einsum(self.fstr2, self.hfja_temp, self.rij3, out=self.hfja)  # unit: oersted
        self.addat(hfieldR, self.long_ia, self.hfia.ravel())
        self.addat(hfieldR, self.long_ja, self.hfja.ravel())
        return hfield  
if __name__ == "__main__":
    pass