import numpy as np


class Camera(object):

    def __init__(self, K, Rt, cam_cols=None, cam_rows=None):
        """
        Standard:
        ^ x
        |             z out from camera, towards world
        |
         --------> y

        Coordinates:
        [cam] <--Rt-- [body] <--body2world_Rt-- [world]
        :param K: 3, 4 | intrinsic matrix
        :param Rt: 4, 4 | extrinsic matrix - body2cam
        :param cam_cols:
        :param cam_rows:
        """
        self.cam_cols = cam_cols
        self.cam_rows = cam_rows
        self.K = K
        self.Rt = Rt
        self.P = K.dot(Rt)
        self.body2world_Rt = np.eye(4)
        self.world2body_Rt = np.eye(4)

    def set_body2world_Rt(self, Rt):
        self.body2world_Rt = Rt
        world2body_Rt = np.eye(4)
        world2body_Rt[:3, :3] = self.body2world_Rt[:3, :3].T
        world2body_Rt[:-1, -1] = world2body_Rt[:3, :3].dot(-self.body2world_Rt[:-1, -1])
        self.world2body_Rt = world2body_Rt

    def get_world_projection_matrix(self):
        P_world = self.P.dot(self.world2body_Rt)
        return P_world

    def get_world2cam_Rt(self):
        return self.Rt.dot(self.world2body_Rt)

    def project(self, x):
        P_world = self.get_world_projection_matrix()
        return P_world.dot(x)

    def get_camera_center_world(self):
        Rt = self.get_world2cam_Rt()
        cam_center = Rt[:3, :3].T.dot(-Rt[:-1, -1])
        return cam_center