"""Plotting functionality."""
import matplotlib.pyplot as plt
from .pandas_utils import match_columns
from itertools import cycle


SEABORN_PALETTE_DEEP = [
    "#4C72B0", "#DD8452", "#55A868", "#C44E52", "#8172B3",
    "#937860", "#DA8BC3", "#8C8C8C", "#CCB974", "#64B5CD"]


def plot_streams(trajectory, streams, ax=None):
    """Plot streams.

    Parameters
    ----------
    trajectory : DataFrame
        A collection of time series data

    streams : list of str
        Regular expressions that will be used to find matching streams in
        the columns of 'trajectory'

    ax : Matplotlib axis, optional (default: new axis)
        Axis to which we plot the trajectory

    Returns
    -------
    ax : Matplotlib axis
        Axis to which we plotted the trajectory
    """
    newaxis = ax is None
    if newaxis:
        ax = plt.subplot(111)

    columns = match_columns(trajectory, streams, keep_time=False)

    if "Time" in trajectory:
        t = trajectory["Time"]
        xlabel = "Time [s]"
    else:
        t = range(len(trajectory))
        xlabel = "Step"

    colors = cycle(SEABORN_PALETTE_DEEP)

    for c in columns:
        color = next(colors)
        ax.plot(t, trajectory[c], color=color, label=c)

    if newaxis:
        ax.legend()
        ax.set_xlabel(xlabel)

    return ax


def plot_segmented_streams_in_rows(trajectories, streams=None, axes=None):
    """Plot segmented streams in rows.

    Parameters
    ----------
    trajectories : list of DataFrame
        A collection of segmented time series data

    streams : list of str
        Regular expressions that will be used to find matching streams in
        the columns of 'trajectory'

    axes : list of Matplotlib axes, optional (default: new axes)
        Axes to which we plot the trajectories

    Returns
    -------
    axes : list of Matplotlib axes, optional (default: new axes)
        Axes to which we plot the trajectories
    """
    for trajectory in trajectories:
        axes = plot_streams_in_rows(trajectory, streams, axes)
    return axes


def plot_streams_in_rows(trajectory, streams, axes=None):
    """Plot streams in rows.

    Parameters
    ----------
    trajectory : DataFrame
        A collection of time series data

    streams : list of str
        Regular expressions that will be used to find matching streams in
        the columns of 'trajectory'

    axes : list of Matplotlib axes, optional (default: new axes)
        Axes to which we plot the trajectories

    Returns
    -------
    axes : list of Matplotlib axes, optional (default: new axes)
        Axes to which we plot the trajectories
    """
    columns = match_columns(trajectory, streams, keep_time=False)

    newaxes = axes is None
    if newaxes:
        axes = [plt.subplot(len(columns), 1, 1 + i)
                for i, c in enumerate(columns)]

    if "Time" in trajectory:
        t = trajectory["Time"]
        xlabel = "Time [s]"
    else:
        t = range(len(trajectory))
        xlabel = "Step"

    colors = cycle(SEABORN_PALETTE_DEEP)

    for i, c in enumerate(columns):
        color = next(colors)
        axes[i].plot(t, trajectory[c], color=color, label=c)
        if newaxes:
            axes[i].legend()
            axes[i].set_xlabel(xlabel)

    return axes
