from numpy import einsum, sqrt, eye, copy, amax, sign, prod, linalg, zeros, arange, zeros_like
from mango.constants import c
[docs]class CoM_motion_rm():
    def __init__(self, mass, no_mol):
        self.mass = mass
        self.no_mol = no_mol
        self.stats = []
[docs]    def suscep_align(self, sus_data):
        pos = sus_data['pos']
        mom = sus_data['mom']
        mag = sus_data['mag']
        rot, inertia, angular = self._rm(pos, mom, mag)
        sus_data['rotation'] = rot
        sus_data['inertia'] = inertia
        sus_data['angular'] = angular
        return sus_data 
[docs]    def align(self, xyz_data, sus_data):
        if xyz_data is None and sus_data is not None:
            self.suscep_align(sus_data)
        elif xyz_data is not None:
            rot, inertia, angular = self.xyz_align(xyz_data)
            if sus_data is not None:
                sus_data['pos'][:] = xyz_data[..., 0:3]
                sus_data['mom'][:] = xyz_data[..., 3:6]
                sus_data['mag'][:] = xyz_data[..., 6:9]
                sus_data['rotation'] = rot
                sus_data['inertia'] = inertia
                sus_data['angular'] = angular
        return xyz_data, sus_data 
[docs]    def xyz_align(self, xyz_data):
        pos = xyz_data[..., 0:3]
        mom = xyz_data[..., 3:6]
        mag = xyz_data[..., 6:9]
        rot, inertia, angular = self._rm(pos, mom, mag)
        return rot, inertia, angular 
[docs]    def inertia_angular(self, xyz_data, sus_data):
        pos = sus_data['pos']
        mom = sus_data['mom']
        pos[:], mom[:], pos_com, mom_com = remove_CoMm(pos, mom, self.mass, self.no_mol)
        return get_angular(pos, pos_com, mom, mom_com), get_inertia(pos, pos_com, self.mass) 
    def _rm(self, pos, mom, mag):
        pos[:], mom[:], pos_com, mom_com = remove_CoMm(pos, mom, self.mass, self.no_mol)
        inertia = get_inertia(pos, pos_com, self.mass)
        rot = zeros_like(inertia)
        # Kabsch's algorithm
        if self.no_mol > 1:
            # Alignment of the initial frame (diagonalise the tensor of inertia)
            wmat, vmat = linalg.eigh(inertia[:, 0])
            fmat = einsum('aij -> aji', vmat)
            print("# Frame alignment (Kabsch)")
            target = copy(pos[:, 0])  # Initial configuration
            rmsd_max = amax(einsum('aijk,aijk->ai', pos - target[:, None], pos - target[:, None]) / self.no_mol)
            print("# Maximum RMSD before alignment: %12.5g" % (rmsd_max))
            rmat = einsum('abxi,axj->abij', pos, target)
            vmat, sigma, wmat = linalg.svd(rmat)
            umat = zeros((pos.shape[0], pos.shape[1], 3, 3))
            loc = arange(3)
            umat[:, :, loc, loc] = 1
            umat[:, :, 2, 2] = sign(prod(sigma, axis=-1))
            rmat = einsum('aij, abkj, abck, ablk -> abil', fmat, wmat, umat, vmat)
            # Rotation
            pos[:] = einsum('abij,abxj->abxi', rmat, pos)
            mom[:] = einsum('abij,abxj->abxi', rmat, mom)
            mag[:] = einsum('abij,abxj->abxi', rmat, mag)
            rot[:] = einsum('abji', rmat)
            target[:] = einsum('axk,ajk->ajx', fmat, target)
            rmsd_max = amax(einsum('aijk,aijk->ai', pos - target[:, None, :, :], pos - target[:, None, :, :]) / self.no_mol)
            print("# Maximum RMSD after  alignment: %12.5g" % (rmsd_max))
            pos[:], mom[:], pos_com, mom_com = remove_CoMm(pos, mom, self.mass, self.no_mol)
            inertia = get_inertia(pos, pos_com, self.mass)
        angular = get_angular(pos, pos_com, mom, mom_com)
        return rot, inertia, angular 
[docs]def get_inertia(pos, pos_com, mass):
    # Tensor of inertia
    mass_scaled = einsum('aijk,j->aijk', pos - pos_com[:, :, None, :], sqrt(mass))  # mass-scaled coordinates
    tdot = einsum('aixj,aixk->aijk', mass_scaled, mass_scaled)
    return einsum('aixx,jk->aijk', tdot, eye(3)) - tdot 
[docs]def get_angular(pos, pos_com, mom, mom_com):
    return einsum('xyz,aijx, aijy ->aijz', c.eijk, pos - pos_com[:, :, None, :], mom - mom_com[:, :, None, :], optimize=True) 
[docs]def remove_CoMm(pos, mom, mass, no_mol):
    # Remove Centre of Mass Motion
    pos_com = einsum("aijk,j->aik", pos, mass) / einsum("j->", mass)
    pos -= pos_com[:, :, None, :]
    mom_com = einsum("aijk->aik", mom) / no_mol
    mom -= mom_com[:, :, None, :]
    return pos, mom, pos_com, mom_com