#!/usr/bin/python3
"""Basic control of the robot, move forward and avoid obstacles. Doesn't get stuck"""

# imports
import rclpy
import math as m
from rclpy.node import Node
from sensor_msgs.msg import PointCloud2
from kobuki_ros_interfaces.msg import WheelDropEvent, ButtonEvent
from geometry_msgs.msg import Twist
import sensor_msgs_py.point_cloud2
import time

# main class
class ReactiveMove(Node):
    
    def __init__(self, name="reactive_move", timerFreq = 1/60.0):
        """ constructor """
        super().__init__(name)  # Create the node

        # Initialize parameters
        self.declare_parameters(
            namespace='',
            parameters=[
                # Rectangle of vision of the robot
                ('RECT_WIDTH', 0.6),
                ('RECT_LENGTH', 0.6),
                # Rectangle of speeds of the robot
                ('MIN_DIST_MOVE_FORWARD', 0.3),
                ('MAX_DIST_MOVE_FORWARD', 0.6),
                # Velocities
                ('ANGULAR_VELOCITY', 0.8),
                ('LINEAR_VELOCITY', 0.25)
            ])

        # Initialize subscribers
        self.create_subscription(PointCloud2, '/scan_points', self.scan_callback, 10)
        self.create_subscription(WheelDropEvent, '/events/wheel_drop', self.wheel_callback, 10)
        self.create_subscription(ButtonEvent, '/events/button', self.btn_callback, 10)

        # Initialize a publisher
        self._velocity_publisher = self.create_publisher(Twist, 'cmd_vel', 10)

        # Initialize a clock for the publisher
        self.create_timer(timerFreq, self.publish_velo)

        # Initialize variables
        self._timerFreq = timerFreq
        self._previousScanTime = None
        self._points = []
        self._firstPointSeen = None
        ## Sécurité
        self._leftWheelDropped = False
        self._rightWheelDropped = False
        self._canMove = True



    def scan_callback(self, scanMsg):
        """
        Save received points for future calculations
        """
        self._points = sensor_msgs_py.point_cloud2.read_points(scanMsg)

        self._previousScanTime = time.time()

    def wheel_callback(self, msg):
        """
        Save received wheel state informations (dropped)
        """
        if(msg.wheel == msg.LEFT):
            self._leftWheelDropped = msg.state == msg.DROPPED
        else:
            self._rightWheelDropped = msg.state == msg.DROPPED

    def btn_callback(self, msg):
        """
        Toggle robot movement when pressing any of the btns
        """
        print(msg)
        if(msg.state == msg.PRESSED):
            self._canMove = not self._canMove
            print("pressed")

    
    def calcul_velo(self):
        """
        Calculate the velocity based on the lidar detected points and other data
        """

        now = time.time()


        # If scan data received more than 1s ago, stop the robot
        if now - self._previousScanTime > 1: 
            print("No data received for more than 1s")
            return Twist()  # Return a zero velocity
        
        # If both robot wheels are dropped (i.e. the robot is not on the floor), stop the robot movement
        if self._leftWheelDropped and self._rightWheelDropped:
            print("The robot is not on the floor")
            return Twist()  # Return a zero velocity
        

        # If stop btn was pressed
        if not self._canMove:
            print("The robot is stopped")
            return Twist()  # Return a zero velocity

        

        # Default case
        x, y = firstPointInFront(self._points, self.paramDouble('RECT_LENGTH'), self.paramDouble('RECT_WIDTH') / 2.0)

        if not(self._firstPointSeen) and x and y: #If you were previously moving straight without obstacles
            self._firstPointSeen = (x,y)

        velocity = Twist()
        if x and y: # If there is a point in front at the moment
            
            # Linear movement
            if x < self.paramDouble('MIN_DIST_MOVE_FORWARD'):
                velocity.linear.x = 0.0 #If the wall is too close, don't go forward, only rotate
            else:
                coefficient = (x - self.paramDouble('MIN_DIST_MOVE_FORWARD')) / (self.paramDouble('MAX_DIST_MOVE_FORWARD') - self.paramDouble('MIN_DIST_MOVE_FORWARD')) #linear between 0 (MIN_DIST_MOVE_FORWARD) and 1 (MAX_DIST_MOVE_FORWARD)
                if coefficient > 1:
                    coefficient = 1
                elif coefficient < 0:
                    coefficient = 0
                velocity.linear.x = coefficient * self.paramDouble('LINEAR_VELOCITY')

            # Rotation, based on the first point seen
            firstY = self._firstPointSeen[1]
            if firstY >= 0:  # Point was in the left part of the rectangle -> Continue turning right until there is no point in front
                velocity.angular.z = -self.paramDouble('ANGULAR_VELOCITY')
            else:  # Point was in the right part of the rectangle -> Continue turning left until there is no point in front
                velocity.angular.z = self.paramDouble('ANGULAR_VELOCITY')

            
        else:  # If there's not point in front, go straight
            velocity.linear.x = self.paramDouble('LINEAR_VELOCITY')
            self._firstPointSeen = None


        return velocity






    def publish_velo(self):
        """ 
        Publish velocity independently of the callback
        """

        # Make sure at least one callback is made
        if not self._previousScanTime:
            return -1


        velocity = self.calcul_velo()

        self._velocity_publisher.publish(velocity) #  Move according to previously calculated velocity

        return 0

    def paramDouble(self, name):
        return self.get_parameter(name).get_parameter_value().double_value

# Static functions
def firstPointInFront(points:list, maxX, maxY):
    """
    Check if points in front of the robot (in the rectangle with forward distance maxX and lateral distance of 2*maxY)
    """
    # Closest point
    cx, cy = None, None 
    closestDistance = 10000

    for (x,y,_) in points:
        pointDistance = m.sqrt(x**2 + y**2) #Get point distance to Lidar

        isInRectangle = x > 0 and x < maxX and abs(y) < maxY
        isCloser = (cx is None) or (pointDistance < closestDistance)

        if(isInRectangle and isCloser): #If the point is closer, store it as the new closest point in the rectangle
            cx,cy = x,y
            closestDistance = pointDistance
    
    return (cx, cy)



# Main loop
def main():
    """
    Main loop
    """
    rclpy.init()
    rosNode = ReactiveMove()
    rclpy.spin(rosNode)


    rosNode.destroy_node()
    rclpy.shutdown()

if __name__ == "__main__":
    main()