#! /usr/bin/env python3
import time
import struct
import socket
from hebi_tools.hebi_proto import FirmwareInfo

# TODOs:
# - set more config parameters (fbk frequency?)
# - see other code samples for anything else we are missing.
# - report serial, etc...
# - check CRC!
# - better constants for params and ports..
# - check for startup 'wh' on telnet? 'wh,2,0' means timeout from last connection...
# - Hex write for float parameters...


class BotaConfigConnection:
    # TODO: PARAM LISTS HERE!
    # TODO: initial read on sensor start after first boot!
    def __init__(self, ip):
        self.bota_conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.bota_conn.connect((ip, 23))

    def _parse_response(self, cmd, data: bytes):
        data = data.decode("utf-8")
        data = data.strip()
        data = data.split(",")
        if data[0] != cmd:
            raise Exception(f"Got back wrong command: {data[0]}")
        if data[1] != "0":
            raise Exception(f"Error in response: {data[1]}")
        return data[2]

    def _read_param(self, param: str, subparam: str) -> str:
        cmd = b"ra," + bytes(param, encoding="utf-8") + b"," + bytes(subparam, encoding="utf-8") + b",0\n"
        self.bota_conn.send(cmd)
        data = b""
        while len(data) == 0 or data[-1] != (b"\n")[0]:
            data += self.bota_conn.recv(100)
        return self._parse_response("ra", data)

    def _set_param(self, param: str, subparam: str, value: str) -> bool:
        cmd = (
            b"wa,"
            + bytes(param, encoding="utf-8")
            + b","
            + bytes(subparam, encoding="utf-8")
            + b","
            + bytes(value, encoding="utf-8")
            + b"\n"
        )
        self.bota_conn.send(cmd)
        data = b""
        while len(data) == 0 or data[-1] != (b"\n")[0]:
            data += self.bota_conn.recv(100)
        res = self._parse_response("wa", data)
        return res == value

    def set_mode_run(self) -> bool:
        cur_mode = self._read_param("1", "1")
        if cur_mode == "2":
            return True
        # If it doesn't match, try writing
        res = self._set_param("1", "2", "2")
        return res

    def set_mode_config(self) -> bool:
        cur_mode = self._read_param("1", "1")
        if cur_mode == "1":
            return True
        # If it doesn't match, try writing
        res = self._set_param("1", "2", "1")
        return res

    # Only in config mode!
    def set_app_mode_imu_wrench(self) -> bool:
        cur_mode = self._read_param("3", "1")
        if cur_mode == "2":
            return True
        # If it doesn't match, try writing
        res = self._set_param("3", "1", "2")
        return res

    # Only in config mode!
    def set_app_mode_wrench(self) -> bool:
        cur_mode = self._read_param("3", "1")
        if cur_mode == "1":
            return True
        # If it doesn't match, try writing
        res = self._set_param("3", "1", "1")
        return res

    def set_wrench_offsets(self, fx, fy, fz, tx, ty, tz) -> bool:
        # TODO - do hex writes instead, then actually check success!
        # success = self._set_param('2', '1', str(fx))
        # success = success and self._set_param('2', '2', str(fy))
        # success = success and self._set_param('2', '3', str(fz))
        # success = success and self._set_param('2', '4', str(tx))
        # success = success and self._set_param('2', '5', str(ty))
        # success = success and self._set_param('2', '6', str(tz))
        # return success
        curr_fx = float(self._read_param("2", "1"))
        curr_fy = float(self._read_param("2", "2"))
        curr_fz = float(self._read_param("2", "3"))
        curr_tx = float(self._read_param("2", "4"))
        curr_ty = float(self._read_param("2", "5"))
        curr_tz = float(self._read_param("2", "6"))
        self._set_param("2", "1", str(fx + curr_fx))
        self._set_param("2", "2", str(fy + curr_fy))
        self._set_param("2", "3", str(fz + curr_fz))
        self._set_param("2", "4", str(tx + curr_tx))
        self._set_param("2", "5", str(ty + curr_ty))
        self._set_param("2", "6", str(tz + curr_tz))
        return True

    def shutdown(self):
        self.set_mode_config()
        self.bota_conn.close()


## BASIC FEEDBACK:
# A1-6 wrench
# B1 timestamp (us)
# B2 status (bitfield; 1-4 used)

## TARE INTERFACE:
# B7 number of errors during tare
# B8 number of tares done
# E8 trigger tare; edge trigger going high, use F3-8 for desired force/torque output after tare


def get_float_pin_value(proto_in) -> float:
    if proto_in.HasField("float_value"):
        return proto_in.float_value
    elif proto_in.HasField("int_value"):
        return float(proto_in.int_value)
    return 0


class Bota:
    # AXIS_MAX = pow(2, 15)
    # TRIGGER_MAX = 256
    def __init__(self, ip):
        self.enter_run_mode_packet_countdown = 1
        self.force = [0.0, 0.0, 0.0]
        self.torque = [0.0, 0.0, 0.0]
        self.accel = [0.0, 0.0, 0.0]
        self.gyro = [0.0, 0.0, 0.0]
        self.bota_status = -1
        self.timestamp = -1
        self.temp = -1
        self.tare_error_count = 0
        self.tare_count = 0
        self.tare_command_pin = 0
        self.shutdown = False
        self.ip = ip
        # If we haven't heard from the bota in 5 seconds, try to reconnect
        self.REFRESH_DEVICE_TIMEOUT_SECONDS = 5
        self._init_connection()

        self.listener = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        self.listener.bind(("0.0.0.0", 30302))
        self.listener.settimeout(self.REFRESH_DEVICE_TIMEOUT_SECONDS)

    def _init_connection(self):
        try:
            print("Connecting")
            self.conn = BotaConfigConnection(self.ip)
            self.conn.set_mode_config()
            self.conn.set_app_mode_imu_wrench()
            self.conn.set_mode_run()
            self.disconnected = False
        except:
            self.conn = None
            self.disconnected = True
            print("Waiting for sensor...\n")

        # Set up feedback listening socket:

    def parse_feedback(self, data: bytes):
        if len(data) < 1:
            print("Empty feedback packet!")
            return

        if self.enter_run_mode_packet_countdown > 0:
            self.enter_run_mode_packet_countdown -= 1
            return

        header = data[0]
        # Wrench only mode; header 0xAA
        if len(data) == 37 and header == 170:  # 0xAA
            wrench_only = True
            crc = int.from_bytes(data[35:37], "little", signed=False)
        elif len(data) == 61 and header == 171:  # 0xAB
            wrench_only = False
            crc = int.from_bytes(data[59:61], "little", signed=False)
        else:
            print(f"Unknown or malformed feedback packet! {len(data)} / {header}")
            return

        # TODO: CHECK CRC!

        self.bota_status = int.from_bytes(data[1:3], "little", signed=False)
        self.force[0] = struct.unpack("<f", data[3:7])[0]
        self.force[1] = struct.unpack("<f", data[7:11])[0]
        self.force[2] = struct.unpack("<f", data[11:15])[0]
        self.torque[0] = struct.unpack("<f", data[15:19])[0]
        self.torque[1] = struct.unpack("<f", data[19:23])[0]
        self.torque[2] = struct.unpack("<f", data[23:27])[0]
        self.timestamp = int.from_bytes(data[27:31], "little", signed=False)
        self.temp = struct.unpack("<f", data[31:35])[0]

        if not wrench_only:
            self.accel[0] = struct.unpack("<f", data[35:39])[0]
            self.accel[1] = struct.unpack("<f", data[39:43])[0]
            self.accel[2] = struct.unpack("<f", data[43:47])[0]
            self.gyro[0] = struct.unpack("<f", data[47:51])[0]
            self.gyro[1] = struct.unpack("<f", data[51:55])[0]
            self.gyro[2] = struct.unpack("<f", data[55:59])[0]

    def run(self):
        while not self.shutdown:
            if self.disconnected:
                time.sleep(1)
                self._init_connection()
                continue

            # Try to get packet!
            try:
                msg = self.listener.recv(70)  # This is big enough for wrench and wrench + IMU feedback messages
                self.parse_feedback(msg)
            except socket.timeout:
                print("Disconnection detected")
                self.enter_run_mode_packet_countdown = 1
                self.disconnected = True;
            except KeyboardInterrupt:
                self.shutdown = True
                break

        if not self.disconnected:
            self.conn.shutdown()

    def device_handler(self, device, request, response):
        response.sender_id = device.sender_id
        response.echo.tx_time = request.echo.tx_time
        response.echo.payload = request.echo.payload
        response.echo.sequence_number = request.echo.sequence_number

        return_response = False

        if request.HasField("settings"):
            if request.settings.HasField("name"):
                if request.settings.name.HasField("name"):
                    device.name = request.settings.name.name
                if request.settings.name.HasField("family"):
                    device.family = request.settings.name.family

        if request.request_settings:
            response.settings.name.name = device.name
            response.settings.name.family = device.family
            return_response = True

        else:
            response.settings.ClearField("name")

        if request.request_firmware_info:
            response.firmware_info.type = device.fw_type
            response.firmware_info.revision = device.fw_rev
            response.firmware_info.mode = FirmwareInfo.APPLICATION
            return_response = True
        else:
            response.ClearField("firmware_info")

        if request.request_ethernet_info:
            response.ethernet_info.mac_address = device.mac_address
            response.ethernet_info.ip_address = device.ip_address
            response.ethernet_info.netmask = device.netmask
            return_response = True
        else:
            response.ClearField("ethernet_info")

        if request.request_hardware_info:
            response.hardware_info.serial_number = device.hw_serial_number
            response.hardware_info.mechanical_type = device.hw_mechanical_type
            response.hardware_info.mechanical_revision = device.hw_mechanical_revision
            response.hardware_info.electrical_type = device.hw_electrical_type
            response.hardware_info.electrical_revision = device.hw_electrical_revision
            return_response = True
        else:
            response.ClearField("hardware_info")

        if request.request_feedback and not self.disconnected:
            response.feedback.io_feedback.a.pin1.float_value = self.force[0]
            response.feedback.io_feedback.a.pin2.float_value = self.force[1]
            response.feedback.io_feedback.a.pin3.float_value = self.force[2]
            response.feedback.io_feedback.a.pin4.float_value = self.torque[0]
            response.feedback.io_feedback.a.pin5.float_value = self.torque[1]
            response.feedback.io_feedback.a.pin6.float_value = self.torque[2]

            response.feedback.wrench.force.x = self.force[0]
            response.feedback.wrench.force.y = self.force[1]
            response.feedback.wrench.force.z = self.force[2]
            response.feedback.wrench.torque.x = self.torque[0]
            response.feedback.wrench.torque.y = self.torque[1]
            response.feedback.wrench.torque.z = self.torque[2]

            response.feedback.io_feedback.b.pin1.int_value = self.timestamp
            response.feedback.io_feedback.b.pin2.int_value = self.bota_status
            response.feedback.io_feedback.b.pin7.int_value = self.tare_error_count
            response.feedback.io_feedback.b.pin8.int_value = self.tare_count

            response.feedback.accel.x = self.accel[0]
            response.feedback.accel.y = self.accel[1]
            response.feedback.accel.z = self.accel[2]
            response.feedback.gyro.x = self.gyro[0]
            response.feedback.gyro.y = self.gyro[1]
            response.feedback.gyro.z = self.gyro[2]

            response.feedback.ambient_temperature = self.temp

            # Enable tare:
            response.feedback.io_feedback.e.pin8.int_value = 0
            response.feedback.io_feedback.f.pin3.float_value = 0
            response.feedback.io_feedback.f.pin4.float_value = 0
            response.feedback.io_feedback.f.pin5.float_value = 0
            response.feedback.io_feedback.f.pin6.float_value = 0
            response.feedback.io_feedback.f.pin7.float_value = 0
            response.feedback.io_feedback.f.pin8.float_value = 0

            return_response = True
        else:
            response.ClearField("feedback")

        ## TARE INTERFACE:
        # E8 trigger tare; edge trigger going high, use F3-8 for desired force/torque output after tare
        if request.command.io_command.e.pin8.HasField("int_value"):
            prev_value = self.tare_command_pin
            self.tare_command_pin = request.command.io_command.e.pin8.int_value
            if prev_value == 0 and self.tare_command_pin == 1:
                tare_fx = get_float_pin_value(request.command.io_command.f.pin3)
                tare_fy = get_float_pin_value(request.command.io_command.f.pin4)
                tare_fz = get_float_pin_value(request.command.io_command.f.pin5)
                tare_tx = get_float_pin_value(request.command.io_command.f.pin6)
                tare_ty = get_float_pin_value(request.command.io_command.f.pin7)
                tare_tz = get_float_pin_value(request.command.io_command.f.pin8)
                self.conn.set_mode_config()
                success = self.conn.set_wrench_offsets(
                    tare_fx - self.force[0],
                    tare_fy - self.force[1],
                    tare_fz - self.force[2],
                    tare_tx - self.torque[0],
                    tare_ty - self.torque[1],
                    tare_tz - self.torque[2],
                )
                self.conn.set_mode_run()
                self.enter_run_mode_packet_countdown = 1
                if success:
                    self.tare_count += 1
                else:
                    self.tare_error_count += 1

        return return_response
