"""
Object state defined by:

x_t: 13, | (p_t, q_t, v_t, H_t) full state vector
p_t: 3, | 3d position of center of mass [m]
q_t: 4, | quaternion representation for 3d orientation
v_t: 3, | velocity of center of mass [m/s]
H_t: 3, | angular momentum H_I^c [kg m^2 / s]
    about center of mass in body frame
"""
import numpy as np
from misc import quaternion as qua


class FreeFloating(object):
    """
    Model 6dof (translation + rotation) motion imparted by thrust inputs
    in absence of other forces.

    x_dot = fx + B(x)

    p_dot = v_t + 0
    q_dot = [0.5(J_B^c)^-1 A(q)H_I^c] \dot q + 0
        - A(q) is rotation matrix for orientation q
        - dot is the quaternion dot multiplication
    v_dot = 0 + A(q) f/m = A(q) b_part
    H_dot = 0 + \sum_{i=1}^n r_i \cross f_i u_i

    Approximately integrates rotation by normalizing the quaternion
    after each initial value problem (IVP) is solved.
    - Exact would be enforcing unit length throughout IVP
    - Smaller step sizes (timestep) => better accuracy
    """

    def __init__(self, mounted_thruster_model):
        self.m = 13
        self.A = np.zeros((self.m, self.m))
        self.A[:3, 7:10] = np.eye(3)
        f = mounted_thruster_model.f
        r = mounted_thruster_model.r
        self.n = f.shape[1]
        self.B = np.zeros((self.m, self.n))
        self.b_part = f / mounted_thruster_model.m
        self.B[10:, :] = np.cross(r.T, f.T).T
        self.J_Bc_inv = np.linalg.inv(mounted_thruster_model.im)

    def step(self, x_t, u_t):
        operators = self.get_process_operators(x_t)
        fx, gx = operators['fx'], operators['gx']
        x_dot = fx + gx.dot(u_t)
        return x_dot

    def get_process_operators(self, x_t):
        """
        Return operators which may be based on current state
        :param x_t:
        :return: dict{fx, gx, B(x_t)}
        """
        B_t = self.B
        q = x_t[3:7]
        Aq = qua.quaternion2rotation_matrix(q / np.linalg.norm(q))
        B_t[7:10, :] = Aq.T.dot(self.b_part)
        gx = B_t
        omega = self.J_Bc_inv.dot(x_t[10:])
        fx_q_dot = 0.5 * qua.dot_mult_matrix(q).dot(omega)
        fx = self.A.dot(x_t)
        fx[3:7] = fx_q_dot
        return dict(fx=fx, gx=gx, A=self.A, B=B_t)

    def reset_state(self, x_t):
        """
        Outside of IVP solver, normalize quaternion to unit length
        :param x_t:
        :return:
        """
        q = x_t[3:7]
        q /= max(np.linalg.norm(q), 1e-8)
        return x_t

    def get_mn(self):
        return self.m, self.n