import numpy as np
from active_sensing import design_1d_variance as d1v
from misc.distances import circular_1d


def plan_planar_1d_var_orientations(
        x, point_x, var_bounds, lambda_diff=0.1,
        var_fcn=None, cost_fcn=None,
        n_angles=16, var_fcn_kwargs=None, cost_fcn_kwargs=None):
    """
    Plan angles in plane to satisfy sensing accuracy based on
    spherical variance model.
    :param x: k, 2 | already planned positions in plane
        at which to choose orientation
        - assume these are equally spaced in time
        (so costs based on transitions are consistent)
    :param point_x: m, 2 | points of interest
    :param var_bounds: m, | maximum variance allowed at final timestep
        for each of m points
    :param lambda_diff: scalar weight for cost of changing choice over time
    :param var_fcn: (x, angles, y) -> 1/M: (m, k, n_angles)
        - facing at angle from x, variance of observation of y position
    :param cost_fcn: (x, angles) -> c: (k, n_angles)
        - cost of each choice of angle at each timestep
    :param n_angles:
    :param var_fcn_kwargs:
    :param cost_fcn_kwargs:
    :return:
        yaw: k, | feasible plan of angle for each timestep, or default if not valid
        is_valid: True iff feasible plan found
    """
    var_fcn = var_fcn if var_fcn else distance_angle_spherical_var
    cost_fcn = cost_fcn if cost_fcn else face_movement_direction_cost
    var_fcn_kwargs = var_fcn_kwargs if var_fcn_kwargs else dict()
    cost_fcn_kwargs = cost_fcn_kwargs if cost_fcn_kwargs else dict()
    angles = grid_1d_angles(n_angles)

    M_inv = var_fcn(x, angles, point_x, **var_fcn_kwargs)
    c = cost_fcn(x, angles, **cost_fcn_kwargs)
    yaw_choices, is_valid = d1v.solve_instance(c, lambda_diff, 1/M_inv, var_bounds)

    # extract by taking distribution average
    yaw = (yaw_choices * angles).sum(axis=1)
    if not is_valid:
        yaw = make_movement_direction_angles(x)
        try:
            yaw[-1] = cost_fcn_kwargs['angle_last']
        except KeyError:
            pass
    return yaw, is_valid


def grid_1d_angles(n_angles=16):
    return np.linspace(0, 2*np.pi, endpoint=False, num=n_angles)


def distance_angle_spherical_var(x, angles, y, a=(1., 5, .3)):
    """
    Compute variance of each possible choice according to model:
    y_hat ~ N(y, a0 + .5 exp{a1 |w - w'|} + exp{a2 ||x - y||} )
    where current state is [x, w] and heading pointing to point at y is w'
    :param x: k, 2
    :param angles: n,
    :param y: m, 2
    :param a: 3, | [a0, a1, a2]
    :return: M_inv: m, k, n | [i, t, j] = variance of observation of point i
        at state x[t], angle[j]
    """
    m = y.shape[0]
    k = x.shape[0]
    n = angles.size
    w_p = np.arctan2(
        y[:, 1][:, np.newaxis] - x[:, 1],
        y[:, 0][:, np.newaxis] - x[:, 0],
    )  # m, k
    dif_yaw = w_p[..., np.newaxis] - angles.reshape(1, 1, -1)  # m, k, n
    dif_yaw = circular_1d(dif_yaw.ravel(), 2*np.pi).reshape(m, k, n)
    dif_x = np.linalg.norm(y[:, np.newaxis, :] - x[np.newaxis, ...], axis=2)  # m, k
    M_inv = a[0] + .5 * np.exp(a[1] * dif_yaw) + np.exp(a[2] * dif_x[..., np.newaxis])
    return M_inv


def face_movement_direction_cost(x, angles, sd=1., angle_last=None, sd_last=0.1):
    """
    Set cost(yaw_t) = |yaw_t - yaw_t^h|^2 / 2 sd^2
    where yaw_t is chosen angle at timestep t and
    yaw_t^h is that timestep's direction of motion.
    Use separate 'standard deviation' weighting for final timestep, since this
    direction is often specified by the user.
    :param x: k, 2 | [t] = position in plane at timestep t
    :param angles:
    :param sd:
    :param angle_last:
    :param sd_last:
    :return: c: k, n | c[t] = linear cost of each of n choices at timestep t
    """
    k = x.shape[0]
    n = angles.size
    motion_yaw = make_movement_direction_angles(x)
    motion_yaw[-1] = angle_last if angle_last is not None else motion_yaw[-2]
    dif_yaw = motion_yaw[:, np.newaxis] - angles
    dif_yaw = circular_1d(dif_yaw.ravel(), 2*np.pi).reshape(k, n)
    c = np.zeros((k, n))
    c[:-1] = dif_yaw[:-1] ** 2 / (2 * sd)
    c[-1] = dif_yaw[-1] ** 2 / (2 * sd_last)
    return c


def make_movement_direction_angles(x):
    """
    Make yaw angles corresponding to direction of motion.
    :param x: k, 2
    :return: k,
    """
    k = x.shape[0]
    motion_yaw = np.zeros((k,))
    np.arctan2(
        x[1:, 1] - x[:-1, 1],
        x[1:, 0] - x[:-1, 0],
        out=motion_yaw[:-1],
    )
    motion_yaw[-1] = motion_yaw[-2]
    return motion_yaw