# conda create -c conda-forge -n vispy_env python=3 vispy pyqt5 numpy
# conda install pyopengl
# conda install imageio imageio-ffmpeg
# conda install pillow
"""
Mango particle visualiser.
Controls:
* 0    - reset cameras
* 1    - toggle camera between regular 3D (turntable), first person (fly),
         arcball and Pan Zoom
* 2    - toggle between real particles and vector spheres
* +/-  - speed up/slow down simulation
* !    - start/stop simulation
* .    - reset simulation speed
* I    - save image of current frame
With fly camera:
* WASD or arrow keys - move around
* SPACE - brake
* FC - move up-down
* IJKL or mouse - look around
TODO:
* Change arrow size based on strength
* speed up read in
"""
from sys import argv
import numpy as np
import imageio
from vispy import app, scene, util
from vispy.visuals.transforms import STTransform, MatrixTransform
from vispy.visuals.transforms._util import as_vec4
from vispy.visuals.visual import CompoundVisual
from vispy.visuals.mesh import MeshVisual
from vispy.scene import visuals
from vispy.geometry import create_arrow
from mango.managers import serverwrapper
from mango import c
from mango.debug import DBG
[docs]def read(filename, speed):
    """
    Read in extended xyz file and names columns in recorded arrays.
    Future:
    * Error checking
    * crash gracefully when memory gets too full
    """
    with open(filename) as f:
        # Managing initial lines of xyz file
        data_list = []
        num = 0
        for i, line in enumerate(f):
            if i % 5e6 == 0 and i != 0:
                print(f'Read {int(i*5e6)} frames')
            if i == 0:
                no_mol = int(line)
            elif i == 1:
                comments = {}
                line = line.split()
                for comment in line:
                    com = comment.split("=")
                    try:
                        comments[com[0]] = float(com[1].split('[')[0])
                    except ValueError:
                        comments[com[0]] = com[1] == 'True'
                    except IndexError:
                        pass
                # comments["i"] = int(comments["i"])
            elif i % (no_mol + 2) == 1 or i % (no_mol + 2) == 0:
                # ignoring future comments other than the first
                continue
            elif num % speed < no_mol:
                # skip frames
                data_list += [line.split()]
                num += 1
            else:
                num += 1
        # Naming columns
        dtypes = []
        for i in range(len(data_list[0])):
            if i == 0:
                dtypes += [("name", "U5")]
            elif i == 1:
                dtypes += [('x', float)]
            elif i == 2:
                dtypes += [('y', float)]
            elif i == 3:
                dtypes += [('z', float)]
            else:
                dtypes += [('col{}'.format(i - 4), float)]
    xyz_dump = np.rec.array([tuple(x) for x in data_list], dtype=dtypes)
    size = xyz_dump.x.shape[0]
    print('Number of frames: ', size // no_mol)
    return {"xyz_dump": xyz_dump, "size": size, "no_mol": no_mol,
            "mag": comments['mag'] if 'mag' in comments else False,
            "force": comments['force'] if 'force' in comments else False,
            "momentum": comments['momentum'] if 'momentum' in comments else False,
            'box': comments['boxsize'] if 'boxsize' in comments else False} 
class Arrow3DVisual(CompoundVisual):
    def __init__(self, radius=1.0, length=1.0, cols=30, rows=30, vertex_colors=None, face_colors=None,
                 color=(0.5, 0.5, 1, 1), edge_color=None, **kwargs):
        mesh = create_arrow(rows, cols, radius=radius, length=length)
        self._mesh = MeshVisual(vertices=mesh.get_vertices(),
                                faces=mesh.get_faces(),
                                vertex_colors=vertex_colors,
                                face_colors=face_colors, color=color)
        if edge_color:
            self._border = MeshVisual(vertices=mesh.get_vertices(),
                                      faces=mesh.get_edges(),
                                      color=edge_color, mode='lines')
        else:
            self._border = MeshVisual()
        CompoundVisual.__init__(self, [self._mesh, self._border], **kwargs)
        self.mesh.set_gl_state(polygon_offset_fill=True,
                               polygon_offset=(1, 1), depth_test=True)
    @property
    def mesh(self):
        return self._mesh
    @property
    def border(self):
        return self._border
[docs]class scene_setup():
    def __init__(self, application, radii, skip=1, freeze=False, true_size=True, fps=False, video=False, sm_box=True, loop=True):
        self.app = application
        # Prepare canvas
        canvas = scene.SceneCanvas(keys='interactive', size=(800, 608), show=True, bgcolor='white')
        if fps:
            canvas.measure_fps()
        canvas.events.mouse_move.connect(self.on_mouse_move)
        canvas.events.key_press.connect(self.on_key_press)
        self.render = canvas.render
        self.csize = canvas.size
        self.view = canvas.central_widget.add_view()
        if video:
            self.writer = imageio.get_writer('animation.mp4', quality=10, fps=20)
            self.video = self.makevideo
            self.canvas = canvas
        else:
            self.video = self._noop
        fov = 60.
        self.loops = 1
        self.f = 0
        self.inc = 1
        self.stop_at_end = not loop
        self.cont = False
        self.camera(fov)
        with np.errstate(all='raise'):
            if not true_size:
                self.cr_bsph = self._noop
            self.scene(*self.get_data(radii, freeze, sm_box, skip))
        self.timer = app.Timer('auto', connect=self.on_timer, start=False, app=self.app)
        try:
            self.app.run()
        except:
            pass
        finally:
            if video:
                self.writer.close()
    @staticmethod
    def _noop(*args):
        pass
[docs]    def makevideo(self):
        self.writer.append_data(self.canvas.render()) 
[docs]    def on_timer(self, event):
        self.f += self.inc
        if self.f >= self.no_iters:
            if self.stop_at_end and not self.cont:
                self.timer.stop()
                if self.f - self.inc < self.no_iters - 1:
                    self.f -= self.inc
                else:
                    self.cont = True
                    print("LastFrame        ", end='\r')
                return
            elif self.stop_at_end:
                self.cont = False
            self.f = 0
            print(f"Start x{self.loops}     ", end='\r')
            self.loops += 1
        for i in self.ispheres:
            self.ispheres[i].transform.move_to(self.trans[i, self.f])
            self.spheres[i].transform.move_to(self.trans[i, self.f])
            self.move_arr(i)
        self.video() 
[docs]    def get_data(self, radii, freeze, sm_box, skip=1):
        a = read(argv[1], skip)
        n = a['no_mol']
        pradius = np.zeros(n) + radii
        self.framefilelength = int(np.log10(a['size'] / n))
        xyz = self.positions(a, freeze)
        self.arrows(a)
        self.no_iters = xyz.shape[1]
        box = np.sum(pradius * 2) * 1.5 if sm_box or not a['box'] else a['box']
        print('BOXSIZE', box)
        return xyz, n, pradius, box 
    # def maxmin(self, arr1, arr2):
    #     norm1 = np.linalg.norm(arr1, axis=-1)
    #     norm2 = np.linalg.norm(arr2, axis=-1)
    #     conesize1 = norm1 / np.max(norm1, axis=-1)[..., None]
    #     conesize2 = norm2 / np.max(norm2, axis=-1)[..., None]
    #     self.fcones[i].transform.scale(self.fconesize[i, self.f], self.xyz[i, self.f])
    #     self.fcones[i].transform.scale(self.fconesize[i, self.f], xyz[i, self.f])
    #     conesize = self.maxmin(arr1, arr2)
    #     return conesize1 / conesize2
[docs]    def positions(self, a, freeze):
        n = a['no_mol']
        x = a['xyz_dump'].x
        y = a['xyz_dump'].y
        z = a['xyz_dump'].z
        xyz = np.array([x, y, z]).T
        xyz = xyz.reshape(n, -1, 3, order='F')
        if freeze:
            xyz -= xyz[int(n / 2)][None] + 0.01
        trans = []
        for framenum, frame in enumerate(xyz):
            trans += [[]]
            for mol in frame:
                trans[framenum] += [as_vec4(mol)]
        self.trans = np.array(trans)
        return xyz 
[docs]    def arrows(self, a):
        n = a['no_mol']
        cnum = np.arange(3, 9) if a['momentum'] else np.arange(0, 6)
        if a['force'] and a['mag']:
            fx = getattr(a['xyz_dump'], f'col{cnum[0]}')
            fy = getattr(a['xyz_dump'], f'col{cnum[1]}')
            fz = getattr(a['xyz_dump'], f'col{cnum[2]}')
            mx = getattr(a['xyz_dump'], f'col{cnum[3]}')
            my = getattr(a['xyz_dump'], f'col{cnum[4]}')
            mz = getattr(a['xyz_dump'], f'col{cnum[5]}')
            self.cr_arrs = self.cr_arr2
            self.move_arr = self._mv_arr2
        elif a['force']:
            fx = getattr(a['xyz_dump'], f'col{cnum[0]}')
            fy = getattr(a['xyz_dump'], f'col{cnum[1]}')
            fz = getattr(a['xyz_dump'], f'col{cnum[2]}')
            mx = my = mz = np.zeros_like(a['xyz_dump'].x)
            self.cr_arrs = self._cr_arrf
            self.move_arr = self._mv_arrf
        elif a['mag']:
            mx = getattr(a['xyz_dump'], f'col{cnum[0]}')
            my = getattr(a['xyz_dump'], f'col{cnum[1]}')
            mz = getattr(a['xyz_dump'], f'col{cnum[2]}')
            fx = fy = fz = np.zeros_like(a['xyz_dump'].x)
            self.cr_arrs = self._cr_arrm
            self.move_arr = self._mv_arrm
        else:
            self.cr_arrs = self.move_arr = self._noop
        if a['force']:
            fxyz = np.array([fx, fy, fz]).T
            fxyz = fxyz.reshape(n, -1, 3, order='F')
            self.frotater, self.initf, self.initfa = self.cone_rotation(fxyz)
        if a['mag']:
            mxyz = np.array([mx, my, mz]).T
            mxyz = mxyz.reshape(n, -1, 3, order='F')
            self.mrotater, self.initm, self.initma = self.cone_rotation(mxyz)
            if DBG:
                print(self.initm, self.initma) 
[docs]    def cone_rotation(self, arr1):
        if arr1.shape[1] > 1:
            arr2 = self.spinner(arr1)
            angle, rotax = self.angle(arr1, arr2, 'xyz, nix, niy ->niz')
            angle = self.spinner2(angle)
            rotax = self.spinner2(rotax)
        else:
            angle = rotax = None
        initial_dir = np.zeros((arr1.shape[0], 3))
        initial_dir[:, 1] = 1
        if DBG:
            print(arr1.shape)
        init_angle, init_rot = self.angle(initial_dir, arr1[:, 0], 'xyz, ix, iy ->iz')
        return self.rotate_array(angle, rotax, init_angle, init_rot), init_angle, init_rot 
[docs]    def rotate_array(self, angle, rotax, ia, ir):
        matrix = self.rotdp(np.tile(np.eye(4), (ia.shape[0], 1, 1)), ia, ir)
        if angle is not None:
            rotater = np.zeros((angle.shape[0], angle.shape[1], matrix.shape[1], matrix.shape[2]))
            for n2 in range(angle.shape[1]):
                rotater[:, n2] = self.rotdp(matrix, angle[:, n2], rotax[:, n2]).copy()
        else:
            rotater = matrix[:, None, ...]
        return rotater 
[docs]    @staticmethod
    def rotdp(matrix, angle, axis):
        for n in range(angle.shape[0]):
            matrix[n] = np.dot(matrix[n], util.transforms.rotate(angle[n], axis[n])) if angle[n] != 0 else matrix[n]
        return matrix 
[docs]    def angle(self, arr1, arr2, string=None):
        arr1_ = np.linalg.norm(arr1, axis=-1)[..., None]
        arr2_ = np.linalg.norm(arr2, axis=-1)[..., None]
        dp = np.einsum('...j, ...j -> ...', arr1, arr2)[..., None] / (arr1_ * arr2_)
        try:
            cp = np.einsum(string, c.eijk, arr1, arr2) / (arr1_ * arr2_ * np.sin(dp))
        except FloatingPointError:
            cp = np.zeros_like(dp)
        return (np.pi / 2 - np.arccos(dp)) * (180 / np.pi), cp 
[docs]    def scene(self, xyz, n, pradius, bsize):
        self.spheres = {}
        self.ispheres = {}
        self.fcones = {}
        self.mcones = {}
        # Create an XYZAxis visual
        self.axis = scene.visuals.XYZAxis(parent=self.view)
        self.axis.transform = STTransform(translate=(50, 50), scale=(50, 50, 50, 1)).as_matrix()
        self.box = visuals.Box(width=bsize, height=bsize, depth=bsize, color=None, edge_color='black', parent=self.view.scene)
        self.box.transform = STTransform(translate=[0, 0, 0])
        self.translucent(self.box)
        for i in range(n):
            self.ispheres[i] = visuals.Sphere(radius=pradius[i] * 0.1, method='ico', parent=self.view.scene,
                                              color='black', subdivisions=2)
            self.ispheres[i].transform = STTransform(translate=xyz[i, self.f])
            self.cr_bsph(i, pradius[i], xyz[i, self.f])
            self.cr_arrs(i, pradius[i]) 
[docs]    def camera(self, fov):
        # Create four cameras (Fly, Turntable, Arcball and PanZoom)
        self.cam1 = scene.cameras.TurntableCamera(parent=self.view.scene, fov=fov, scale_factor=10,
                                                  name='Turntable')
        self.cam2 = scene.cameras.FlyCamera(parent=self.view.scene, fov=fov, scale_factor=10, name='Fly')
        self.cam3 = scene.cameras.ArcballCamera(parent=self.view.scene, fov=fov, scale_factor=10, name='Arcball')
        self.cam4 = scene.cameras.PanZoomCamera(parent=self.view.scene, aspect=1, name='PanZoom')
        self.cam4.zoom(10, center=(0.5, 0.5, 0))
        self.view.camera = self.cam1  # Select turntable at first 
    def _mv_arrf(self, i):
        self._arrow_move_to(self.fcones[i], self.frotater[i, self.f], self.trans[i, self.f])
    def _mv_arrm(self, i):
        self._arrow_move_to(self.mcones[i], self.mrotater[i, self.f], self.trans[i, self.f])
    def _mv_arr2(self, i):
        self._arrow_move_to(self.fcones[i], self.frotater[i, self.f], self.trans[i, self.f])
        self._arrow_move_to(self.mcones[i], self.mrotater[i, self.f], self.trans[i, self.f])
[docs]    def cr_bsph(self, i, pradius, xyz):
        self.spheres[i] = visuals.Sphere(radius=pradius, method='ico', parent=self.view.scene,
                                         color=(0.5, 0.6, 1, 0.1), edge_color=(0.5, 0.6, 1, 1), subdivisions=2)
        self.translucent(self.spheres[i])
        self.spheres[i].transform = STTransform(translate=xyz) 
    def _cr_arrf(self, i, pradius):
        self.fcones[i] = visuals.Arrow3D(radius=pradius * 0.01, length=pradius * 0.3, parent=self.view.scene, cols=20,
                                         rows=4, color=(0.5, 0.6, 1, 1), edge_color=(0.5, 0.5, 1, 1))
        self.fcones[i].transform = MatrixTransform()
        self._arrow_move(self.fcones[i], self.initf[i], self.initfa[i], self.trans[i, self.f])
    def _cr_arrm(self, i, pradius):
        self.mcones[i] = visuals.Arrow3D(radius=pradius * 0.01, length=pradius * 0.3, parent=self.view.scene, cols=20,
                                         rows=4, color=(1.0, 0.5, 0.5, 1.0), edge_color=(1.0, 0.0, 0.0, 1.0))
        self.mcones[i].transform = MatrixTransform()
        self._arrow_move(self.mcones[i], self.initm[i], self.initma[i], self.trans[i, self.f])
[docs]    def cr_arr2(self, i, pradius):
        self._cr_arrf(i, pradius)
        self._cr_arrm(i, pradius) 
[docs]    @staticmethod
    def spinner(arr):
        arr2 = arr.copy()
        arr2[:, :-1] = arr[:, 1:].copy()
        arr2[:, -1] = arr[:, 0].copy()
        return arr2 
[docs]    @staticmethod
    def spinner2(arr):
        arr2 = arr.copy()
        arr2[..., 1:] = arr[..., :-1].copy()
        arr2[..., :1] = arr[..., -1:].copy()
        return arr2 
[docs]    @staticmethod
    def translucent(obj):
        obj.mesh.set_gl_state('translucent', polygon_offset_fill=True,
                              polygon_offset=(1, 1), depth_test=False) 
    @staticmethod
    def _arrow_move(obj, angle, raxis, trans):
        if angle != 0:
            obj.transform.rotate(angle, raxis)
        obj.transform.translate_to(trans)
    @staticmethod
    def _arrow_move_to(obj, rot, trans):
        obj.transform.rotate_to(rot)
        obj.transform.translate_to(trans)
[docs]    def on_mouse_move(self, event):
        # Implement axis connection with cam1
        if event.button == 1 and event.is_dragging:
            self.axis.transform.reset()
            self.axis.transform.rotate(self.cam1.roll, (0, 0, 1))
            self.axis.transform.rotate(self.cam1.elevation, (1, 0, 0))
            self.axis.transform.rotate(self.cam1.azimuth, (0, 1, 0))
            self.axis.transform.scale((50, 50, 0.001))
            self.axis.transform.translate((50., 50.))
            self.axis.update() 
[docs]    def on_key_press(self, event):
        # Implement key presses
        if event.text == '1':
            cam_toggle = {self.cam1: self.cam2, self.cam2: self.cam3,
                          self.cam3: self.cam4, self.cam4: self.cam1}
            self.view.camera = cam_toggle.get(self.view.camera, self.cam2)
            print(self.view.camera.name + ' camera')
            self.axis.visible = self.view.camera is self.cam1
        elif event.text == '0':
            self.cam1.set_range()
            self.cam2.set_range()
            self.cam3.set_range()
        elif event.text == '+':
            if self.timer.interval > 0.01:
                self.timer.interval -= 0.01
            else:
                self.inc *= 2
            print(self.inc, self.timer.interval)
        elif event.text == '-':
            if self.inc >= 2:
                self.inc //= 2
            else:
                self.timer.interval += 0.01
            print(self.inc, self.timer.interval)
        elif event.text == '.':
            self.inc = 1
            self.timer.interval = 1 / 60
        elif event.text == '!':
            if self.timer.running:
                self.timer.stop()
            else:
                self.timer.start()
        elif event.text == '2':
            for i in self.spheres.values():
                i.visible = not i.visible
        elif event.text == 'I':
            name = f'{self.f:0{self.framefilelength}d}.png'
            imageio.imwrite(name, self.render())
            print(f'frame {name} saved')  
[docs]def move_to(self, trans):
    self.translate = trans 
[docs]def translate_to(self, trans):
    self.matrix[-1] = trans 
[docs]def rotate_to(self, rot):
    self.matrix = rot 
[docs]@serverwrapper('Error')
def main():
    print(__doc__)
    helpstring = '{}{}{}'.format("Run with:\n mango_vis <xyzfile> <radii> <-s -v -m -l>\n",
                                 " -s skipframes (default: 300)\n -v verbose\n",
                                 " -m record video\n -l loop visualisation")
    if '-h' in argv:
        print(helpstring)
        return
    fps = ('-v' in argv)
    video = ('-m' in argv)
    loop = ('-l' in argv)
    if '-s' in argv:
        sind = argv.index('-s')
        skip = int(argv[sind + 1])
        del argv[sind + 1]
        del argv[sind]
    else:
        skip = False
    for v in ['-v', '-m', '-l']:
        if v in argv:
            del argv[argv.index(v)]
    if len(argv) != 3:
        print('xyz file or radii not given')
        print(helpstring)
        exit(1)
    radii = np.array(argv[2:], dtype=float)
    scene_setup(app.Application(), radii=radii, skip=skip or 300, freeze=False, fps=fps, video=video, sm_box=False, loop=loop) 
STTransform.move_to = move_to
MatrixTransform.translate_to = translate_to
MatrixTransform.rotate_to = rotate_to
visuals.Arrow3D = visuals.create_visual_node(Arrow3DVisual)
if __name__ == '__main__':
    main()