import numpy as np


def pt2line_min_dists(a, b0, b1):
    """
    Find distance of each point to each line segment
    :param a: n, d | n points in \reals^d
    :param b0: k, d | start points of k lines
    :param b1: k, d | end points of k lines
    :return: dists: n, k | [i, j] = distance of a[i] to line b0[j]-b1[j]
    """
    line_vectors = b1 - b0  # k, d
    pt2start_vectors = a - b0[:, np.newaxis]  # k, n, d
    dots = (line_vectors[:, np.newaxis] * pt2start_vectors).sum(axis=2)  # k, n
    norms = (line_vectors ** 2).sum(axis=1)  # k,
    norms[norms < 1e-8] = 1e-8
    t = dots.T/norms  # n, k
    np.clip(t, 0, 1, out=t)
    proj = b0.T + (t[:, np.newaxis] * line_vectors.T)  # n, 2, k
    rej = a[:, :, np.newaxis] - proj  # n, 2, k
    sq_dists = (rej ** 2).sum(axis=1)  # n, k
    return np.sqrt(sq_dists)