import numpy as np
from pytransform3d.rotations import random_quaternion, quaternion_slerp
from pytransform3d.transformations import (
    exponential_coordinates_from_transform,
    transform_from_exponential_coordinates)
from mocap.cleaning import (
    smooth_quaternion_trajectory, smooth_exponential_coordinates,
    median_filter, mirror_screw_axis_direction)
from nose.tools import assert_true, assert_false
from numpy.testing import assert_array_almost_equal


def test_smooth_quaternion_trajectory():
    random_state = np.random.RandomState(232)
    q_start = random_quaternion(random_state)
    if q_start[1] < 0.0:
        q_start *= -1.0
    q_goal = random_quaternion(random_state)
    n_steps = 101
    Q = np.empty((n_steps, 4))
    for i, t in enumerate(np.linspace(0, 1, n_steps)):
        Q[i] = quaternion_slerp(q_start, q_goal, t)
    Q_broken = Q.copy()
    Q_broken[20:23, :] *= -1.0
    Q_broken[80:, :] *= -1.0
    Q_smooth = smooth_quaternion_trajectory(Q_broken)
    assert_array_almost_equal(Q_smooth, Q)


def test_smooth_exponential_coordinates():
    Sthetas = np.loadtxt("test/data/screw_trajectory.txt")
    Sthetas_diff_norms = np.linalg.norm(np.diff(Sthetas, axis=0), axis=1)
    assert_true(np.any(Sthetas_diff_norms > 1.0))
    Sthetas_smooth = smooth_exponential_coordinates(Sthetas)
    Sthetas_smooth_diff_norms = np.linalg.norm(
        np.diff(Sthetas_smooth, axis=0), axis=1)
    assert_false(np.any(Sthetas_smooth_diff_norms > 1.0))


def test_mirror_screw_axis():
    pose = np.array([[ 0.10156069, -0.02886784,  0.99441042,  0.6753021 ],
                     [-0.4892026 , -0.87182166,  0.02465395, -0.2085889 ],
                     [ 0.86623683, -0.48897203, -0.10266503,  0.30462221],
                     [ 0.        ,  0.        ,  0.        ,  1.        ]])
    exponential_coordinates = exponential_coordinates_from_transform(pose)
    mirror_exponential_coordinates = mirror_screw_axis_direction(
        exponential_coordinates.reshape(1, 6))[0]
    pose2 = transform_from_exponential_coordinates(
        mirror_exponential_coordinates)
    assert_array_almost_equal(pose, pose2)


def test_median_filter():
    X = np.ones((100, 1))
    X[50, 0] = 10000.0
    X_filtered = median_filter(X, 3)
    assert_array_almost_equal(X_filtered, np.ones((100, 1)))

    X = np.empty((100, 5))
    X[:, 0] = -20.0
    X[:, 1] = 1.0
    X[:, 2] = -20.0
    X[:, 3] = 30.0
    X[:, 4] = -20.0

    X_corrupted = X.copy()
    X_corrupted[40, 1] = 10000.0
    X_corrupted[50, 2] = -10000.0
    X_corrupted[60, 3] = 10000.0
    X_filtered = median_filter(X_corrupted, 3)
    assert_array_almost_equal(X_filtered, X)
