#!/usr/bin/env python3

from dataclasses import dataclass, field
from functools import wraps
from typing import Protocol
import hebi
import hebi._internal.errors
from hebi.robot_model import FrameType
import numpy as np
from time import sleep
# from scipy.spatial.transform import Rotation as R, Slerp
from threading import Event, RLock, Thread

from typing import Union, get_args
import numpy.typing as npt

def synchronized(func):
    """
    A decorator that locks an instance method, ensuring only one thread
    can execute it at a time for a given instance.
    """
    @wraps(func)
    def wrapper(self, *args, **kwargs):
        with self._lock:
            return func(self, *args, **kwargs)
    return wrapper


def get_union_types(union_type):
    """
    Retrieves the types within a Union type.

    Args:
      union_type: The Union type to inspect.

    Returns:
      A tuple containing the types within the Union, or None if the input is not a Union.
    """
    try:
        if getattr(union_type, '__origin__') is Union:
            return get_args(union_type)
    except AttributeError:
        pass
    return None


class EndEffectorController(Protocol):
    def update(self) -> bool: ...

    def send(self) -> bool: ...

    def close(self): ...

    def open(self): ...

    def toggle(self): ...


@dataclass
class Startup:
    ...


@dataclass
class Idle:
    ...


@dataclass
class Floating:
    return_state: 'HEBIArmControl.State' = field(default_factory=Idle)
    allow_tool_force: bool = False


@dataclass
class Moving:
    goal: hebi.arm.Goal
    to_dock: bool
    color: 'str' = 'blue'


@dataclass
class OvertorqueFault:
    joint: int
    torque: float
    return_state: 'HEBIArmControl.State'


@dataclass
class MStopTriggered:
    return_state: 'HEBIArmControl.State'


class HEBIArmControl:
    State = Startup | Idle | Moving | Floating | OvertorqueFault | MStopTriggered

    def __init__(self,
                 lookup: hebi.Lookup,
                 config: hebi.config.HebiConfig,
                 ee_controller: EndEffectorController | None = None,
                 logging=True):

        self._state = Startup()
        self.arm = hebi.arm.create_from_config(config, lookup)

        if not config.hrdf:
            raise RuntimeError('Arm needs hrdf?')
        self.user_robot_model = hebi.robot_model.import_from_hrdf(config.hrdf)

        if config.user_data is None:
            raise RuntimeError('Config must have user_data section')

        self.user_data = config.user_data
        assert self.user_data is not None

        if config.gains is None:
            raise RuntimeError('Config must have gains section')

        self.gains = config.gains
        self.default_command_lifetime = self.arm.group.command_lifetime

        # self.ik_seed = self.user_data['ik_seed_pos']

        self.logging = logging
        if self.logging:
            self.arm.group.start_log('logs', mkdirs=True)

        self.ee_ctrl = ee_controller

        self._shutdown = Event()
        self._lock = RLock()
        self._thread = Thread(None, self._run)

        self.arm.update()

        if not np.any(np.isnan(self.arm.last_feedback.position_command)):
            print('Starting from commanded position')
            self.starting_position = self.arm.last_feedback.position_command
        else:
            print('No position command, starting from feedback')
            print(self.arm.last_feedback.position_command)
            self.starting_position = self.arm.last_feedback.position

        if 'default' in self.gains:
            self.arm.load_gains(self.gains['default'])
        self.arm.pending_command.led.color = 'transparent'

        self._ee_wrench = np.zeros(6)

        self.ee_xyz = np.empty(3)
        self.ee_rot = np.empty((3, 3))
        self.ee_jacobian = np.empty((6, self.arm.size))

        self.ns_vec_prev = None

        self.vel_step_comp = 10  # [Nm]
        self.vel_speed_comp = 60  # [Nm / (rad/sec)]
        self.base_torque_error_limit = float(self.user_data.get('base_torque_error_limit', np.inf))  # [Nm]
        self.joint_torque_error_limit = self.base_torque_error_limit
        self.torque_limit_return_rate = 5  # [Nm / sec]

    def write_user_log_state(self, v1=None, v2=None, v3=None, v4=None, v5=None, v6=None, v7=None, v8=None, v9=None):
        self.arm.group.log_user_state(v1, v2, v3, v4, v5, v6, v7, v8, v9)

    def _run(self):
        while not self._shutdown.is_set():
            self._update()
            self.send()

        if np.any(np.isnan(self.arm.last_feedback.position_command)):
            print('No position commands set at shutdown')
            print('Clearing torque command and 0 command lifetime')
            self.arm.group.command_lifetime = self.default_command_lifetime
            self.arm.pending_command.effort = None
            self.arm.pending_command.led.color = 'transparent'
            self.arm.send()

        if self.logging:
            self.arm.group.stop_log()

    @property
    @synchronized
    def state(self):
        return self._state

    @state.setter
    @synchronized
    def state(self, new_state):
        self._state = new_state

    @property
    def running(self):
        return self._thread.is_alive()

    def start(self):
        if self.running:
            print('Controller already running, ignoring start request.')
            return

        self._shutdown.clear()
        self._thread.start()

    def stop(self):
        if not self.running:
            raise RuntimeError("Can't stop controller, not running")

        self._shutdown.set()
        self._thread.join()

    @property
    def gravcomp_plugin(self):
        if gc := self.arm.get_plugin_by_type(hebi.arm.GravCompEffortPlugin):
            return gc
        raise RuntimeError('Active arm does not have gravcomp plugin?')

    @synchronized
    def is_idle(self):
        if isinstance(self.state, (OvertorqueFault, MStopTriggered)):
            return False
        return self.arm.at_goal or self.arm._trajectory is None

    @property
    @synchronized
    def ee_wrench(self):
        return self._ee_wrench

    @ee_wrench.setter
    @synchronized
    def ee_wrench(self, wrench: 'npt.NDArray[np.float64]'):
        self._ee_wrench = wrench


    def _linear_move(self, goal: hebi.arm.Goal, xyz_from, xyz_to, orientation, ik_seed, num_midpoints=5, duration: float | None = None):
        dist: float = np.linalg.norm(xyz_from - xyz_to)  # type: ignore
        travel_time = dist * 10.0 if duration is None else duration  # seconds
        for i in range(num_midpoints):
            ratio = (1.0 + i) / (1.0 + num_midpoints)
            xyz = xyz_from * (1 - ratio) + xyz_to * ratio
            goal.add_waypoint(t=travel_time/(1.0 + num_midpoints),
                              position=self.arm_ik(ik_seed, xyz, orientation))
            # velocity=np.zeros(goal.dof_count),
            # acceleration=np.zeros(goal.dof_count))

        final_position = self.arm_ik(ik_seed, xyz_to, orientation)
        goal.add_waypoint(t=travel_time/(1.0 + num_midpoints),
                          position=final_position,
                          velocity=np.zeros(goal.dof_count),
                          acceleration=np.zeros(goal.dof_count))
        return final_position

    def set_color(self, color: str, indices=None):
        #print(f'Set Color: {color}')
        if indices:
            for idx in indices:
                self.arm.pending_command[idx].led.color = color
        else:
            self.arm.pending_command.led.color = color

    def request_transition(self, action: State, blocking=True):
        while True:
            with self._lock:
                can_preempt = isinstance(
                    action, Moving) and isinstance(self.state, Moving)
                if self.is_idle() or can_preempt:
                    break

            if not blocking:
                return False
            sleep(0.01)

        with self._lock:
            self._transition_to(action)

        return True

    @synchronized
    def _transition_to(self, next_state: State):
        if self.state == next_state:
            return

        # Don't spam a bunch of Moving -> Moving transition prints during teleop
        if type(self.state) is not type(next_state):
            print(f'Transition: {self.state} -> {next_state}')

        # The "transition_from" section, maybe factor this out later
        match self.state:
            case OvertorqueFault(idx, value, _):
                self.joint_torque_error_limit = 1.5 * self.base_torque_error_limit
            case _:
                self.arm.group.command_lifetime = 0

        match next_state:
            case OvertorqueFault(idx, value, _):
                self.set_color('yellow')
                self._command_current_position()
                print(
                    f'Torque error limit exceeded for joint (from base) #{idx}, {value} Nm')

            case Floating(_):
                self.set_color('magenta')
                self.arm.cancel_goal()
                self.arm.group.command_lifetime = self.default_command_lifetime

            case Idle():
                self.set_color('transparent')
                if self.arm._trajectory is None:
                    print('Commanding current position')
                    self._command_current_position()
                gc = self.gravcomp_plugin
                if not gc.enabled:
                    print(f'Enabling gravcomp')
                    gc.enabled = True
                if self.gripper_closed():
                    self.open_gripper()

            case Moving(goal, _, led_color):
                # print(f'Moving to {goal}, to dock? {to_dock}')
                self.set_color(led_color)
                self.arm.set_goal(goal)

        self.state = next_state

    @synchronized
    def _update(self):
        t_prev = self.arm._last_time
        self.arm.update()
        self.dt = self.arm._last_time - t_prev
        prev_xyz = self.ee_xyz.copy()
        self.arm.FK(self.arm.last_feedback.position, xyz_out=self.ee_xyz, orientation_out=self.ee_rot)
        self.ee_jacobian = self.arm.robot_model.get_jacobian_end_effector(self.arm.last_feedback.position)
        self.ee_vel = (self.ee_xyz - prev_xyz) / self.dt

        if self.ee_ctrl:
            self.ee_ctrl.update()

        if self.joint_torque_error_limit > self.base_torque_error_limit:
            # print(self.joint_torque_error_limit)
            self.joint_torque_error_limit -= self.torque_limit_return_rate * 0.01

        fbk = self.arm.last_feedback

        joint_torque_error = fbk.effort - fbk.effort_command

        if any(abs(joint_torque_error) > self.joint_torque_error_limit) and not isinstance(self.state, OvertorqueFault):
            print('Joint Torque Error Limit Exceeded')
            idx = int(np.argmax(joint_torque_error))
            recovery_state = Idle()
            self._transition_to(OvertorqueFault(
                idx, joint_torque_error[idx], recovery_state))

        match self.state:
            case Startup():
                g = hebi.arm.Goal(self.arm.size)
                g.add_waypoint(t=0.3, position=self.starting_position)
                self.arm.set_goal(g)
                self._transition_to(Idle())

            case Moving(_, to_dock, _):
                self.arm.pending_command.effort += self.ee_jacobian.T @ self.ee_wrench
                if self.arm.at_goal:
                    self._transition_to(Idle())

            case Floating(_, True):
                self.arm.pending_command.effort += self.ee_jacobian.T @ self.ee_wrench

    def send(self):
        self.arm.send()
        if self.ee_ctrl:
            self.ee_ctrl.send()

    def open_gripper(self):
        # TODO: implement
        pass

    def close_gripper(self):
        # TODO: implement
        pass

    def toggle_gripper(self):
        # TODO: implement
        pass

    def gripper_closed(self):
        # TODO: implement
        pass

    def _command_current_position(self):
        g = hebi.arm.Goal(self.arm.size)
        g.add_waypoint(t=0.5, position=self.arm.last_feedback.position)
        self.arm.set_goal(g)

    def arm_fk(self, positions, xyz_out=None, tip_axis_out=None, orientation_out=None):
        '''Uses a different RobotModel instance from the arm object.

        This function can be called from the user thread without worrying about locks'''
        tmp_frame = np.empty((4, 4))
        self.user_robot_model.get_end_effector(positions, tmp_frame)

        if tip_axis_out is not None:
            tip_axis_out[:] = tmp_frame[:3, 2]

        if orientation_out is not None:
            orientation_out[:] = tmp_frame[:3, :3]

        if xyz_out is not None:
            xyz_out[:] = tmp_frame[:3, 3]
            return xyz_out

        return tmp_frame[:3, 3]

    def arm_ik(self, seed, xyz, so3):
        xyz_objective = hebi.robot_model.PositionObjective(FrameType.EndEffector, xyz=xyz)
        so3_objective = hebi.robot_model.SO3Objective(FrameType.EndEffector, rotation=so3)
        try:
            return self.user_robot_model.solve_inverse_kinematics(seed, xyz_objective, so3_objective)
        except hebi._internal.errors.HEBI_Exception as e:
            print(f'Kinematics error: {e}')
            print(f'Using seed {seed}\nTarget xyz: {xyz}\nso3: {so3}')

    def arm_nullspace_vector(self, position):
        robot_model = self.user_robot_model
        j = robot_model.get_jacobian_end_effector(position)
        j_inv = np.linalg.pinv(j)
        filter_matrix = np.eye(robot_model.dof_count) - j_inv @ j
        u, s, v = np.linalg.svd(filter_matrix)
        return v.T[:, 0]

    def get_overlay_force(self, tip_force: float):
        tip_axis = np.empty(3)
        position = self.arm.last_feedback.position
        self.arm_fk(position, tip_axis_out=tip_axis)
        return self.ee_jacobian.T[:, :3] @ (tip_force * tip_axis)

    def move_to(self, xyz, rot, seed=None, duration=5.0, blocking=True):
        g = hebi.arm.Goal(self.arm.size)
        if seed is None:
            seed = self.arm.last_feedback.position_command

        g.add_waypoint(t=duration, position=self.arm_ik(seed, xyz, rot))
        if not self.request_transition(Moving(g, False), blocking):
            print('Could not start move, arm is not idle')
            return False

        if blocking:
            while not isinstance(self.state, Idle):
                sleep(0.01)

        return True

    def goto_joint(self, angles, time=5.0, color='blue', blocking=True):
        g = hebi.arm.Goal(self.arm.size)
        g.add_waypoint(t=time, position=angles)

        if not self.request_transition(Moving(g, False, color), blocking):
            print('Could not start move, arm is not idle')
            return False

        if blocking:
            while not isinstance(self.state, Idle):
                sleep(0.01)

        return True

    def goto_cartesian(self, xyz, rot, undulate_vel=0.0, time=0.3,
                       user_seed: 'npt.NDArray[np.float64] | None' = None,
                       seed_blend_ratio = 1.0,
                       blocking=True):
        # default to last commanded position
        seed = self.arm.last_feedback.position_command
        match self.state:
            # If moving towards a goal, seed from the end of that trajectory
            # This helps make null space control more predictable
            case Moving(g_now, _):
                seed = np.array(g_now._positions[-1])
            case Idle():
                pass
            case _:
                # For other states (error, docked, etc) shouldn't be able to move arm
                return False

        if user_seed is not None:
            seed += seed_blend_ratio * (user_seed - seed)

        new_seed = seed
        if undulate_vel != 0.0:
            ns_vec = self.arm_nullspace_vector(seed)
            if self.ns_vec_prev is not None:
                if np.dot(self.ns_vec_prev, ns_vec) < 0.0:
                    ns_vec *= -1

            new_seed += ns_vec * undulate_vel
            self.ns_vec_prev = ns_vec

        g = hebi.arm.Goal(self.arm.size)
        g.add_waypoint(t=time, position=self.arm_ik(new_seed, xyz, rot))
        if not self.request_transition(Moving(g, False, 'transparent'), blocking):
            print(f'Could not jog arm , arm is in {self.state} state')
            return False

        if blocking:
            while not isinstance(self.state, Idle):
                sleep(0.01)

        return True

    @synchronized
    def clear_torque_fault(self, blocking=True):
            match self.state:
                case OvertorqueFault(_, _, recovery_state):
                    self._transition_to(recovery_state)
                    return True
                case _:
                    print('Not in torque fault state')
                    return False

    def toggle_floating(self, blocking=True):
        match self.state:
            case Idle():
                self.request_transition(Floating(self.state), blocking)
            case Floating(return_state):
                self.request_transition(return_state, blocking)
            case _:
                print('Arm is not Idle/Docked, cannot toggle compliance')
                return False

        return True
