"""
Magnetic motion setup.
Set up environment and start calculation loop
"""
# External Dependencies
from numpy import (array, zeros, zeros_like,
                   sin, cos, ones, maximum,
                   sqrt, errstate, where, einsum)
from numpy.core.multiarray import c_einsum
from copy import deepcopy
from scipy import linalg
from math import pi
from contextlib import suppress
# Internal Dependencies
from mango.constants import c, nestedDict
from mango.position import position
from mango.io import write_data
from mango.initpositions import initpos
from mango.time import _time, end, grace
from mango.multiproc import mp_handle
from mango.debug import debug, profile
from mango.managers import addpid
[docs]def calc_wrap(calc, stat):
    """
    Wrap calculate to allow decoration in multiprocessing.
    Could be used in future for any function
    """
    with suppress(KeyboardInterrupt):
        return profile(calc.run, stat, file="./wprofiler.prof") if c.profile and calc.flg.parallel else calc.run(stat) 
[docs]def field(time, H0=167.0, nu=267.0e3):
    """Calculate the phase of the external magnetic field."""
    if nu > c.EPS:
        H = H0 * sin(2.0 * pi * nu * time)
    else:
        H = H0
    return H 
[docs]def num_iters(iters, skip):
    """Get number of iterations including first and last iterations."""
    return 1 + ((iters - 1) // skip) 
[docs]def sanity_checks(var, pos, mag, xyz):
    """Sanity check input."""
    new_nm = pos.shape[0]
    new_size = zeros(new_nm)
    if "no_molecules" not in var.defaults and var.no_molecules != new_nm:
        c.Error("F Number of molecules is not consistant with input file")
    elif var.no_molecules != new_nm:
        var.no_molecules = new_nm
        name = list(var.name)
        name[-2] = str(var.no_molecules)
        var.name = "".join(name)
    if "radius_1" not in var.defaults and var.radius.shape[0] != new_nm:
        c.Error("F Number of molecule radii is not consistant with input file")
    elif "radius_1" in var.defaults:
        var.radius = var.radius[0] + new_size
    if "dens_1" not in var.defaults and var.dens.shape[0] != new_nm:
        c.Error("F Number of molecule densities is not consistant with input file")
    elif "dens_1" in var.defaults:
        var.dens = var.dens[0] + new_size
        var.ms = var.ms[0] + new_size
    if mag is not None:
        abs_mag = linalg.norm(mag, axis=1)
        if (abs_mag > 1 + c.EPS).any() ^ (not xyz and (abs_mag > (var.ms * var.vol) + c.EPS).any()):
            c.Error("F The magnetisation should be less than the saturation magnetisation") 
[docs]def new_vars(var):
    """Calculate New variables from input."""
    var.hradius = var.radius
    var.vol = (4.0 / 3.0) * pi * (var.radius**3)  # [nm^3]
    var.hvol = (4.0 / 3.0) * pi * (var.hradius**3)  # [nm^3]
    var.phi = c.random.normal(pi, (2 * pi), size=var.no_molecules)
    var.theta = c.random.normal(pi / 2, pi, size=var.no_molecules)
    var.mass = var.vol * var.dens
    var.alpha = 6.0 * var.eta * c.GBARE / var.ms   # dimensionless damping parameter
    var.geff = c.GBARE / (1.0 + var.alpha**2)   # effective gyromagnetic ratio
    # size of the magnetic fluctuations
    # Note that c2m is divided here by sqrt(dt), but finally multiplied by dt.
    # The stochastic part correctly scales with sqrt(dt)
    var.c2 = sqrt(2.0 * c.KB * var.temp * var.alpha / (var.ms * c.GBARE * var.hvol * var.dt))
    var.c2 = where(var.c2 > c.EPS2, var.c2, 0)
    with errstate(all='raise'):
        try:
            var.chi0 = (var.ms**2) * var.vol / (3.0 * c.KB * var.temp)
            var.tauB = 3.0 * var.eta * var.vol / (c.KB * var.temp)
        except FloatingPointError:
            c.Error(">W ZeroDivisionError\n{} and {} have been set to infinity".format(c.ustr.chi, c.ustr.taub))
            var.chi0 = float("inf") * ones(var.no_molecules)
            var.tauB = float("inf") * ones(var.no_molecules) 
[docs]def get_particles(var, flg):
    if var.location is not None:
        fpos = {}
        fmom = {}
        fmag = {}
        # TODO probably could improve splitting filenames
        # currently splits on all spaces
        # move splitting to arguments.py etc
        location = var.location.split()
        for stat, loc in enumerate(location):
            pos, mom, mag = _get_particles(var, flg, loc)
            fpos[stat] = pos
            fmom[stat] = mom
            fmag[stat] = mag
        if len(var.location.split()) == 1:
            return pos, mom, mag, pos.shape
        pos = fpos
        mom = fmom
        mag = fmag
        return pos, mom, mag, pos[0].shape
    else:
        pos, mom, mag = _get_particles(var, flg, var.location)
        return pos, mom, mag, pos.shape 
def _get_particles(var, flg, location):
    (pos, mom, mag), xyz = initpos(var.no_molecules, var.radius * 2, location, var.boxsize)
    sanity_checks(var, pos, mag, xyz)
    new_vars(var)
    if mag is None:
        mag = zeros_like(pos)
        # # Transposed (ji not ij) due to theta and phi being row vectors
        einsum_str = "i, ji -> ij" if var.no_molecules > 1 else "i, ji -> j"
        mag[:] = einsum(einsum_str,
                        var.ms, array([sin(var.theta) * cos(var.phi),
                                       sin(var.theta) * sin(var.phi),
                                       cos(var.theta)]))
        # mag[:, 2] = 1
    mag = einsum("...iz,...i->...iz", mag, var.ms * var.vol / maximum(sqrt(einsum("...iz,...iz->...i", mag, mag)), c.EPS2))
    if mom is None:
        mom = zeros_like(pos)
    elif flg.labview:
        mom -= einsum("...iz->...z", mom) / mom.shape[0]  # unit: 1.e-12 g*cm/s
    return pos, mom, mag
# Main calculation
[docs]@debug(['core'])
def integrate(var={}, flg={}):
    """
    Set up of Calculation.
    Setting up of initial positions magnetisations and momentums
    as well as neel relaxations for each molecule
    This is all passed to the parallelisation of the statistical separations
    """
    c.reseed(54321)
    if flg.restart:
        pos = var.pos
        mag = var.mag
        mom = var.mom
        del var.pos
        del var.mom
        del var.mag
        movement = {"var": var, "flg": flg, "mom": mom, "mag": mag, "pos": pos}
    else:
        pos, mom, mag, shape = get_particles(var, flg)
        # Particle movement variables
        movement = {"var": var, "flg": flg, "mom": mom, "mag": mag, "pos": pos,
                    "op_noise": c.random.standard_normal(size=shape) if flg.op_thermal else None}
    print(var.name[:-1])
    # Time recording
    timer = end(var.finishtime)
    # Initial conditions
    calc = calculate(**{"Error": c.Error, "posit": position(**movement),
                        "var": var, "flg": flg, "timer": timer})
    if hasattr(integrate, "DBG"):
        var.print_contents()
        flg.print_contents()
    # Parallelisation of calculations
    mp_handle()(calc_wrap, calc, var.stats, flg.parallel)
    return timer, var.name 
[docs]class calculate():
    """
    1st parallel section.
    This function does the boilerplate around the brunt of the calculation
    it is run once for each statistic
    All the data is then saved in one of three user specified forms
    1. hdf5
    2. pickle
    3. plain text
    """
    @debug(["calc"])
    def __init__(self, **argd):
        """
        Initialise constant and variables.
        dictionary keys: var, flg, mom, mag, pos, timer, posit (initialised position class)
        Parameters
        ----------
        argd: dict
            dictionary of data
        """
        self.__dict__.update(argd)
        self.var.skip_iters = num_iters(self.var.nmax, self.var.skip)
        self.savechunk()
        self.dictcreate = dictionary_creation(self.var, self.flg)
        self.save = save_m(self.flg.neel, write_data(self.var, self.flg))
        self.grtime = grace(self.var.stats).time
        self.end = False
        self.nsplit = 0
        self.h_axis = array([[0.0, 0.0, 1.0]]) * ones((self.var.no_molecules, 3))
        if self.flg.noise:
            self.getnoise = self._returnnoise
        else:
            self.getnoise = self._noop
        if self.flg.suscep or self.flg.field is False:
            self.h_axis *= 0
            self.getfield = self._returnnoextfield
        else:
            self.getfield = self._returnextfield
        if self.flg.prog:
            self.progress = self._prog_verb
            self.progresstime = 1  # max(600 // self.flg.nout, 1)
        elif 3 > self.flg.nout > 1:
            self.progress = self._prog_nverb
            self.progresstime = max(600 // self.flg.nout, 1)
        else:
            self.progress = self._noop
            self.progresstime = 300
        if hasattr(self, 'DBG'):
            self._refresh_dict = self._refresh_dict_debug
    @staticmethod
    def _noop():
        pass
    def _prog_verb(self):
        c.progress(f"bar {self.moldata_dict['name'][len(self.var.directory):]} {self.count / self.var.nmax} {self.stat}")
    def _prog_nverb(self):
        print("{} of {} statistic {}".format(self.count, self.var.nmax, self.stat))
    def _returnnoise(self):
        c.random.standard_normal(out=self.noise_setup)
        self.Hext += c_einsum("ij, i -> ij", self.noise_setup[0], self.var.c2)
    def _returnnoextfield(self):
        self.Hext[:] = self.h_axis
    def _returnextfield(self):
        self.Hext[:] = field(self.time.time, H0=self.var.H_0, nu=self.var.nu) * self.h_axis
[docs]    def progress_report(self):
        """Progress reporter."""
        self.count += self.var.skip
        self.countdown = self.timer.gettimegone()
        if (self.countdown - self.oldcount) > self.progresstime:
            self.oldcount = self.countdown
            self.progress()
            if self.timer.gettimeleft() <= self.grtime:
                self.moldata_dict = self.save.stop(self.moldata_dict, self.posit.count)
                c.Error("W Hit walltime, trying to exit cleanly") 
[docs]    def savechunk(self):
        """Adjust savechunk."""
        self.bksavechunk = self.var.savechunk
        if self.var.savechunk > self.var.skip_iters:
            # split can't be larger than no_iterations
            self.var.savechunk = self.var.skip_iters 
[docs]    def iterationadjust(self):
        """Adjust iterations for writing and storing."""
        if self.flg.restart:
            self.rstep = ((self.var.extra_iter[self.stname]) * self.var.skip)
            self.cstep = self.var.nmax - self.rstep
            # Doesn't include post optimise initial positions step
            # Does include post optimise initial positions step
            self.cwritten = self.var.skip_iters - self.var.extra_iter[self.stname] + 1
            print("Stat {} Written: {} / {}".format(self.stat, self.cwritten, self.var.skip_iters + 1))
            if self.cwritten == self.var.skip_iters + 1:
                self.finishdata(0, self.cstep)
                exit(0)
            elif self.cwritten > self.var.skip_iters + 1:
                print("Bigger?", self.cwritten, self.var.skip_iters, num_iters(self.rstep, self.var.skip))
                exit(1)
            self.var.skip_iters = num_iters(self.rstep, self.var.skip)
            nmax = self.rstep
        else:
            nmax = self.var.nmax
        if self.flg.restart or isinstance(self.posit.pos, dict):
            self.posit.pos = self.posit.pos[self.stat]
            self.posit.mag = self.posit.mag[self.stat]
            self.posit.mom = self.posit.mom[self.stat]
        self.savechunk()
        # savechunk as 0 == save at the end
        self.writes = nmax // (self.var.savechunk * self.var.skip) if self.var.savechunk != 0 else 0
        self.state_int = max(self.writes // 10, 1)
        self.finalmemoryloop = (nmax % (self.var.savechunk * self.var.skip)) / self.var.skip if self.var.savechunk != 0 else 0
        self.finalskiploop = nmax % self.var.skip
        if self.finalskiploop == 0 and self.finalmemoryloop % 1 != 0:
            # finalmemory loop not whole number and skip fits exactly into nmax
            self.finalskiploop = self.var.skip
        # whole number
        self.finalmemoryloop = int(self.finalmemoryloop)
        # last dictionary size: Number of memory loops + last skip loop
        self.remainder = self.finalmemoryloop + 1 if self.finalskiploop > 0 else 0
        if self.finalmemoryloop == 0 and self.finalskiploop == 0:
            # Always run final loop at least once
            self.finalmemoryloop = self.var.savechunk
            self.writes -= 1 if self.writes > 0 else 0
        if hasattr(self, 'DBG'):
            print(f"Stat {self.stat}", f"Write_loop: {self.writes}",
                  f"Savechunk: {self.var.savechunk}", f"Skip: {self.var.skip}",
                  f"Final_M_loop: {self.finalmemoryloop}",
                  f"Final_S_loop: {self.finalskiploop}",
                  f"Remainder: {self.remainder}",
                  f"Total_tosave: {self.var.skip_iters}") 
[docs]    def randomnumbers(self):
        """Set up random number starting point."""
        if self.flg.restart:
            c.rgen.state = self.var.RandNoState[self.stname]
        else:
            c.reseed(54321)
            c.jump(self.stat) 
[docs]    def dictionarys(self):
        """Create dictionaries for data storage."""
        self.moldata_dict_end = dictionary_creation(self.var, self.flg, self.remainder)
        if self.bksavechunk != self.var.savechunk:
            self.dictcreate = dictionary_creation(self.var, self.flg)
        self.moldata_dict = self.dictcreate.copy()
        # File naming
        self.moldata_dict['name'] += "{:g}.{}".format(self.stat + 1, self.flg.save_type)
        self.name = self.moldata_dict['name'] 
    def _refresh_dict(self, dictionary):
        self.moldata_dict = dictionary.copy()
        self.moldata_dict['name'] = self.name
        self.save.remove_state(dictionary)
    def _refresh_dict_debug(self, dictionary):
        self.moldata_dict = deepcopy(dictionary)
        self.moldata_dict['name'] = self.name
        self.save.remove_state(dictionary)
[docs]    @addpid("Error")
    def setup(self):
        """Set up instance for each parallel calculation."""
        self.noise_setup = zeros((3, self.var.no_molecules, 3))
        self.Hext = zeros((self.var.no_molecules, 3))
        self.stname = self.save.stname = f"stat{self.stat}"
        self.iterationadjust()
        self.randomnumbers()
        self.dictionarys()
        self.flg.save_type = self.save.setup(self.name)
        self.starttime()
        self.firststep()
        self.prerange()
        # Progress report varables
        self.oldcount = self.timer.gettimegone() 
[docs]    def prerange(self):
        """Preallocate ranges."""
        self.dlrange = range(self.writes)
        self.mlrange = range(self.var.savechunk)
        self.slrange = range(self.var.skip)
        self.mlrangefinal = range(self.finalmemoryloop)
        self.slrangefinal = range(self.finalskiploop) 
[docs]    def run(self, stat):
        """
        Run calculation.
        Parameters
        ----------
        stat: int
            statistic number
        """
        self.stat = stat
        self.setup()
        self.disksaveloop() 
[docs]    def starttime(self):
        """Start timestep recording."""
        self.time = _time(self.var.dt, self.var.time)
        self.posit.time = self.time 
[docs]    def firststep(self):
        """Set up initial conditions and calculate first step."""
        self.posit.initialconditions()
        if self.flg.restart is None:
            self.count = 0
            first_step = dictionary_creation(self.var, self.flg, 1)
            first_step["name"] = self.moldata_dict['name']
            self.save.file(self.save.memory(first_step, self.posit.prop, self.time))
        else:
            self.count = self.cstep
            self.save.restart(self.moldata_dict['name'], self.var.extra_iter, self.var.SCFcount, self.cwritten) 
[docs]    def disksaveloop(self):
        """Save to disk loop."""
        for dl in self.dlrange:
            self.memorysaveloop(self.mlrange)
            if self.save.end:
                return
            if dl % self.state_int == 0 and dl != 0:
                self.save.state(self.moldata_dict, self.posit.count)
            self.save.file(self.moldata_dict, self.posit.count)
            self._refresh_dict(self.dictcreate)
        self.finaliteration() 
[docs]    def memorysaveloop(self, loop):
        """
        Save to memory loop.
        Parameters
        ----------
        loop: int
            loop size
        """
        for _ in loop:
            self.skiploop(self.slrange)
            # Collect data
            self.moldata_dict = self.save.memory(self.moldata_dict, self.posit.prop, self.time)
            self.progress_report() 
[docs]    def skiploop(self, loop):
        """
        Skip loop, no saving required.
        Parameters
        ----------
        loop: int
            loop size
        """
        for _ in loop:
            # External field
            self.getfield()
            # System noise
            self.getnoise()
            # Iterate
            self.posit.propagate(self.Hext, self.noise_setup)
            # time update
            self.time.time_update() 
[docs]    def finaliteration(self):
        """
        Last iteration loop.
        Setup for last step saving useful things for restart
        """
        if self.moldata_dict_end is not None:
            self._refresh_dict(self.moldata_dict_end)
        else:
            self._refresh_dict(self.dictcreate)
        self.memorysaveloop(self.mlrangefinal)
        if self.finalskiploop > 0:
            self.count += self.finalskiploop
            self.skiploop(self.slrangefinal)
            self.moldata_dict = self.save.memory(self.moldata_dict, self.posit.prop, self.time)
        self.save.end = True
        self.save.file(self.moldata_dict, self.posit.count)
        self.progress()
        self.finishdata(self.posit.count, self.count) 
[docs]    def finishdata(self, SCFcount, count):
        SCFcount += self.var.SCFcount if self.flg.restart else 0
        print(f"\nStat {self.stat}\n",
              f"SCF cycles: {SCFcount}\n",
              f"Total Iterations: {self.var.nmax}\n",
              f"Completed Iterations: {count}\n",
              f"SCF per Iteration: {(SCFcount / count) if count > 0 else 0}")  
[docs]class save_m():
    """
    Storing function.
    Stores data in memory until asked to save to file where it calls the writer
    Parameters
    ----------
    neel: bool
        Is neel relaxation required?
    wd: instance
        instance of writing class
    """
    @debug(["save"])
    def __init__(self, neel, wd):
        """Initialise save routine."""
        self.wd = wd
        self.end = False
        self.neel = neel
        if self.neel:
            self.neel_save = self._neel_save
        else:
            self.neel_save = self._no_neel_save
        self.written = 0
        self.SCFcount = 0
        self._reset_ind()
    def _neel_save(self, data_dict, data):
        data_dict['neel_relaxation'][self.splitcount, :] = data["neel_count"].count
        return data_dict
    @staticmethod
    def _no_neel_save(data_dict, data):
        return data_dict
    def _reset_ind(self):
        self.splitcount = 0
    def _incr_splitcount(self):
        self.splitcount += 1
        self.written += 1
[docs]    def setup(self, name):
        """Set up writing class."""
        return self.wd.setup(name) 
[docs]    def restart(self, name, extra_iter, SCFcount, written):
        """Save extra_iters, SCFcount and current written."""
        self.written = written
        self.SCFcount = SCFcount
        self.wd.write({'name': name, 'vars': {'extra_iter': extra_iter}}) 
[docs]    def file(self, data_dict, SCFcount=0):
        """
        Save current data block to file.
        Parameters
        ----------
        data_dict: dict
            Storage dictionary
        """
        if self.end:
            self.state(data_dict, SCFcount)
        self.wd.write(data_dict)
        self._reset_ind() 
[docs]    def state(self, data_dict, SCFcount):
        """Save current state."""
        data_dict['vars']['RandNoState'][self.stname] = c.rgen.state
        data_dict['vars']['written'] = self.written
        data_dict['vars']['SCFcount'] = self.SCFcount + SCFcount 
[docs]    def remove_state(self, data_dict):
        """Remove current state."""
        for val in ['RandNoState', 'written', 'SCFcount']:
            data_dict['vars'][val] = {} 
[docs]    def stop(self, data_dict, SCFcount=0):
        """
        Emergency Stop saving.
        Save current data block to file.
        Parameters
        ----------
        data_dict: dict
            Storage dictionary
        """
        self.end = True
        sc = self.splitcount  # save  to avoid reset
        self.file(self.cutter(data_dict, deepcopy(data_dict), None, sc), SCFcount)
        return self.cutter(data_dict, data_dict, sc, None) 
[docs]    def cutter(self, orig, copy, start, end):
        """Return a section of input dictionary."""
        copy['position'] = orig['position'][start:end, :, :]
        copy['iter_time'] = orig['iter_time'][start:end, :]
        copy['magnetisation'] = orig['magnetisation'][start:end, :, :]
        copy['forces'] = orig['forces'][start:end, :, :]
        copy['momentum'] = orig['momentum'][start:end, :, :]
        if self.neel:
            copy['neel_relaxation'][start:end, :] = orig['neel_relaxation'][start:end, :]
        return copy 
[docs]    def memory(self, data_dict, data, time):
        """
        Save to memory current iteration data.
        Parameters
        ----------
        data_dict: dict
            Storage dictionary
        data: dict
            data to be stored
        time: instance
            timestep instance
        Returns
        -------
        data_dict: dict
            Storage dictionary
        """
        data_dict['position'][self.splitcount, :, :] = data["pos"]
        data_dict['iter_time'][self.splitcount, :] = time.iter, time.time
        data_dict['magnetisation'][self.splitcount, :, :] = data["mag"]
        data_dict['forces'][self.splitcount, :, :] = data["forces"]
        data_dict['momentum'][self.splitcount, :, :] = data["mom"]
        data_dict = self.neel_save(data_dict, data)
        self._incr_splitcount()
        return data_dict  
[docs]@debug(['core'])
def dictionary_creation(var, flg, extra=False):
    """
    Creation of initial dictionary.
    - all data is the same at this point
    - Uses nested dictionaries for ease of storage
    """
    if extra is False:
        size = var.savechunk
    elif extra > 0:
        size = extra
    else:
        return None
    moldata_dict = _getdicts(size, var.no_molecules, flg.neel)
    moldata_dict['mango_version'] = c.version
    moldata_dict['name'] = var.name
    moldata_dict['flags'] = flg.__dict__.copy()
    moldata_dict['vars'] = var.__dict__.copy()
    del moldata_dict['vars']['finishtime']
    with suppress(KeyError):
        del moldata_dict['vars']['defaults']
    del moldata_dict['vars']['time']
    del moldata_dict['vars']['name']
    for val in ['RandNoState', 'extra_iter', 'SCFcount', 'written']:
        moldata_dict['vars'][val] = {}
    return moldata_dict 
def _getdicts(x, y, neel):
    new_dict = nestedDict()
    new_dict['iter_time'] = zeros((x, 2))
    new_dict['position'] = zeros((x, y, 3))
    new_dict['magnetisation'] = zeros_like(new_dict['position'])
    new_dict['forces'] = zeros_like(new_dict['position'])
    new_dict['momentum'] = zeros_like(new_dict['position'])
    # new_dict['CoM'] = zeros((x, 3))
    # new_dict['CoM_vel'] = zeros_like(new_dict['CoM'])
    # new_dict['momenta']['total'] = zeros_like(new_dict['CoM'])
    # new_dict['momenta']['total_mag'] = zeros_like(new_dict['CoM'])
    # new_dict['momenta']['total_angular'] = zeros_like(new_dict['CoM'])
    # new_dict['energy']['kinetic'] = zeros(x)
    # new_dict['energy']['trans_pot'] = zeros_like(new_dict['energy']['kinetic'])
    # new_dict['energy']['mag_pot'] = zeros_like(new_dict['energy']['kinetic'])
    # new_dict['energy']['total'] = zeros_like(new_dict['energy']['kinetic'])
    # new_dict['energy']['kineticm'] = zeros((x, y))
    if neel:
        new_dict['neel_relaxation'] = zeros((x, y))
    return new_dict
if __name__ == "__main__":
    pass