#! /usr/bin/env python3
import time
import importlib
import threading
import argparse
import platform
import numpy as np
from dataclasses import dataclass, field
from scipy.spatial.transform import Rotation as R

import yaml

from main import create_server, default_device_handler

from hebi_tools.hebi_proto import FirmwareInfo

import inputs

HAPTIC_AVAILABLE = True
try:
    from pyOpenHaptics.hd_device import HapticDevice
    import pyOpenHaptics.hd as hd
    from pyOpenHaptics.hd_callback import hd_callback
except Exception as e:
    print(f'Touch X haptics will not work properly, Open Haptics packages not found: {e}')
    HAPTIC_AVAILABLE = False
    # Provide no-op shims so code doesn't break
    class HapticDevice:
        pass

    class hd:
        pass

    def hd_callback(func):
        # Just return the function unchanged
        return func


SYSTEM = platform.system()


class Gamepad:
    AXIS_MAX = pow(2, 15)
    TRIGGER_MAX = 256
    def __init__(self, gamepad_idx, config_file):
        self.gamepad_idx = gamepad_idx
        self.axis_states =   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
        self.button_states = [0, 0, 0, 0, 0, 0, 0, 0]
        self.keymap = self.load_config(config_file)
        self.shutdown = False
        try:
            self.gamepad = inputs.devices.gamepads[self.gamepad_idx]
            self.disconnected = False
        except:
            self.gamepad = None
            self.disconnected = True
            print('Waiting for gamepad...')

    def load_config(self, config_file):
        with open(config_file) as cf:
            try:
                return yaml.safe_load(cf)
            except yaml.YAMLError as e:
                print(e)

    def run(self):
        while not self.shutdown:
            if self.disconnected:
                try:
                    importlib.reload(inputs)
                    time.sleep(0.1)
                    self.gamepad = inputs.devices.gamepads[self.gamepad_idx]
                    if self.gamepad is not None:
                        self.disconnected = False
                        print('Connected to gamepad!')
                except KeyboardInterrupt:
                    self.shutdown = True
                    break
                except:
                    continue

            try:
                if SYSTEM == 'Windows':
                    self.gamepad._GamePad__check_state()
                events = self.gamepad._do_iter()
            except (OSError, inputs.UnpluggedError):
                print('Lost connection to gamepad...')
                self.disconnected = True
                self.axis_states =   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
                self.button_states = [0, 0, 0, 0, 0, 0, 0, 0]
                continue
            except KeyboardInterrupt:
                self.shutdown = True
                break
            if not events:
                continue
            for event in events:
                if event.code in self.keymap.keys():
                    evt_type = self.keymap[event.code]['type']
                    idx = self.keymap[event.code]['idx']
                    if evt_type == 'axis':
                        scale = self.keymap[event.code]['scale']
                        self.axis_states[idx] = float(event.state) / float(scale)
                    elif evt_type == 'button':
                        self.button_states[idx] = event.state

    def device_handler(self, device, request, response):
        response.sender_id = device.sender_id
        response.echo.tx_time = request.echo.tx_time
        response.echo.payload = request.echo.payload
        response.echo.sequence_number = request.echo.sequence_number

        return_response = False

        if request.HasField('settings'):
            if request.settings.HasField('name'):
                if request.settings.name.HasField('name'):
                    device.name = request.settings.name.name
                if request.settings.name.HasField('family'):
                    device.family = request.settings.name.family

        if request.request_settings:
            response.settings.name.name = device.name
            response.settings.name.family = device.family
            return_response = True

        else:
            response.settings.ClearField('name')

        if request.request_firmware_info:
            response.firmware_info.type = device.fw_type
            response.firmware_info.revision = device.fw_rev
            response.firmware_info.mode = FirmwareInfo.APPLICATION
            return_response = True
        else:
            response.ClearField('firmware_info')

        if request.request_ethernet_info:
            response.ethernet_info.mac_address = device.mac_address
            response.ethernet_info.ip_address = device.ip_address
            response.ethernet_info.netmask = device.netmask
            return_response = True
        else:
            response.ClearField('ethernet_info')

        if request.request_hardware_info:
            response.hardware_info.serial_number = device.hw_serial_number
            response.hardware_info.mechanical_type = device.hw_mechanical_type
            response.hardware_info.mechanical_revision = device.hw_mechanical_revision
            response.hardware_info.electrical_type =  device.hw_electrical_type
            response.hardware_info.electrical_revision =  device.hw_electrical_revision
            return_response = True
        else:
            response.ClearField('hardware_info')

        if request.request_feedback and not self.disconnected:
            response.feedback.io_feedback.a.pin1.float_value = self.axis_states[0]
            response.feedback.io_feedback.a.pin2.float_value = self.axis_states[1]
            response.feedback.io_feedback.a.pin3.float_value = self.axis_states[2]
            response.feedback.io_feedback.a.pin4.float_value = self.axis_states[3]
            response.feedback.io_feedback.a.pin5.float_value = self.axis_states[4]
            response.feedback.io_feedback.a.pin6.float_value = self.axis_states[5]
            response.feedback.io_feedback.a.pin7.float_value = self.axis_states[6]
            response.feedback.io_feedback.a.pin8.float_value = self.axis_states[7]

            response.feedback.io_feedback.b.pin1.int_value = self.button_states[0]
            response.feedback.io_feedback.b.pin2.int_value = self.button_states[1]
            response.feedback.io_feedback.b.pin3.int_value = self.button_states[2]
            response.feedback.io_feedback.b.pin4.int_value = self.button_states[3]
            response.feedback.io_feedback.b.pin5.int_value = self.button_states[4]
            response.feedback.io_feedback.b.pin6.int_value = self.button_states[5]
            response.feedback.io_feedback.b.pin7.int_value = self.button_states[6]
            response.feedback.io_feedback.b.pin8.int_value = self.button_states[7]
            return_response = True
        else:
            response.ClearField('feedback')

        return return_response

@dataclass
class DeviceState:
    button_states: list = field(default_factory=lambda: [False, False])  # [stylus_button, extra_button]
    position: np.ndarray = field(default_factory=lambda: np.zeros(3))  # [x, y, z] position
    quaternion: np.ndarray = field(default_factory=lambda: np.array([0.0, 0.0, 0.0, 1.0]))  # [x, y, z, w] quaternion
    force: np.ndarray = field(default_factory=lambda: np.zeros(3))  # [fx, fy, fz] force
    joints: list = field(default_factory=lambda: [0.0, 0.0, 0.0])  # 3-element joint list
    gimbals: list = field(default_factory=lambda: [0.0, 0.0, 0.0])  # 3-element gimbal list

    def __repr__(self):
        return (f"DeviceState(button_states={self.button_states}, "
                f"position={self.position}, "
                f"quaternion={self.quaternion}, "
                f"force={self.force}, "
                f"joints={self.joints}, "
                f"gimbals={self.gimbals})")


# Global reference to the haptic controller for the callback
_haptic_controller_instance = None

@hd_callback
def hd_update():
    global _haptic_controller_instance
    if _haptic_controller_instance is not None:
        _haptic_controller_instance._hd_update_impl()


class TouchX:
    def __init__(self):
        self.shutdown = False
        self.disconnected = True
        self.device_state = None
        self.device_lock = threading.Lock()
        self.haptic_device = None
        self.shutdown_event = threading.Event()

    def _hd_update_impl(self):
        """Implementation of haptic device update - called by the global callback"""
        with self.device_lock:
            if self.device_state is None:
                return
        
            transform = hd.get_transform()
            transform_matrix = np.array(transform).T
            position = np.array([transform_matrix[0, 3], -transform_matrix[2, 3], transform_matrix[1, 3]])
            
            # Apply coordinate transformation to fix rotation mapping
            coord_transform = np.array([[1, 0, 0],
                                       [0, 0, -1],
                                       [0, 1, 0]])
            # Transform the rotation matrix: T * R * T^(-1)
            rotation = coord_transform @ transform_matrix[:3, :3] @ coord_transform.T
            
            # Convert rotation matrix to quaternion
            quaternion = R.from_matrix(rotation).as_quat()  # Returns [x, y, z, w]

            buttons = hd.get_buttons()
            joints = hd.get_joints()
            gimbals = hd.get_gimbals()
            
            # Apply force to haptic device
            hd.set_force([self.device_state.force[0], self.device_state.force[2], -self.device_state.force[1]])
            
            # Save to device state
            self.device_state.position = position
            self.device_state.quaternion = quaternion
            self.device_state.button_states[0] = bool(buttons & 1)  # Bit 0: stylus button
            self.device_state.button_states[1] = bool(buttons & 2)  # Bit 1: extra button
            self.device_state.joints = [joints[0], joints[1], joints[2]]
            self.device_state.gimbals = [gimbals[0], gimbals[1], gimbals[2]]

    def haptic_thread_function(self):
        """Function to run haptic device in separate thread"""
        global _haptic_controller_instance
        print("Initializing haptic device in separate thread...")
        
        try:
            # Set global reference for the callback
            _haptic_controller_instance = self
            
            self.device_state = DeviceState()
            self.haptic_device = HapticDevice(device_name="Default Device", callback=hd_update)
            time.sleep(0.2)  # Give haptic device time to initialize
            self.disconnected = False
            print("Haptic device initialized and running in background thread.")
            
            # Keep the haptic device running until shutdown signal
            while not self.shutdown_event.is_set():
                time.sleep(0.01)  # Small sleep to prevent excessive CPU usage
                
        except Exception as e:
            print(f"Haptic thread error: {e}")
            self.disconnected = True
        finally:
            # Clear global reference
            _haptic_controller_instance = None
            
            # Clean up haptic device
            if self.haptic_device is not None:
                try:
                    self.haptic_device.close()
                    print("Haptic device closed.")
                except Exception as e:
                    print(f"Error closing haptic device: {e}")
            
            # Clean up device state
            with self.device_lock:
                self.device_state = None
            self.disconnected = True

    def run(self):
        print("Starting haptic device controller...")
        
        # Reset shutdown event
        self.shutdown_event.clear()
        
        # Start haptic device in separate thread
        haptic_thread = threading.Thread(target=self.haptic_thread_function, daemon=False)
        haptic_thread.start()
        time.sleep(0.5)  # Give haptic thread time to initialize
        
        # Check if haptic device was initialized successfully
        with self.device_lock:
            device_initialized = self.device_state is not None
        
        if not device_initialized:
            print("Failed to initialize haptic device.")
            self.disconnected = True
        
        try:
            while not self.shutdown:
                if self.disconnected:
                    print('Waiting for haptic device...')
                    time.sleep(1.0)
                    continue
                    
                # Haptic device updates are handled in the callback
                time.sleep(0.01)  # Small sleep to prevent excessive CPU usage
                
        except KeyboardInterrupt:
            print("\nShutting down haptic controller...")
            self.shutdown = True  # Set shutdown flag when keyboard interrupt is caught
        except Exception as e:
            print(f"Error in haptic controller: {e}")
        finally:
            # Signal haptic thread to shutdown and wait for it
            self.shutdown_event.set()
            haptic_thread.join(timeout=5.0)
                
            if haptic_thread.is_alive():
                print("Warning: Haptic thread did not shutdown cleanly")
            else:
                print("Haptic thread closed successfully.")

    def device_handler(self, device, request, response):
        response.sender_id = device.sender_id
        response.echo.tx_time = request.echo.tx_time
        response.echo.payload = request.echo.payload
        response.echo.sequence_number = request.echo.sequence_number

        return_response = False

        if request.HasField('settings'):
            if request.settings.HasField('name'):
                if request.settings.name.HasField('name'):
                    device.name = request.settings.name.name
                if request.settings.name.HasField('family'):
                    device.family = request.settings.name.family

        if request.request_settings:
            response.settings.name.name = device.name
            response.settings.name.family = device.family
            return_response = True
        else:
            response.settings.ClearField('name')

        if request.request_firmware_info:
            response.firmware_info.type = device.fw_type
            response.firmware_info.revision = device.fw_rev
            response.firmware_info.mode = FirmwareInfo.APPLICATION
            return_response = True
        else:
            response.ClearField('firmware_info')

        if request.request_ethernet_info:
            response.ethernet_info.mac_address = device.mac_address
            response.ethernet_info.ip_address = device.ip_address
            response.ethernet_info.netmask = device.netmask
            return_response = True
        else:
            response.ClearField('ethernet_info')

        if request.request_hardware_info:
            response.hardware_info.serial_number = device.hw_serial_number
            response.hardware_info.mechanical_type = device.hw_mechanical_type
            response.hardware_info.mechanical_revision = device.hw_mechanical_revision
            response.hardware_info.electrical_type =  device.hw_electrical_type
            response.hardware_info.electrical_revision =  device.hw_electrical_revision
            return_response = True
        else:
            response.ClearField('hardware_info')

        if request.request_feedback and not self.disconnected:
            device_state = None
            with self.device_lock:
                device_state = self.device_state

            if device_state is not None:
                response.feedback.mobile_feedback.ar_position.x = device_state.position[0] * 1e-3
                response.feedback.mobile_feedback.ar_position.y = device_state.position[1] * 1e-3
                response.feedback.mobile_feedback.ar_position.z = device_state.position[2] * 1e-3
                response.feedback.mobile_feedback.ar_orientation.x = device_state.quaternion[0]
                response.feedback.mobile_feedback.ar_orientation.y = device_state.quaternion[1]
                response.feedback.mobile_feedback.ar_orientation.z = device_state.quaternion[2]
                response.feedback.mobile_feedback.ar_orientation.w = device_state.quaternion[3]

                response.feedback.io_feedback.b.pin1.int_value = int(device_state.button_states[0])
                response.feedback.io_feedback.b.pin2.int_value = int(device_state.button_states[1])

                response.feedback.io_feedback.a.pin1.float_value = device_state.joints[0]
                response.feedback.io_feedback.a.pin2.float_value = device_state.joints[1]
                response.feedback.io_feedback.a.pin3.float_value = device_state.joints[2]
                response.feedback.io_feedback.a.pin4.float_value = device_state.gimbals[0]
                response.feedback.io_feedback.a.pin5.float_value = device_state.gimbals[1]
                response.feedback.io_feedback.a.pin6.float_value = device_state.gimbals[2]

                response.feedback.io_feedback.f.pin1.float_value = device_state.force[0]
                response.feedback.io_feedback.f.pin2.float_value = device_state.force[1]
                response.feedback.io_feedback.f.pin3.float_value = device_state.force[2]

            return_response = True
        else:
            response.ClearField('feedback')

        if request.HasField('command') and not self.disconnected:
            # Read force commands from f IO pins
            force_x, force_y, force_z = 0.0, 0.0, 0.0
            if request.command.HasField('io_command'):
                io_cmd = request.command.io_command
                if io_cmd.HasField('f'):
                    force_x = io_cmd.f.pin1.float_value if io_cmd.f.HasField('pin1') else 0.0
                    force_y = io_cmd.f.pin2.float_value if io_cmd.f.HasField('pin2') else 0.0
                    force_z = io_cmd.f.pin3.float_value if io_cmd.f.HasField('pin3') else 0.0
                    force_x = 0.0 if np.isnan(force_x) else force_x
                    force_y = 0.0 if np.isnan(force_y) else force_y
                    force_z = 0.0 if np.isnan(force_z) else force_z

            with self.device_lock:
                if self.device_state is not None:                    
                    self.device_state.force = np.array([force_x, force_y, force_z])

        return return_response

#@Gooey
def main():
    parser = argparse.ArgumentParser(description='Use other input device as mobileIO')
    parser.add_argument('--device', choices=['haptic', 'gamepad', 'keyboard'], default='gamepad',
            help='desired control device type')
    parser.add_argument('--key_config', default=None,
            help='path to YAML file that maps IO events to mobileIO channels')
    parser.add_argument('--family', default='HEBI',
            help='Family for the input device and imitation modules')
    parser.add_argument('--name', default=None,
            help='Name for the input device and imitation modules')
    parser.add_argument('--port', default=16665,
            help='Optional non-default port number to run on (allowing for multiple devices using the hebi-gateway-client)')
    args = parser.parse_args()

    # Initialize the appropriate controller based on the selected device
    if args.device == 'haptic':
        if not HAPTIC_AVAILABLE:
            print("Error: Haptic device selected but required packages (pyOpenHaptics) are not available.")
            print("Please install the necessary haptic packages or choose a different device type.")
            return
        device_controller = TouchX()
        if args.name is None:
            args.name = 'hapticIO'
    elif args.device == 'gamepad':
        if args.key_config is None:
            raise ValueError("Key configuration file must be provided for gamepad control.")
        if not args.key_config.endswith('.yaml'):
            raise ValueError("Key configuration file must be a YAML file.")
        args.key_config = str(args.key_config)
        device_controller = Gamepad(0, args.key_config)
        if args.name is None:
            args.name = 'gamepadIO'
    elif args.device == 'keyboard':
        raise NotImplementedError("Keyboard control is not implemented yet.")

    if args.port is not None:
        port = int(args.port)
        if port is None or port < 1024 or port > 49151:
            raise ValueError("Port must be valid number between 1024 and 49151.")

    device_thread = threading.Thread(target=device_controller.run)
    group = None

    try:
        group = create_server(args.family,
                              [args.name],
                              int(args.port),
                              device_handler=[device_controller.device_handler])

        device_thread.start()

        print('Starting Virtual Module')
        group.run()
    except KeyboardInterrupt:
        print("\nShutting down...")
    finally:
        # Properly shutdown the device controller
        device_controller.shutdown = True
        
        # Stop the group if it was created
        if group is not None:
            group._enabled = False
        
        # Wait for device thread to finish
        if device_thread.is_alive():
            device_thread.join(timeout=10.0)
            if device_thread.is_alive():
                print("Warning: Device thread did not shutdown cleanly")
            else:
                print("Device thread shutdown successfully")


if __name__ == '__main__':
    main()
