#!/usr/bin/env python3

try:
    from typing import override
except:
    from typing_extensions import override

import time
import hebi
from os.path import join, dirname, realpath
import numpy as np
import numpy.typing as npt
from time import sleep
from scipy.spatial.transform import Rotation as R
from threading import Event

from hebi_haptic.arm_control import Floating, HEBIArmControl, synchronized


class HapticArmControl(HEBIArmControl):
    def __init__(self, lookup: 'hebi.Lookup', example_config: 'hebi.config.HebiConfig'):
        super().__init__(lookup, example_config)
        io_grp = lookup.get_group_from_names(example_config.families[0], 'handleIO')
        assert io_grp is not None
        self.io_grp = io_grp
        self.io_grp.feedback_frequency = self.arm._group.feedback_frequency
        while (io_fbk := self.io_grp.get_next_feedback()) is None:
            print('Looking for IO Board...')
            time.sleep(0.1)

        self.force_fbk = np.zeros(3)
        self.force_fbk_in_ee_frame = True

        self.io_fbk = io_fbk
        #self.q_io_init = R.from_quat(self.io_fbk.orientation, scalar_first=True) # For standard IO board
        self.q_io_init = R.from_quat(self.io_fbk.orientation, scalar_first=True) * R.from_euler('YZ', [np.pi, -np.pi/2]) # For IO micro micro (gumstick)

    def get_corrected_orientation(self):
        q_now = R.from_quat(self.io_fbk.orientation, scalar_first=True)
        #return self.q_io_init.inv() * q_now * R.from_euler('Y', -np.pi/2) # For standard IO board
        return self.q_io_init.inv() * q_now * R.from_euler('XZ', [-np.pi/2, np.pi/2]) # For IO micro micro (gumstick)

    @synchronized
    def calibrate(self):
        #self.q_io_init = R.from_quat(self.io_fbk.orientation, scalar_first=True)
        self.q_io_init = R.from_quat(self.io_fbk.orientation, scalar_first=True) * R.from_euler('YZ', [np.pi, -np.pi/2]) # For IO micro micro (gumstick)

    @override
    @synchronized
    def _update(self):
        self.io_grp.get_next_feedback(reuse_fbk=self.io_fbk)
        super()._update()
        self.apply_forces()

    @synchronized
    def apply_forces(self):
        # x y z r p y
        k_p = np.array([1000, 1000, 1000, 0, 0, 0])
        #k_p = np.array([0, 0, 0, 0, 0, 0])
        #k_d = np.array([5, 5, 5, 0, 0, 0])
        k_d = np.array([2.5, 2.5, 2.5, 0, 0, 0])
        #k_d = np.array([0, 0, 0, 0, 0, 0])

        #wall_locs = np.array([0.3, -0.1, -0.1])

        xyz = self.ee_xyz
        spring_deflection = np.zeros(6)
    
        #spring_deflection[:3] =  wall_locs - xyz 

        center = np.array([0, 0, 0.1])
        radius = np.sqrt(np.sum((xyz - center) ** 2))

        max_radius = 0.5
        radial_deflection = max(0.0, radius - max_radius)
        spring_deflection[:3] -= radial_deflection * (xyz / radius)
    
        #if xyz[0] < wall_locs[0]:
        #    spring_deflection[0] = 0
    
        #if xyz[1] > wall_locs[1]:
        #    spring_deflection[1] = 0
    
        #if xyz[2] > wall_locs[2]:
        #    spring_deflection[2] = 0
                
        spring_force = k_p * spring_deflection
    
        damping_force = np.zeros(6)
        ee_vel = self.ee_jacobian @ self.arm.last_feedback.velocity
        damping_force = - k_d * ee_vel

        wrench = spring_force + damping_force
        wrench[:3] += self.get_force_feedback_in_base_frame()
        self.ee_wrench = wrench

    @synchronized
    def set_force_feedback(self, force, in_ee_frame=True):
        self.force_fbk = np.array(force).flatten()
        self.force_fbk_in_ee_frame = in_ee_frame

    @synchronized
    def clear_force_feedback(self):
        self.force_fbk[:] = 0

    def get_force_feedback_in_base_frame(self):
        if not self.force_fbk_in_ee_frame:
            return self.force_fbk

        rot = self.get_corrected_orientation()
        out = rot.apply(self.force_fbk) 
        #print(self.force_fbk)
        return out.flatten()
