"""
Use the Newton-Euler equations and simplifications resulting from:

1) Object motion occurs only in 2d plane
- cross product formulas are simplified
- torques restricted to the z-axis (following right-handed axes, plane is xy)
2) Moment of inertia is a multiple of identity matrix
- angular momentum is zero

Object state defined by:

x_t: 6, | (p_t, v_t, \theta_t, \omega_t)
p_t: 2, | 2d position of center of mass [m]
v_t: 2, | velocity of center of mass [m/s]
\theta_t: yaw [rad]
\omega_t: yaw rate [rad/s]

Motion of rigid body is determined by thrusters that apply force u_i
"""
import numpy as np
from scipy import linalg as la
from misc import matrix_building as mb


class NoDragCts(object):
    """
    Bc = ( 0_{2,k}   =  ( 0_{2,k}
           B_t            rot(\theta_t)f/m
           0_{1,k}        0_{1,k}
           c'      )      c'      )
    where c'u = 1/i_m \sum_{i=1}^n (-r_{i,2} ; r{i,1})'f_i u_i
    The interpretation is that Bc applies forces for linear momentum
    in the global frame using current rotation, and torques in the local frame
    since no flips result from rotations in the plane.
    """

    def __init__(self, mounted_thruster_model):
        self.m = 6
        self.Ac = np.zeros((self.m, self.m))
        self.Ac[:2, 2:4] = np.eye(2)
        self.Ac[4, -1] = 1.
        f = mounted_thruster_model.f
        r = mounted_thruster_model.r
        self.n = f.shape[1]
        self.b_part = 1/mounted_thruster_model.m * f
        self.Bc_t = np.zeros((self.m, self.n))
        self.Bc_t[-1, :] = (-r[1, :] * f[0, :] + r[0, :] * f[1, :]) / \
            mounted_thruster_model.im

    def step(self, x_t, u_t):
        operators = self.get_process_operators(x_t)
        A, B_t = operators['A'], operators['B']
        x_dot = A.dot(x_t) + B_t.dot(u_t)
        return x_dot

    def get_process_operators(self, x_t):
        """
        Return operators which may be based on current state
        :param x_t:
        :return: dict{A, B(x_t)}
        """
        B_t = self.Bc_t
        B_t[2:4, :] = mb.rot_2d(x_t[4]).dot(self.b_part)
        return dict(A=self.Ac, B=B_t)

    def get_mn(self):
        return self.m, self.n


class AccelApproxDragCts(object):
    """
    Simulate drag by exponential decay to velocity, angular velocity
    """
    def __init__(self, no_drag_cts_model, alpha):
        """
        :param no_drag_cts_model:
        :param alpha: in (0, 1) | 1 == no drag model
        """
        no_drag_cts_model.Ac[:2, 2:4] = np.eye(2) * alpha
        no_drag_cts_model.Ac[4, -1] = 1. * alpha
        self.no_drag_cts_model = no_drag_cts_model
        self.alpha = alpha

    def step(self, x_t, u_t):
        operators = self.get_process_operators(x_t)
        A, B_t = operators['A'], operators['B']
        x_dot = A.dot(x_t) + B_t.dot(u_t)
        return x_dot

    def get_process_operators(self, x_t):
        return self.no_drag_cts_model.get_process_operators(x_t)

    def get_mn(self):
        return self.no_drag_cts_model.get_mn()


class ZeroOrderHoldDiscrete(object):
    """
    Use ZOH for discretization.
    - (= assume control u_t held constant until next sample time)
    - This is exact for linear time invariant systems (A, B do not vary with time).
    For the continuous equation
    x_dot(t) = Ax(t) + Bu(t),
    discretize as
    x_{t+1} = A_d x_t + B_d u_t,
    where the system matrices are given by
    A_d = expm{A dt}
    B_d = (itg_{tau=0}^{dt} expm{A tau} dtau) B
    """

    def __init__(self, cts_model, dt):
        """
        :param cts_model:
        :param dt: sampling period [s]
        """
        self.cts_model = cts_model
        self.dt = dt
        self.Ad_t = np.zeros(())
        self.Bd_t = np.zeros(())

    def step(self, x_t, u_t):
        operators = self.get_process_operators(x_t)
        Ad_t, Bd_t = operators['A'], operators['B']
        x_p = Ad_t.dot(x_t) + Bd_t.dot(u_t)
        return x_p

    def get_process_operators(self, x_t):
        operators = self.cts_model.get_process_operators(x_t)
        Ac_t, Bc_t = operators['A'], operators['B']
        m = Ac_t.shape[0]
        M = np.block([[Ac_t, np.eye(m)],
                      [np.zeros_like(Ac_t), np.zeros_like(Ac_t)]])
        eM = la.expm(M * self.dt)
        self.Ad_t = eM[0:m, 0:m]
        self.Bd_t = eM[0:m, m:].dot(Bc_t)
        return dict(A=self.Ad_t, B=self.Bd_t)

    def get_mn(self):
        return self.cts_model.get_mn()