#!/usr/bin/python3
import rclpy
from rclpy.node import Node
from sensor_msgs.msg import Image
import pyrealsense2 as rs
import signal, time, numpy as np
import sys, cv2
from cv_bridge import CvBridge


class CameraDriver(Node):

    def __init__(self, name="cam_driver", timerFreq = 1/60.0):
        super().__init__(name) # Create the node

        # Initialize publishers
        self._rgb_publisher = self.create_publisher(Image, 'cam_rgb', 10)
        self._depth_publisher = self.create_publisher(Image, 'cam_depth', 10)
        self._infra_publisher_1 = self.create_publisher(Image, 'cam_infra_1',10)
        self._infra_publisher_2 = self.create_publisher(Image, 'cam_infra_2',10)

        # Initialize a clock for the publisher
        self.create_timer(timerFreq, self.publish_imgs)


        # ------------------------------ Initialize camera ------------------------------ #

        ## Configure depth and color streams
        self._pipeline = rs.pipeline()
        self._config = rs.config()

        ## Get device product line for setting a supporting resolution
        pipeline_wrapper = rs.pipeline_wrapper(self._pipeline)
        pipeline_profile = self._config.resolve(pipeline_wrapper)
        device = pipeline_profile.get_device()
        device_product_line = str(device.get_info(rs.camera_info.product_line))

        print( f"Connect: {device_product_line}" )
        found_rgb = True
        for s in device.sensors:
            print( "Name:" + s.get_info(rs.camera_info.name) )
            if s.get_info(rs.camera_info.name) == 'RGB Camera':
                found_rgb = True

        if not (found_rgb):
            print("Depth camera required !!!")
            exit(0)

        ## Configure stream width, height, format and frequency
        self._config.enable_stream(rs.stream.color, width=848, height=480, format=rs.format.bgr8, framerate=60)
        self._config.enable_stream(rs.stream.depth, width=848, height=480, format=rs.format.z16, framerate=60)
        self._config.enable_stream(rs.stream.infrared, 1, width=848, height=480, format=rs.format.y8, framerate=60)
        self._config.enable_stream(rs.stream.infrared, 2, width=848, height=480, format=rs.format.y8, framerate=60)

        ## Start the acquisition    
        self._pipeline.start(self._config)

    def read_imgs(self):
        """lire et traduire images camera"""
        
        # Get frames
        frames = self._pipeline.wait_for_frames()
        
        # Split frames
        color_frame = frames.first(rs.stream.color)
        depth_frame = frames.first(rs.stream.depth)
        infra_frame_1 = frames.get_infrared_frame(1)
        infra_frame_2 = frames.get_infrared_frame(2)


        if not (depth_frame and color_frame and infra_frame_1 and infra_frame_2):
            print("eh oh le truc y marche pas")
            return
        
        # Convert images to numpy arrays
        depth_image = np.asanyarray(depth_frame.get_data())
        color_image = np.asanyarray(color_frame.get_data())
        infra_image_1 = np.asanyarray(infra_frame_1.get_data())
        infra_image_2 = np.asanyarray(infra_frame_2.get_data())

        # Apply colormap on depth and IR images (image must be converted to 8-bit per pixel first)
        depth_colormap   = cv2.applyColorMap(cv2.convertScaleAbs(depth_image,   alpha=0.03), cv2.COLORMAP_JET)
        infra_colormap_1 = cv2.applyColorMap(cv2.convertScaleAbs(infra_image_1, alpha=1), cv2.COLORMAP_JET)
        infra_colormap_2 = cv2.applyColorMap(cv2.convertScaleAbs(infra_image_2, alpha=1), cv2.COLORMAP_JET)
        
        bridge=CvBridge()
        # Convert to ROS2 Image type
        # - RGB Image
        self._rgb_image = bridge.cv2_to_imgmsg(color_image,"bgr8")
        self._rgb_image.header.stamp = self.get_clock().now().to_msg()
        self._rgb_image.header.frame_id = "image"
        # - Depth Image
        self._depth_image = bridge.cv2_to_imgmsg(depth_colormap, "bgr8")
        self._depth_image.header.stamp = self._rgb_image.header.stamp
        self._depth_image.header.frame_id = "depth"
        # - IR1 Image
        self._infra1_image = bridge.cv2_to_imgmsg(infra_colormap_1,"bgr8")
        self._infra1_image.header.stamp = self._rgb_image.header.stamp
        self._infra1_image.header.frame_id = "infrared_1"
        # - IR2 Image
        self._infra2_image = bridge.cv2_to_imgmsg(infra_colormap_2,"bgr8")
        self._infra2_image.header.stamp = self._rgb_image.header.stamp
        self._infra2_image.header.frame_id = "infrared_2"

        return self._rgb_image, self._depth_image, self._infra1_image, self._infra2_image

    def publish_imgs(self):
        self.publish_rgb()
        self.publish_depth()
        self.publish_IR()


    def publish_rgb(self):
        """ 
        Publish RGB image
        """

        self._rgb_publisher.publish(self._rgb_image)

        return 0
    
    def publish_depth(self):
        """ 
        Publish depth colormap image
        """

        self._depth_publisher.publish(self._depth_image)

        return 0

    def publish_IR(self):
        """ 
        Publish infrared colormap image
        """

        self._infra_publisher_1.publish(self._infra1_image)
        self._infra_publisher_2.publish(self._infra2_image)

        return 0

# Capture ctrl-c event
isOk=True
def signalInterruption(signum, frame):
    global isOk
    print( "\nCtrl-c pressed" )
    isOk= False

def main():
    """
    Main loop
    """

    signal.signal(signal.SIGINT, signalInterruption)

    rclpy.init()
    rosNode = CameraDriver(timerFreq=1.0/120)
    isOk = True
    while isOk:
        rosNode.read_imgs()
        rclpy.spin_once(rosNode, timeout_sec=0.001)


    rosNode.destroy_node()
    rclpy.shutdown()

if __name__ == "__main__":
    main()