"""High-level library interface."""
import json
import os

import numpy as np
import pandas as pd
from .qualisys import read_qualisys_tsv
from .xsens import read_xsens_mvnx
from .pandas_utils import extract_segment
from . import normalization
from .plot import (plot_streams, plot_streams_in_rows,
                   plot_segmented_streams_in_rows)


def load(metadata=None, xsens_filename=None, qualisys_filename=None, verbose=0,
         **kwargs):
    """Load motion capture data.

    Parameters
    ----------
    metadata : str, optional (default: None)
        Location of metadata file

    xsens_filename : str, optional (default: None)
        Location of XSens MVNX source file

    qualisys_filename : str, optional (default: None)
        Location of Qualisys TSV source file

    verbose : int, optional (default: 0)
        Verbosity level

    kwargs : dict
        Additional keyword arguments that will be passed to the data loader
    """
    return Record(metadata, xsens_filename, qualisys_filename, verbose,
                  **kwargs)


class Record:
    """Motion capture record.

    Parameters
    ----------
    metadata : str
        Location of metadata file

    xsens_filename : str
        Location of XSens MVNX source file

    qualisys_filename : str
        Location of Qualisys TSV source file

    verbose : int, optional (default: 0)
        Verbosity level

    kwargs : dict
        Additional keyword arguments that will be passed to the data loader
    """
    def __init__(self, metadata, xsens_filename, qualisys_filename, verbose=0,
                 **kwargs):
        self.metadata = metadata
        self.xsens_filename = xsens_filename
        self.qualisys_filename = qualisys_filename
        self.verbose = verbose
        self.kwargs = kwargs

        self._load_data()

    def _load_data(self):
        self._load_metadata()

        success = self._try_to_load_data_from_metadata()

        if not success:
            if self.xsens_filename is not None:
                self.df = read_xsens_mvnx(
                    self.xsens_filename, verbose=self.verbose, **self.kwargs)
            elif self.qualisys_filename is not None:
                self.df = read_qualisys_tsv(
                    self.qualisys_filename, verbose=self.verbose,
                    **self.kwargs)
            else:
                raise ValueError("Did not find any data file.")

        self._try_to_add_time_stream()

    def _load_metadata(self):
        if self.metadata is None:
            self.metadata_content = {}
        else:
            with open(self.metadata, "r") as f:
                self.metadata_content = json.load(f)

    def _try_to_load_data_from_metadata(self):
        if not self.metadata_content:
            return False

        platform_type = self.metadata_content["platform_type"]
        filename = self.metadata_content["record_filename"]
        filename = os.path.expanduser(filename)
        if platform_type == ".tsv":
            assert self.qualisys_filename is None or \
                   self.qualisys_filename == filename
            self.df = read_qualisys_tsv(
                filename, verbose=self.verbose, *self.kwargs)
        elif platform_type == ".mvnx":
            assert self.xsens_filename is None or \
                   self.xsens_filename == filename
            self.df = read_xsens_mvnx(
                filename, verbose=self.verbose, **self.kwargs)
        else:
            raise NotImplementedError(
                "No parser for platform type '%s' found."
                % self.metadata_content["platform_type"])

        return True

    def _try_to_add_time_stream(self):
        if "Time" not in self.df and "frequency" in self.metadata_content:
            dt = 1.0 / float(self.metadata_content["frequency"])
            time = np.arange(0.0, len(self.df) * dt, dt)
            self.df["Time"] = time

    def get_available_streams(self):
        """Get names of available data streams.

        Returns
        -------
        stream_names : list
            Names
        """
        return list(self.df.columns)

    def get_segments_as_dataframes(
            self, label, streams, label_field="label",
            start_field="start_index", end_field="end_index"):
        """Get segments as pandas DataFrames.

        Parameters
        ----------
        label : str
            Label of the segments that should be plotted

        streams : list of str
            Regular expressions that will be used to find matching streams in
            the columns of 'trajectory'

        label_field : str, optional (default: 'label')
            Field in the metadata file that contains the label of a segment.
            Could also be 'l1' or 'l2'.

        start_field : str, optional (default: 'start_index')
            Field in the metadata file that contains the start index of a
            segment. Could also be 'start_frame'.

        end_field : str, optional (default: 'end_index')
            Field in the metadata file that contains the end index of a
            segment. Could also be 'end_frame'.

        Returns
        -------
        trajectories : list of DataFrame
            A list of segments from the original time series
        """
        dataframes = []
        for segment in self.metadata_content["segments"]:
            if segment[label_field] == label:
                start_index = int(segment[start_field])
                end_index = int(segment[end_field])
                segment = extract_segment(
                    self.df, streams, start_index, end_index, keep_time=True)
                dataframes.append(segment)
        if len(dataframes) == 0:
            raise ValueError("Found no segment with label '%s'" % label)
        return dataframes

    def get_segment_names(self, label_field="label"):
        """Get names of available segments.

        Parameters
        ----------
        label_field : str, optional (default: 'label')
            Field in the metadata file that contains the label of a segment.
            Could also be 'l1' or 'l2'.

        Returns
        -------
        segment_names : list of str
            A list of segment labels
        """
        segment_names = []
        for segment in self.metadata_content["segments"]:
            segment_names.append(segment[label_field])
        return segment_names

    def plot(self, streams, ax=None):
        """Plot streams with matplotlib.

        Parameters
        ----------
        streams : list of str
            Regular expressions that will be used to find matching streams in
            the columns of 'trajectory'

        ax : Matplotlib axis, optional (default: new axis)
            Axis to which we plot the trajectory

        Returns
        -------
        ax : Matplotlib axis
            Axis to which we plotted the trajectory
        """
        return plot_streams(self.df, streams, ax)

    def plot_streams_in_rows(self, streams=None, axes=None):
        """Plot streams with matplotlib in rows.

        Parameters
        ----------
        streams : list of str
            Regular expressions that will be used to find matching streams in
            the columns of 'trajectory'

        axes : list of Matplotlib axes, optional (default: new axes)
            Axes to which we plot the trajectories

        Returns
        -------
        axes : list of Matplotlib axes, optional (default: new axes)
            Axes to which we plot the trajectories
        """
        return plot_streams_in_rows(self.df, streams, axes)

    def plot_segmented_streams_in_rows(self, label, streams=None, axes=None):
        """Plot segmented streams with matplotlib in rows.

        Parameters
        ----------
        label : str
            Label of the segments that should be plotted

        streams : list of str
            Regular expressions that will be used to find matching streams in
            the columns of 'trajectory'

        axes : list of Matplotlib axes, optional (default: new axes)
            Axes to which we plot the trajectories

        Returns
        -------
        axes : list of Matplotlib axes, optional (default: new axes)
            Axes to which we plot the trajectories
        """
        trajectories = self.get_segments_as_dataframes(label, streams)
        return plot_segmented_streams_in_rows(trajectories, streams, axes)


def to_frequency(trajectories, target_frequency):
    """Resample trajectories with new frequency.

    Parameters
    ----------
    trajectories : list of DataFrame or DataFrame
        Time series data

    target_frequency : float
        Target frequency

    Returns
    -------
    trajectories : list of DataFrame or DataFrame
        Resampled trajectories
    """
    if isinstance(trajectories, pd.DataFrame):
        return normalization.to_frequency(trajectories, target_frequency)
    else:
        return [normalization.to_frequency(t, target_frequency)
                for t in trajectories]


def scale_duration(trajectories, target_duration, dt, kind="linear"):
    """Scale trajectory to target duration.

    Parameters
    ----------
    trajectories : list of DataFrame or DataFrame
        Time series data

    target_duration : float
        Target duration

    dt : float
        Time between each step

    kind : str, optional (default: 'linear')
        Kind of interpolation. See scipy.interpolate.interp1d for details.

    Returns
    -------
    trajectories : list of DataFrame or DataFrame
        Temporally scaled time series data
    """
    if isinstance(trajectories, pd.DataFrame):
        return normalization.scale_duration(
            trajectories, target_duration, dt, kind)
    else:
        return [normalization.scale_duration(t, target_duration, dt, kind)
                for t in trajectories]


def start_at_origin(trajectories, marker):
    """Normalize trajectory of marker to start at position (0, 0, 0).

    Parameters
    ----------
    trajectories : list of DataFrame or DataFrame
        Time series data

    marker : str
        Marker name, each markers must have three corresponding streams:
        marker + ' X', marker + ' Y', and marker + ' Z'.

    Returns
    -------
    trajectories : list of DataFrame or DataFrame
        Time series data with marker starting at position (0, 0, 0)
    """
    if isinstance(trajectories, pd.DataFrame):
        return normalization.start_at_origin(trajectories, marker)
    else:
        return [normalization.start_at_origin(t, marker) for t in trajectories]
