import numpy as np
import scipy.integrate as itg


def solve_discrete_difference_eqns(discrete_state_model, x0, u_fcn, DT, T):
    """

    :param discrete_state_model:
    :param x0: m, |
    :param u_fcn: (x_t, t) -> u_t \in n, |
    :param DT: sampling time
    :param T: final time, steps made for t < T
    :return:
        x: m, k | ith column is state at t = i*DT
        u: n, k | ith column is control at t
    """
    m, n = discrete_state_model.get_mn()
    t = np.arange(0, T, DT)
    x = np.zeros((m, t.size))
    u = np.zeros((n, t.size))
    x[:, 0] = x0.copy()
    for i in range(0, t.size-1):
        operators = discrete_state_model.get_process_operators(x[:, i])
        u[:, i] = u_fcn(x[:, i], i*DT, **operators)
        x[:, i+1] = discrete_state_model.step(x[:, i], u[:, i])
    return x, u


def solve_cts_eqns(cts_state_model, x0, u_fcn, DT, T, cts_sim_model=None):
    """
    :param cts_state_model: model used for controller
    :param x0: m, |
    :param u_fcn: (x_t, t) -> u_t \in n, |
    :param DT: sampling time
    :param T: final time, steps made for t < T
    :param cts_sim_model: true physical model used to advance simulation
    :return:
        x: m, k | ith column is state at t = i*DT
        u: n, k | ith column is control at t
    """
    cts_sim_model = cts_sim_model if cts_sim_model else cts_state_model
    m, n = cts_state_model.get_mn()
    t = np.arange(0, T, DT)
    x = np.zeros((m, t.size))
    u = np.zeros((n, t.size))
    x[:, 0] = x0.copy()
    for i in range(0, t.size - 1):
        operators = cts_state_model.get_process_operators(x[:, i])
        u[:, i] = u_fcn(x[:, i], i * DT, **operators)
        t_span = (i*DT, (i+1 + 0.5)*DT)
        t_eval = (i+1)*DT,
        sol = itg.solve_ivp(
            lambda t, x: cts_sim_model.step(x, u[:, i]),
            t_span, x[:, i], t_eval=t_eval)
        x[:, [i + 1]] = sol.y
    return x, u