"""Qualisys data loader."""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
from itertools import cycle
from .cleaning import interpolate_nan
from .pandas_utils import get_all_markers


def read_qualisys_tsv(filename, unit="m", verbose=0):
    """Reads motion capturing data from tsv into pandas data frame.

    Parameters
    ----------
    filename : str
        Source file

    unit : str, optional (default: 'm')
        Unit to measure positions. Either meters 'm' or millimeters 'mm'.

    verbose : int, optional (default: 0)
        Verbosity level

    Returns
    -------
    df : DataFrame
        Raw data streams from source file
    """
    n_kv, n_meta = _header_sizes(filename)
    meta = pd.read_csv(
        filename, sep="\t", names=["Key", "Value"], header=None, nrows=7)
    meta = dict(zip(meta["Key"], meta["Value"]))
    if n_kv != n_meta:
        events = pd.read_csv(
            filename, sep="\t", header=None, skiprows=n_kv,
            names=["Event", "Type", "Frame", "Time"],
            nrows=n_meta - n_kv - 1)
    else:
        events = None

    df = pd.read_csv(
        filename, sep="\t", skiprows=n_meta, na_values=["null"])

    if unit == "m":
        # Get rid of "Time" and "Frame"
        marker_cols = df.columns[2:]
        df[marker_cols] /= 1000.0

    markers = [c[:-2] for c in df.columns if c.endswith(" X")]

    if verbose >= 1:
        print("[read_qualisys_tsv] Meta data:")
        print("  " + str(meta))
        print("[read_qualisys_tsv] Events:")
        print("  " + str(events))
        print("[read_qualisys_tsv] Available markers:")
        print("  " + (", ".join(markers)))
        print("[read_qualisys_tsv] Time delta: %g"
              % (1.0 / float(meta["FREQUENCY"])))

    return df


def _header_sizes(filename):
    """Determine number of lines in the header."""
    n_kv = 0    # Number of lines with metadata without events
    n_meta = 0  # Number of lines with metadata
    for i, l in enumerate(open(filename, "r")):
        if n_kv == 0 and l.startswith("EVENT"):
            n_kv = i
        elif l.startswith("Frame"):
            n_meta = i
            break
    if n_kv == 0:
        n_kv = n_meta
    return n_kv, n_meta


def get_trajectory(trajectories, marker, base_frame=None):
    """Get trajectory.

    Parameters
    ----------
    trajectories : DataFrame
        Trajectories from motion capture system

    marker : str
        Marker name

    base_frame : array, shape (3,)
        Position of base frame in which the trajectories will be represented

    Returns
    -------
    X : array, shape (n_steps, n_task_dims)
        Trajectory of the marker
    """
    markers = get_all_markers(trajectories)
    if marker not in markers:
        raise ValueError("Marker %r must be in %r" % (marker, markers))

    X = np.array([trajectories[marker + " X"],
                  trajectories[marker + " Y"],
                  trajectories[marker + " Z"]]).T
    if base_frame is not None:
        X -= base_frame
    return X


def get_segment(trajectories, split_indices, marker, segment_idx,
                base_frame=None):
    """Get segment.

    Parameters
    ----------
    trajectories : DataFrame
        Trajectories from motion capture system

    split_indices : array-like, shape (n_segments - 1,)
        Indices at which segments are split

    marker : string
        Marker name

    segment_idx : int
        Index of the segment

    base_frame : array, shape (3,)
        Position of base frame in which the trajectories will be represented

    Returns
    -------
    X : array, shape (n_steps, n_task_dims)
        Trajectory segment of the marker
    """
    if segment_idx > len(split_indices):
        raise ValueError("Segment index %d must be <= %d"
                         % (segment_idx, len(split_indices)))

    X = get_trajectory(trajectories, marker, base_frame)
    X = np.split(X, split_indices)[segment_idx]
    return X


def get_slice(trajectory, slice):
    """Get slice of trajectory.

    Parameters
    ----------
    trajectory : array-like, shape (n_steps, n_streams)
        Trajectory from motion capture system

    slice : Slice
        Slice object

    Returns
    -------
    sliced_trajectory : array-like, shape (n_slice_steps, n_streams)
        Part of trajectory
    """
    return trajectory[slice]


def plot(trajectories, markers=None, show_frame=False, ax=None,
         split_indices=np.array([]), prepare_plot=True, figsize=None,
         base_frame=None, call_plot=True):
    """Plot trajectories.

    Triangles will mark the beginning and circles the end of trajectories.

    Parameters
    ----------
    trajectories : DataFrame
        Trajectories from motion capture system

    markers : list, optional (default: all markers)
        Marker names

    show_frame : bool, optional (default: False)
        Show origin and orientation of reference frame

    ax : Axis, optional (default: new axis)
        Matplotlib 3D axis

    split_indices : array-like, optional (default: [])
        Indices at which we split the trajectories into parts

    prepare_plot : bool, optional (default: True)
        Add entries to legend, set labels

    base_frame : array, shape (3,)
        Position of base frame in which the trajectory will be plotted

    call_plot : bool, optional (default: True)
        Call show() for Matplotlib
    """
    if ax is None:
        fig = plt.figure(figsize=figsize)
        ax = fig.add_subplot(111, projection="3d", aspect="equal")

    if markers is None:
        markers = get_all_markers(trajectories)

    colors = plt.get_cmap("nipy_spectral")(np.linspace(0, 1, len(markers)))
    # HACK: yellow is not visible, shift to orange
    yellow_indices = np.logical_and(np.logical_and(colors[:, 0] > 0.5,
                                                   colors[:, 1] > 0.5),
                                    colors[:, 2] < 0.5)
    colors[yellow_indices, 1] -= 0.5

    n_parts = len(split_indices) + 1
    for marker_idx, marker in enumerate(markers):
        linestyles = cycle(["-", ":", "-.", "--"])
        for segment_idx in range(n_parts):
            linestyle = next(linestyles)
            part = get_segment(trajectories, split_indices, marker,
                               segment_idx, base_frame)
            label = marker if segment_idx == 0 else ""
            kwargs = {"color": colors[marker_idx], "lw": 3,
                      "linestyle": linestyle}
            if prepare_plot:
                kwargs["label"] = label
            ax.scatter(part[0, 0], part[0, 1], part[0, 2], marker="^",
                       color=colors[marker_idx], s=100)
            ax.scatter(part[-1, 0], part[-1, 1], part[-1, 2],
                       color=colors[marker_idx], s=100)
            ax.plot_streams_in_rows(part[:, 0], part[:, 1], part[:, 2],
                                    **kwargs)

    if show_frame:
        ax.plot_streams_in_rows(
            [0, 1], [0, 0], [0, 0], color="r", lw=5, alpha=0.2)
        ax.plot_streams_in_rows(
            [0, 0], [0, 1], [0, 0], color="g", lw=5, alpha=0.2)
        ax.plot_streams_in_rows(
            [0, 0], [0, 0], [0, 1], color="b", lw=5, alpha=0.2)

    if prepare_plot:
        ax.set_aspect("equal")
        ax.set_xlabel("x")
        ax.set_ylabel("y")
        ax.set_zlabel("z")
        ax.legend(loc="lower left")

    if call_plot:
        plt.show()

    return ax


class QualisysDemonstration(object):
    """Demonstrations that have been recorded with the Qualisys system.

    Parameters
    ----------
    filename : string
        File that contains motion capturing data

    unit : string, optional (default: 'm')
        Measurement unit for positions, must be in ["m", "mm"]

    verbose : int, optional (default: 0)
        Verbosity level

    Attributes
    ----------
    markers_ : list
        Names of markers

    dt_ : float
        Time between steps
    """
    def __init__(self, filename, unit="m", verbose=0):
        self.filename = filename
        self.unit = unit
        self.verbose = verbose
        self.base_frame = None

        self.trajectories, self.split_indices, self.dt_, self.markers_ = \
            read_qualisys_tsv(self.filename, self.unit, self.verbose)

    def make_relative_from(self, marker, first=False):
        """Subtract the markers positions from all generated trajectories.

        Parameters
        ----------
        marker : string
            Marker name

        first : bool, optional (default: False)
            Take only the first position as reference for all steps
        """
        self.base_frame = self.get_trajectory(marker)
        if first:
            self.base_frame = self.base_frame[0]

    def set_segmentation(self, split_indices):
        """Set segmentation.

        Parameters
        ----------
        split_indices : array-like, shape (n_segments - 1,)
            Indices that split the data into segments
        """
        self.split_indices = split_indices

    def get_segment(self, marker, segment_idx):
        """Get segment.

        Parameters
        ----------
        marker : string
            Marker name

        segment_idx : int
            Index of the segment

        Returns
        -------
        X : array, shape (n_steps, n_task_dims)
            Trajectory segment of the marker
        """
        return get_segment(
            self.trajectories, self.split_indices, marker, segment_idx)

    def get_trajectory(self, marker):
        """Get trajectory.

        Parameters
        ----------
        marker : string
            Marker name

        Returns
        -------
        X : array, shape (n_steps, n_task_dims)
            Trajectory of the marker
        """
        return get_trajectory(
            self.trajectories, marker, self.base_frame)

    def plot(self, markers, show_frame=False, ax=None, prepare_plot=True):
        """Plot trajectories.

        Parameters
        ----------
        markers : list
            Marker names

        show_frame : bool, optional (default: False)
            Show origin and orientation of reference frame

        ax : Axis, optional (default: new axis)
            Matplotlib 3D axis

        prepare_plot : bool, optional (default: True)
            Add entries to legend, set labels
        """
        return plot(self.trajectories, markers, show_frame, ax,
                    self.split_indices, prepare_plot, self.base_frame)


class InterpolatedQualisysDemonstration(QualisysDemonstration):
    def __init__(self, filename, unit="m", verbose=0):
        super(InterpolatedQualisysDemonstration, self).__init__(
            filename, unit, verbose)

    def get_trajectory(self, marker):
        """Get trajectory with interpolated data where markers were lost.

        Parameters
        ----------
        marker : string
            Marker name

        Returns
        -------
        X : array, shape (n_steps, n_task_dims)
            Trajectory of the marker
        """
        X = self.get_trajectory(marker)
        X = interpolate_nan(X)
        return X
