"""
Haptic controller (TouchX) for virtual mobile IO.
"""

import time
import threading
import numpy as np
from scipy.spatial.transform import Rotation as R

from hebi_tools.hebi_proto import FirmwareInfo
from dataclasses import dataclass, field
from collections import deque


@dataclass
class TouchXState:
    position: np.ndarray = field(default_factory=lambda: np.zeros(3))  # [x, y, z] position
    prev_position: np.ndarray = field(default_factory=lambda: np.zeros(3))  # Previous position
    linear_velocity: np.ndarray = field(default_factory=lambda: np.zeros(3))  # [x, y, z] linear velocity
    quaternion: np.ndarray = field(default_factory=lambda: np.array([0.0, 0.0, 0.0, 1.0]))  # [x, y, z, w] quaternion
    prev_quaternion: np.ndarray = field(default_factory=lambda: np.array([0.0, 0.0, 0.0, 1.0]))  # Previous quaternion
    angular_velocity: np.ndarray = field(default_factory=lambda: np.array([0.0, 0.0, 0.0]))  # [roll, pitch, yaw]
    button_states: list = field(default_factory=lambda: [False, False])  # [stylus_button, extra_button]
    force: np.ndarray = field(default_factory=lambda: np.zeros(3))  # [fx, fy, fz] force
    joints: list = field(default_factory=lambda: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0])  # 6-element joint list
    prev_joints: list = field(default_factory=lambda: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0])  # Previous joint positions
    joint_velocities: list = field(
        default_factory=lambda: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
    )  # 6-element joint velocity list
    prev_time: float = field(default_factory=lambda: time.time())

    # Moving average filter buffers for velocity calculations
    linear_velocity_buffer: deque = field(
        default_factory=lambda: deque(maxlen=5)
    )  # Buffer for linear velocity samples
    angular_velocity_buffer: deque = field(
        default_factory=lambda: deque(maxlen=5)
    )  # Buffer for angular velocity samples
    joint_velocities_buffer: list = field(
        default_factory=lambda: [deque(maxlen=5) for _ in range(6)]
    )  # Buffer for joint velocity samples

    def __repr__(self):
        return (
            f"TouchXState(button_states={self.button_states}, "
            f"position={self.position}, "
            f"prev_position={self.prev_position}, "
            f"quaternion={self.quaternion}, "
            f"prev_quaternion={self.prev_quaternion}, "
            f"force={self.force}, "
            f"joints={self.joints})"
        )


try:
    from pyOpenHaptics.hd_device import HapticDevice
    import pyOpenHaptics.hd as hd
    from pyOpenHaptics.hd_callback import hd_callback
except Exception as e:
    print(f"Touch X haptics will not work properly, Open Haptics packages not found: {e}")
    raise RuntimeError("Haptic device cannot be initialized because required packages/dependencies are not available.")


# Global reference to the haptic controller for the callback
_haptic_controller_instance = None


def moving_average_filter(buffer, new_value):
    """Apply moving average filter to a new value using the provided buffer."""
    if isinstance(new_value, np.ndarray):
        # For numpy arrays (like linear_velocity, angular_velocity)
        buffer.append(new_value.copy())
        if len(buffer) > 0:
            return np.mean(buffer, axis=0)
        else:
            return new_value
    else:
        # For scalar values (like individual joint velocities)
        buffer.append(new_value)
        if len(buffer) > 0:
            return np.mean(buffer)
        else:
            return new_value


@hd_callback
def hd_update():
    global _haptic_controller_instance
    if _haptic_controller_instance is not None:
        _haptic_controller_instance._hd_update_impl()


class TouchX:
    def __init__(self):
        self.shutdown = False
        self.disconnected = True
        self.device_state = None
        self.device_lock = threading.Lock()
        self.haptic_device = None

    def _hd_update_impl(self):
        """Implementation of haptic device update - called by the global callback"""
        with self.device_lock:
            if self.device_state is None:
                return

            # Get current time and compute delta time
            current_time = time.time()
            dt = current_time - self.device_state.prev_time
            self.device_state.prev_time = current_time
            if dt <= 1e-4:
                dt = 1e-3  # Prevent division by zero or negative time

            transform = hd.get_transform()
            transform_matrix = np.array(transform).T
            position = np.array([transform_matrix[0, 3], -transform_matrix[2, 3], transform_matrix[1, 3]])

            # Apply coordinate transformation to fix rotation mapping
            coord_transform = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]])
            # Transform the rotation matrix: T * R * T^(-1)
            rotation = coord_transform @ transform_matrix[:3, :3] @ coord_transform.T

            # Convert rotation matrix to quaternion
            quaternion = R.from_matrix(rotation).as_quat()  # Returns [x, y, z, w]

            buttons = hd.get_buttons()
            joints = hd.get_joints()
            gimbals = hd.get_gimbals()

            # Apply force to haptic device
            hd.set_force([self.device_state.force[0], self.device_state.force[2], -self.device_state.force[1]])

            # Update previous states
            self.device_state.prev_position = self.device_state.position.copy()
            self.device_state.prev_quaternion = self.device_state.quaternion.copy()
            self.device_state.prev_joints = self.device_state.joints.copy()

            # Save current states
            self.device_state.position = position
            self.device_state.quaternion = quaternion
            self.device_state.button_states[0] = bool(buttons & 1)  # Bit 0: stylus button
            self.device_state.button_states[1] = bool(buttons & 2)  # Bit 1: extra button
            self.device_state.joints = [joints[0], joints[1], joints[2], gimbals[0], gimbals[1], gimbals[2]]

            # Compute velocities with moving average filtering
            raw_joint_velocities = [
                (j - p) / dt for j, p in zip(self.device_state.joints, self.device_state.prev_joints)
            ]
            filtered_joint_velocities = []
            for i, raw_vel in enumerate(raw_joint_velocities):
                filtered_vel = moving_average_filter(self.device_state.joint_velocities_buffer[i], raw_vel)
                filtered_joint_velocities.append(filtered_vel)
            self.device_state.joint_velocities = filtered_joint_velocities

            raw_linear_velocity = (self.device_state.position - self.device_state.prev_position) / dt
            self.device_state.linear_velocity = moving_average_filter(
                self.device_state.linear_velocity_buffer, raw_linear_velocity
            )

            prev_rotmat = R.from_quat(self.device_state.prev_quaternion).as_matrix()
            delta_rotmat = prev_rotmat.T @ rotation
            delta_rotvec = R.from_matrix(delta_rotmat).as_rotvec()
            raw_angular_velocity = delta_rotvec / dt
            self.device_state.angular_velocity = moving_average_filter(
                self.device_state.angular_velocity_buffer, raw_angular_velocity
            )

    def initialize_haptic_device(self):
        """Initialize the haptic device"""
        global _haptic_controller_instance
        print("Initializing haptic device...")

        try:
            # Set global reference for the callback
            _haptic_controller_instance = self

            self.device_state = TouchXState()
            self.haptic_device = HapticDevice(device_name="Default Device", callback=hd_update)
            time.sleep(0.2)  # Give haptic device time to initialize
            self.disconnected = False
            print("Haptic device initialized successfully.")
            return True

        except Exception as e:
            print(f"Haptic device initialization error: {e}")
            self.disconnected = True
            return False

    def cleanup_haptic_device(self):
        """Clean up the haptic device"""
        global _haptic_controller_instance

        # Clear global reference
        _haptic_controller_instance = None

        # Clean up haptic device
        if self.haptic_device is not None:
            try:
                self.haptic_device.close()
                print("Haptic device closed.")
            except Exception as e:
                print(f"Error closing haptic device: {e}")

        # Clean up device state
        with self.device_lock:
            self.device_state = None
        self.disconnected = True

    def run(self):
        print("Starting haptic device controller...")

        # Initialize haptic device
        if not self.initialize_haptic_device():
            print("Failed to initialize haptic device.")
            return

        try:
            while not self.shutdown:
                if self.disconnected:
                    print("Haptic device disconnected. Attempting to reconnect...")
                    time.sleep(1.0)
                    continue

                # Haptic device updates are handled in the callback
                time.sleep(0.01)  # Small sleep to prevent excessive CPU usage

        except KeyboardInterrupt:
            print("\nShutting down haptic controller...")
            self.shutdown = True  # Set shutdown flag when keyboard interrupt is caught
        except Exception as e:
            print(f"Error in haptic controller: {e}")
        finally:
            self.cleanup_haptic_device()

    def device_handler(self, device, request, response):
        response.sender_id = device.sender_id
        response.echo.tx_time = request.echo.tx_time
        response.echo.payload = request.echo.payload
        response.echo.sequence_number = request.echo.sequence_number

        return_response = False

        if request.HasField("settings"):
            if request.settings.HasField("name"):
                if request.settings.name.HasField("name"):
                    device.name = request.settings.name.name
                if request.settings.name.HasField("family"):
                    device.family = request.settings.name.family

        if request.request_settings:
            response.settings.name.name = device.name
            response.settings.name.family = device.family
            return_response = True
        else:
            response.settings.ClearField("name")

        if request.request_firmware_info:
            response.firmware_info.type = device.fw_type
            response.firmware_info.revision = device.fw_rev
            response.firmware_info.mode = FirmwareInfo.APPLICATION
            return_response = True
        else:
            response.ClearField("firmware_info")

        if request.request_ethernet_info:
            response.ethernet_info.mac_address = device.mac_address
            response.ethernet_info.ip_address = device.ip_address
            response.ethernet_info.netmask = device.netmask
            return_response = True
        else:
            response.ClearField("ethernet_info")

        if request.request_hardware_info:
            response.hardware_info.serial_number = device.hw_serial_number
            response.hardware_info.mechanical_type = device.hw_mechanical_type
            response.hardware_info.mechanical_revision = device.hw_mechanical_revision
            response.hardware_info.electrical_type = device.hw_electrical_type
            response.hardware_info.electrical_revision = device.hw_electrical_revision
            return_response = True
        else:
            response.ClearField("hardware_info")

        if request.request_feedback and not self.disconnected:
            device_state = None
            with self.device_lock:
                device_state = self.device_state

            if device_state is not None:
                response.feedback.mobile_feedback.ar_position.x = device_state.position[0] * 1e-3
                response.feedback.mobile_feedback.ar_position.y = device_state.position[1] * 1e-3
                response.feedback.mobile_feedback.ar_position.z = device_state.position[2] * 1e-3
                response.feedback.mobile_feedback.ar_orientation.x = device_state.quaternion[0]
                response.feedback.mobile_feedback.ar_orientation.y = device_state.quaternion[1]
                response.feedback.mobile_feedback.ar_orientation.z = device_state.quaternion[2]
                response.feedback.mobile_feedback.ar_orientation.w = device_state.quaternion[3]

                response.feedback.io_feedback.b.pin1.int_value = int(device_state.button_states[0])
                response.feedback.io_feedback.b.pin2.int_value = int(device_state.button_states[1])

                response.feedback.io_feedback.a.pin1.float_value = device_state.joints[0]
                response.feedback.io_feedback.a.pin2.float_value = device_state.joints[1]
                response.feedback.io_feedback.a.pin3.float_value = device_state.joints[2]
                response.feedback.io_feedback.a.pin4.float_value = device_state.joints[3]
                response.feedback.io_feedback.a.pin5.float_value = device_state.joints[4]
                response.feedback.io_feedback.a.pin6.float_value = device_state.joints[5]

                response.feedback.io_feedback.f.pin1.float_value = device_state.force[0]
                response.feedback.io_feedback.f.pin2.float_value = device_state.force[1]
                response.feedback.io_feedback.f.pin3.float_value = device_state.force[2]

                response.feedback.wrench.force.x = device_state.force[0]
                response.feedback.wrench.force.y = device_state.force[1]
                response.feedback.wrench.force.z = device_state.force[2]

                response.feedback.gyro.x = device_state.angular_velocity[0]
                response.feedback.gyro.y = device_state.angular_velocity[1]
                response.feedback.gyro.z = device_state.angular_velocity[2]

                response.feedback.accel.x = device_state.linear_velocity[0]
                response.feedback.accel.y = device_state.linear_velocity[1]
                response.feedback.accel.z = device_state.linear_velocity[2]

            return_response = True
        else:
            response.ClearField("feedback")

        if request.HasField("command") and not self.disconnected:
            # Read force commands from f IO pins
            force_x, force_y, force_z = 0.0, 0.0, 0.0
            if request.command.HasField("wrench"):
                wrench = request.command.wrench
                force_x = wrench.force.x
                force_y = wrench.force.y
                force_z = wrench.force.z
                # torque_x = wrench.torque.x
                # torque_y = wrench.torque.y
                # torque_z = wrench.torque.z
            elif request.command.HasField("io_command"):
                io_cmd = request.command.io_command
                if io_cmd.HasField("f"):
                    force_x = io_cmd.f.pin1.float_value if io_cmd.f.HasField("pin1") else 0.0
                    force_y = io_cmd.f.pin2.float_value if io_cmd.f.HasField("pin2") else 0.0
                    force_z = io_cmd.f.pin3.float_value if io_cmd.f.HasField("pin3") else 0.0
                    force_x = 0.0 if np.isnan(force_x) else force_x
                    force_y = 0.0 if np.isnan(force_y) else force_y
                    force_z = 0.0 if np.isnan(force_z) else force_z

            with self.device_lock:
                if self.device_state is not None:
                    self.device_state.force = np.array([force_x, force_y, force_z])

        return return_response
