"""Visualization utilities based on Open3D."""
import numpy as np
from .pandas_utils import extract_markers
from pytransform3d import visualizer as pv
import open3d as o3d


def plot_markers(figure, df, markers, colors=None):
    """Plot markers.

    Parameters
    ----------
    figure : Figure
        Figure to which the artist will be added.

    df : DataFrame
        Time series data

    markers : list of str
        Names of Qualisys markers that should be displayed

    colors : list of tuple, optional (default: black)
        A color is represented by 3 values between 0 and 1 indicate
        representing red, green, and blue respectively.
    """
    if colors is None:
        colors = [None] * len(markers)
    for marker, color in zip(markers, colors):
        X = extract_markers(df, [marker], keep_time=False).to_numpy()
        figure.plot(X, c=color)


class PointCollection(pv.Artist):
    """Collection of points.

    Parameters
    ----------
    P : array, shape (n_points, 3)
        Points

    s : float, optional (default: 0.05)
        Scaling of the spheres that will be drawn

    c : array-like, shape (3,) or (n_points, 3), optional (default: black)
        A color is represented by 3 values between 0 and 1 indicate
        representing red, green, and blue respectively.
    """
    def __init__(self, P, s=0.05, c=None):
        self.markers = []
        self.P = np.zeros_like(P)

        if c is not None:
            c = np.asarray(c)
            if c.ndim == 1:
                c = np.tile(c, (len(P), 1))

        for i in range(len(P)):
            marker = o3d.geometry.TriangleMesh.create_sphere(radius=s)
            if c is not None:
                n_vertices = len(marker.vertices)
                colors = np.zeros((n_vertices, 3))
                colors[:] = c[i]
                marker.vertex_colors = o3d.utility.Vector3dVector(colors)
            marker.compute_vertex_normals()
            self.markers.append(marker)

        self.set_data(P)

    def set_data(self, P):
        """Update data.

        Parameters
        ----------
        P : array, shape (n_points, 3)
            Points
        """
        P = np.copy(P)
        for i, marker, p, previous_p in zip(
                range(len(self.P)), self.markers, P, self.P):
            if any(np.isnan(p)):
                P[i] = previous_p
            else:
                marker.translate(p - previous_p)

        self.P = P

    @property
    def geometries(self):
        """Expose geometries.

        Returns
        -------
        geometries : list
            List of geometries that can be added to the visualizer.
        """
        return self.markers


def scatter(figure, P, s=0.05, c=None):
    """Scatter plot.

    Parameters
    ----------
    figure : Figure
        Figure to which the artist will be added.

    P : array, shape (n_points, 3)
        Points

    s : float, optional (default: 0.05)
        Scaling of the spheres that will be drawn

    c : array-like, shape (3,) or (n_points, 3), optional (default: black)
        A color is represented by 3 values between 0 and 1 indicate
        representing red, green, and blue respectively.

    Returns
    -------
    artist : PointCollection
        Artist that has been added to the figure
    """
    artist = PointCollection(P, s, c)
    artist.add_artist(figure)
    return artist
