import plotly.graph_objects as go
from misc import quaternion
from misc.matrix_building import Rt_3d
from vis import triangulated_surface as vtrs
from vis import triangulated_thruster_model as vttm


def loop_sim(model, x, q, u, dt, scene=None, is_quaternion=False, **kwargs):
    """

    :param model: triangulated surface
    :param x: k, 3 | 3d position of center of model
    :param q: k, 3 | exponential map representation for 3d orientation
    - iff is_quaternion=True, instead treat as k, 4 quaternions
    :param u: k, n | control inputs
    :param dt:
    :param scene:
    :param kwargs:
    :return:
    """
    k = x.shape[0]
    frame_data = []
    for i in range(k):
        if is_quaternion:
            Aq = quaternion.quaternion2rotation_matrix(q[i])
        else:
            Aq = quaternion.exp_map2rotation_matrix(q[i])
        Rt = Rt_3d(Aq.T, x[i])
        data = [vtrs.make_model_mesh(model, Rt)]
        data.extend(vttm.make_thrust_vectors(model, scales=u[i], Rt=Rt))
        frame_data.append(data)
    drawn_frames = [go.Frame(
        data=frame_data[i],
        layout=go.Layout(
            title_text='t = {:2.2f}s'.format(i * dt),
            scene=dict(
                xaxis=dict(range=x[i, 0] + [-10, 10], autorange=False),
                yaxis=dict(range=x[i, 1] + [-10, 10], autorange=False),
                zaxis=dict(range=x[i, 2] + [-10, 10], autorange=False),
                aspectmode='cube',
            ),
        )
    ) for i in range(len(frame_data))]

    fig = go.Figure(
        data=frame_data[0],
        layout=go.Layout(
            title='t = {:2.2f}s'.format(0),
            autosize=False,
            width=1200, height=800,
            scene=dict(
                xaxis=dict(range=x[0, 0] + [-10, 10], autorange=False),
                yaxis=dict(range=x[0, 1] + [-10, 10], autorange=False),
                zaxis=dict(range=x[0, 2] + [-10, 10], autorange=False),
                aspectmode='cube',
            ),
            updatemenus=[dict(
                type="buttons",
                buttons=[dict(label='Play',
                              method='animate',
                              args=[None, {'frame': {'duration': 200, 'redraw': True}, }])])]
        ),
        frames=drawn_frames
    )
    fig.update_layout(transition_duration=500)
    fig.show()