"""
Implements Probabilistic Road Map algorithm for global planning in euclidean space.

This basically does:
0) sample attainable configurations in the world
1) make a weighted graph describing cost of moving between the configurations
2) apply shortest path algorithm to find lowest weight path between given graph vertices

The version here does not maintain state, and cannot be iteratively updated.
"""
import numpy as np
from misc import projection as prj
import scipy.sparse as sp


def sample_near_obstacles_radially(obs, n_samples=4, mu_r=1., sig_r=.2):
    """
    Around each obstacle sample n_sample positions \in \reals^d.
    Set each sample to be at a distance ~ N(mu_r, sig_r^2) from center.
    - avoid dividing by zero (though, this chance is a.s. zero)
    :param obs: n, d | n obstacle centers in d dimensions
    :param n_samples:
    :param mu_r:
    :param sig_r:
    :return: x: n_samples*n, d | sampled positions
    """
    n, d = obs.shape
    x = np.random.randn(n_samples, n, d)
    norms = np.linalg.norm(x, axis=-1, keepdims=True)
    zero_mask = norms < 1e-8
    norms[zero_mask] = 1.
    x /= norms
    unit_vec = np.zeros((d,))
    unit_vec[0] = 1.
    x[zero_mask[..., 0]] = unit_vec
    r = np.random.randn(*norms.shape) * sig_r + mu_r
    x *= r
    x += obs
    return x.reshape(n_samples * n, d)


def make_complete_graph(x):
    """
    Connect all points with edges with weight given by
    euclidean distance.
    :param x: n, d | n points \in \reals^d
    :return: g: n, n | adjacency matrix for undirected weighted graph
    - [i, j] = distance(x[i], x[j])
    """
    dif = x[np.newaxis, ...] - x[:, np.newaxis]
    g = np.linalg.norm(dif, axis=-1)
    return g


def remove_colliding_edges(g, vertices, obs, obs_r):
    """
    Set edges in given graph to zero for edges that collide with obstacles.
    :param g: n, n | adjacency matrix for undirected weighted graph
    :param vertices: n, d | n points in \reals^d
    :param obs: m, d | obstacle positions
    :param obs_r: m, | obstacle radii - used for removing graph edges
    :return:
        modify graph by setting colliding distances to zero (= no-edge)
    """
    n = vertices.shape[0]
    i_ind, j_ind = np.mgrid[:n, :n]
    i_ind, j_ind = i_ind.ravel(), j_ind.ravel()
    mask = i_ind < j_ind
    i_ind = i_ind[mask]  # k,
    j_ind = j_ind[mask]
    dists = prj.pt2line_min_dists(obs, vertices[i_ind], vertices[j_ind])  # m, k
    invalid_mask = np.any(dists.T < obs_r, axis=1)
    g[i_ind[invalid_mask], j_ind[invalid_mask]] = 0
    g[g.T == 0] = 0


def get_path_list(predecessors, i, j):
    """
    Retrieve shortest path i->j from predecessors list
    :param predecessors: n, | predecessor indices for for graph with n vertices
    :param i: index of starting point in graph
    :param j: index of end point in graph
    :return: path_inds | list of indices [i, ..., j]
    """
    path_inds = []
    while j != i:
        path_inds.append(j)
        j = predecessors[j]
    path_inds.append(i)
    return path_inds[::-1]


def prm(x, y, obs, obs_r, sample_kwargs=None, graph_fcn=make_complete_graph):
    """

    :param x: d, | start position in euclidean space
    :param y: d, | end position
    :param obs: m, d | obstacle positions
    :param obs_r: m, | obstacle radii - used for removing graph edges
    :param sample_kwargs:
    :param graph_fcn: n, d -> n, n | builds adjacency matrix
        for undirected weighted graph
    :return:
        path_vertices: k, d | vertices forming shortest path found
        - [0] = x, [-1] = y
    """
    sample_kwargs = sample_kwargs or {}
    vertices = sample_near_obstacles_radially(obs, **sample_kwargs)
    vertices = np.vstack((x, vertices, y))
    n = vertices.shape[0]
    g = graph_fcn(vertices)
    remove_colliding_edges(g, vertices, obs, obs_r)
    res = sp.csgraph.dijkstra(
        csgraph=g, directed=False, indices=0, return_predecessors=True)
    dist_matrix, predecessors = res
    path_inds = get_path_list(predecessors, 0, n - 1)
    return vertices[path_inds]