import numpy as np


class TriangulatedSurface(object):
    """
    center: center of mass [m^3]
    m: total mass [kg]
    im: 3, 3 | inertia matrix relative to center of mass [kg m^2]
    """

    def __init__(self, rho, vertices=None, triangles=None):
        """
        :param rho: constant density of the solid object [kg/m^3]
        :param vertices: n, 3 | vertices of 3d object [m]
        :param triangles: m, 3 | triangles composing the object's surface
            Assume that these are vertices are consistently ordered
            For a closed body, all surface normals pointing outward is sufficient
        """
        self.rho = rho
        self.vertices = vertices
        self.triangles = triangles
        self.m = None
        self.center = None
        self.im = None


def calculate_mass_properties(vertices, triangles, rho):
    """
    Calculate mass properties of solid object via surface integrals
    following supplemental material for:
    M. Bacher, et al., ”Spin-It: Optimizing Moment of
    Inertia for Spinnable Objects", ACM Trans. Graphics, 2014.
    :param vertices: n, 3 | vertices of 3d object [m]
    :param triangles: m, 3 | triangles composing the object's surface
        Assume that these are vertices are consistently ordered
    :param rho: constant density of the solid object [kg/m^3]
    :return:
    """
    s = np.zeros((10,))
    for ijk in triangles:
        a, b, c = vertices[ijk]
        bars = vertices[ijk, :].copy()
        bars = bars[:, [1, 2, 0]]
        u = b - a
        v = c - a
        n = np.cross(u, v)
        h1 = a + b + c
        h2 = a**2 + b * (a + b)
        h3 = h2 + c * h1
        h4 = a**3 + b * h2 + c * h3
        h5 = h3 + a * (h1 + a)
        h6 = h3 + b * (h1 + b)
        h7 = h3 + c * (h1 + c)
        h8 = bars[0] * h5 + bars[1] * h6 + bars[2] * h7
        s[0] += (n * h1)[0]
        s[1:4] += n * h3
        s[4:7] += n * h8
        s[7:] += n * h4
    s[0] /= 6
    s[1:4] /= 24
    s[4:7] /= 120
    s[7:] /= 60
    s *= rho
    m = s[0]
    center = s[1:4] / m
    im = np.array([
        [s[8] + s[9], -s[4], -s[6]],
        [-s[4], s[7] + s[9], -s[5]],
        [-s[6], -s[5], s[7] + s[8]],
    ])
    return m, center, im