"""Data cleaning operations."""
import pandas as pd
import numpy as np
from scipy import interp
from scipy.signal import medfilt
try:
    from pytransform3d.batch_rotations import smooth_quaternion_trajectory
    from pytransform3d.trajectories import mirror_screw_axis_direction
except ImportError:
    raise ImportError(
        "Install the latest version of pytransform3d (develop branch)")


def interpolate_nan(X):
    """Remove NaNs with linear interpolation.

    This function accepts DataFrame objects and numpy arrays. When a NumPy
    array has to be converted, exact zeros are interpreted as NaNs, too.
    Furthermore an exception is thrown if the trajectory only contains NaNs.

    Parameters
    ----------
    X : array, shape (n_steps, n_dims) or DataFrame
        Trajectory

    Returns
    -------
    X : array, shape (n_steps, n_dims) or DataFrame
        Trajectory without NaN
    """
    if isinstance(X, pd.DataFrame):
        return X.interpolate(method="linear", limit_direction="both")
    else:
        nans = np.logical_or(np.isnan(X), X == 0.0)

        if np.all(nans):
            raise ValueError("Only NaN")

        for d in range(X.shape[1]):
            def x(y):
                return y.nonzero()[0]
            X[nans[:, d], d] = interp(x(nans[:, d]), x(~nans[:, d]),
                                      X[~nans[:, d], d])
        return X


def median_filter(X, window_size):
    """Median filter for trajectories.

    A median filter should be used to remove large jumps caused by noisy
    measurements or interpolation artifacts that often occur after
    normalization of orientation representations with ambiguities
    (such as quaternions).

    Parameters
    ----------
    X : array, shape (n_steps, n_dims) or DataFrame
        Trajectory

    Returns
    -------
    X : array, shape (n_steps, n_dims) or DataFrame
        Filtered trajectory
    """
    if isinstance(X, pd.DataFrame):
        return X.rolling(window_size).median()
    else:
        return np.column_stack(
            [medfilt(X[:, d], window_size) for d in range(X.shape[1])])


def smooth_exponential_coordinates(Sthetas):
    """Smooth trajectories of exponential coordinates.

    Exponential coordinates of transformation are not unique. There are at
    least 2 representations of one transformation matrix. This function
    makes trajectories in exponential coordinates more smooth.

    Parameters
    ----------
    Sthetas : array-like, shape (n_steps, 6)
        Exponential coordinates of transformation:
        (omega_x, omega_y, omega_z, v_x, v_y, v_z)

    Returns
    -------
    Sthetas : array-like, shape (n_steps, 6)
        Exponential coordinates of transformation:
        (omega_x, omega_y, omega_z, v_x, v_y, v_z)
    """
    Sthetas = np.copy(Sthetas)
    diffs = np.linalg.norm(Sthetas[:-1, :3] - Sthetas[1:, :3], axis=1)
    sums = np.linalg.norm(Sthetas[:-1, :3] + Sthetas[1:, :3], axis=1)
    before_jump_indices = np.where(diffs > sums)[0]

    before_jump_indices = before_jump_indices.tolist()
    before_jump_indices.append(len(Sthetas))

    slices_to_correct = np.array(
        list(zip(before_jump_indices[:-1], before_jump_indices[1:])))[::2]
    for i, j in slices_to_correct:
        if i + 1 == j:  # outlier, will be ignored and should be filtered
            continue
        Sthetas[i + 1:j] = mirror_screw_axis_direction(Sthetas[i + 1:j])
    return Sthetas
