#!/usr/bin/env python3

import os, sys, uuid, time

from hebi_tools.util import create_module_server, new_root_message
from hebi_tools.hebi_proto import FirmwareInfo

import typing

if typing.TYPE_CHECKING:
    from hebi_tools.hebi_proto import RootMessage

import numpy as np

# Helper functions taken from
# https://stackoverflow.com/a/13294427
import socket
import struct


def ip2int(addr):
    return struct.unpack("!I", socket.inet_aton(addr))[0]


def int2ip(addr):
    return socket.inet_ntoa(struct.pack("!I", addr))


# end stack overflow functions


def _handle_settings(device: "Device", request):
    if request.settings.HasField("name"):
        if request.settings.name.HasField("name"):
            device.name = request.settings.name.name
        if request.settings.name.HasField("family"):
            device.family = request.settings.name.family


def _handle_command(device: "Device", request):
    if request.command.actuator_command.HasField("position") and not np.any(
        np.isnan(request.command.actuator_command.position.offset)
    ):
        device.position = request.command.actuator_command.position.offset
        device.revolutions = request.command.actuator_command.position.revolutions

    if request.command.actuator_command.HasField("velocity"):
        device.velocity = request.command.actuator_command.velocity

    if request.command.actuator_command.HasField("effort"):
        device.effort = request.command.actuator_command.effort


def default_device_handler(device: "Device", request: "RootMessage", response: "RootMessage") -> bool:
    response.sender_id = device.sender_id
    response.echo.tx_time = request.echo.tx_time
    response.echo.payload = request.echo.payload
    response.echo.sequence_number = request.echo.sequence_number

    return_response = False

    if request.HasField("settings"):
        _handle_settings(device, request)

    if request.request_settings:
        response.settings.name.name = device.name
        response.settings.name.family = device.family
        return_response = True
    else:
        response.settings.ClearField("name")

    if request.request_firmware_info:
        response.firmware_info.type = device.fw_type
        response.firmware_info.revision = device.fw_rev
        response.firmware_info.mode = FirmwareInfo.APPLICATION
        return_response = True
    else:
        response.ClearField("firmware_info")

    if request.request_ethernet_info:
        response.ethernet_info.mac_address = device.mac_address
        response.ethernet_info.ip_address = device.ip_address
        response.ethernet_info.netmask = device.netmask
        return_response = True
    else:
        response.ClearField("ethernet_info")

    if request.request_hardware_info:
        response.hardware_info.serial_number = device.hw_serial_number
        response.hardware_info.mechanical_type = device.hw_mechanical_type
        response.hardware_info.mechanical_revision = device.hw_mechanical_revision
        response.hardware_info.electrical_type = device.hw_electrical_type
        response.hardware_info.electrical_revision = device.hw_electrical_revision
        return_response = True
    else:
        response.ClearField("hardware_info")

    if request.HasField("command"):
        _handle_command(device, request)
        return_response = True

    if request.request_feedback:
        response.feedback.actuator_feedback.position.revolutions = device.revolutions
        response.feedback.actuator_feedback.position.offset = device.position
        response.feedback.actuator_feedback.velocity = device.velocity
        response.feedback.actuator_feedback.effort = device.effort

        response.command.actuator_command.position.revolutions = device.revolutions
        response.command.actuator_command.position.offset = device.position
        response.command.actuator_command.velocity = device.velocity
        response.command.actuator_command.effort = device.effort

        response.feedback.orientation.w = device.orientation[0]
        response.feedback.orientation.x = device.orientation[1]
        response.feedback.orientation.y = device.orientation[2]
        response.feedback.orientation.z = device.orientation[3]
        return_response = True
    else:
        response.ClearField("feedback")

    return return_response


class Device:
    """
    Represents a single device in the virtual group
    """

    def __init__(self, family, name, subaddress_id, device_handler=default_device_handler):
        self._family = family
        self._name = name
        self._subaddress_id = subaddress_id
        self._response_message = new_root_message()
        self._request_handlers = list()
        self._request_handlers.append(device_handler)

        self._position = 0
        self._velocity = 0
        self._effort = 0

        self.revolutions = 0

        self._orientation = [1, 0, 0, 0]

        self._position_watchers = []
        self._velocity_watchers = []
        self._effort_watchers = []
        self._orientation_watchers = []

        sender_id = uuid.uuid4().int & 0xFFFFFFFFFFFFFFFF
        self._sender_id = sender_id
        self._mac_address = uuid.getnode()

        # TODO: Add valid IP info
        self._ip_address = ip2int("10.11.12.13")
        self._netmask = ip2int("255.255.255.0")

        # Hardware info
        self._fw_type = "HEBI Virtual Device Interface"
        self._fw_rev = "0.0"
        self._hw_serial_number = "0"
        self._hw_mechanical_type = "Virtual Device"
        self._hw_mechanical_revision = "0.0"
        self._hw_electrical_type = "Virtual Device"
        self._hw_electrical_revision = "0.0"

    def __repr__(self):
        return f"Device(name: '{self._name}', family: '{self._family}', subaddress: {self._subaddress_id})"

    def add_request_handler(self, handler):
        self._request_handlers.append(handler)

    def handle_request(self, request: "RootMessage"):
        response = self._response_message
        return_response = False
        for entry in self._request_handlers:
            ret = entry(self, request, response)
            if type(ret) is bool:
                return_response = return_response or ret

        return return_response

    @property
    def response_message(self):
        return self._response_message

    # ??? fields

    @property
    def position(self):
        return self._position

    @position.setter
    def position(self, value):
        try:
            float(value)
            self._position = value
            for func in self._position_watchers:
                func(self._position)
        except:
            raise TypeError("Position should be numerical (radians)")

    @property
    def velocity(self):
        return self._velocity

    @velocity.setter
    def velocity(self, value):
        try:
            float(value)
            self._velocity = value
            for func in self._velocity_watchers:
                func(self._velocity)
        except:
            raise TypeError("Velocity should be numerical (radians)")

    @property
    def effort(self):
        return self._effort

    @effort.setter
    def effort(self, value):
        try:
            float(value)
            self._effort = value
            for func in self._effort_watchers:
                func(self._effort)
        except:
            raise TypeError("Effort should be numerical (radians)")

    @property
    def orientation(self):
        return self._orientation

    @orientation.setter
    def orientation(self, value):
        try:
            self._orientation = [float(x) for x in value]
            for func in self._orientation_watchers:
                func(self._orientation)
        except:
            raise TypeError("Effort should be numerical (radians)")

    @property
    def family(self):
        return self._family

    @family.setter
    def family(self, value):
        self._family = value

    @property
    def name(self):
        return self._name

    @name.setter
    def name(self, value):
        self._name = value

    @property
    def sender_id(self):
        return self._sender_id

    @sender_id.setter
    def sender_id(self, value):
        self._sender_id = value

    # Network fields

    @property
    def subaddress_id(self):
        return self._subaddress_id

    @subaddress_id.setter
    def subaddress_id(self, value):
        self._subaddress_id = value

    @property
    def mac_address(self):
        return self._mac_address

    @mac_address.setter
    def mac_address(self, value):
        self._mac_address = value

    @property
    def ip_address(self):
        return self._ip_address

    @ip_address.setter
    def ip_address(self, value):
        self._ip_address = value

    @property
    def netmask(self):
        return self._netmask

    @netmask.setter
    def netmask(self, value):
        self._netmask = value

    # Firmware and hardware fields

    @property
    def fw_type(self):
        return self._fw_type

    @fw_type.setter
    def fw_type(self, value):
        self._fw_type = value

    @property
    def fw_rev(self):
        return self._fw_rev

    @fw_rev.setter
    def fw_rev(self, value):
        self._fw_rev = value

    @property
    def hw_serial_number(self):
        return self._hw_serial_number

    @hw_serial_number.setter
    def hw_serial_number(self, value):
        self._hw_serial_number = value

    @property
    def hw_mechanical_type(self):
        return self._hw_mechanical_type

    @hw_mechanical_type.setter
    def hw_mechanical_type(self, value):
        self._hw_mechanical_type = value

    @property
    def hw_mechanical_revision(self):
        return self._hw_mechanical_revision

    @hw_mechanical_revision.setter
    def hw_mechanical_revision(self, value):
        self._hw_mechanical_revision = value

    @property
    def hw_electrical_type(self):
        return self._hw_electrical_type

    @hw_electrical_type.setter
    def hw_electrical_type(self, value):
        self._hw_electrical_type = value

    @property
    def hw_electrical_revision(self):
        return self._hw_electrical_revision

    @hw_electrical_revision.setter
    def hw_electrical_revision(self, value):
        self._hw_electrical_revision = value


class VirtualDeviceServer:
    """
    TODO
    """

    def __init__(self, devices, port=None):
        self._devices_dict = dict()
        self._enabled = False
        self._port = port

        for device in devices:
            self._devices_dict[device.subaddress_id] = device

    # ordered by insertion order (seems reasonable for most cases)
    def __getitem__(self, idx):
        key = [*self._devices_dict][idx]
        return self._devices_dict[key]

    def __len__(self):
        return len(self._devices_dict)

    def run(self):
        request = new_root_message()
        server = create_module_server(self._port)
        start_time = int(time.monotonic_ns() / 1e3)

        self._enabled = True

        while self._enabled:
            incoming_sub_src, incoming_sub_dst, msg = server.recv_with_subaddress(request)
            if not msg:
                continue

            rx_time = int(time.monotonic_ns() / 1e3) - start_time

            sub_src = incoming_sub_dst
            sub_dst = incoming_sub_src

            if sub_src == 0x3FFF:
                devices = self._devices_dict.values()
                return_response = False
                for device in devices:
                    return_response = return_response or device.handle_request(request)
                for device in devices:
                    device.response_message.rx_time = rx_time
                    device.response_message.tx_time = int(time.monotonic_ns() / 1e3) - start_time
                    if request.HasField("echo") or return_response:
                        server.send(device.response_message, device.subaddress_id, None)
            elif sub_src not in self._devices_dict:
                pass
            else:
                device = self._devices_dict[sub_src]
                return_response = device.handle_request(request)
                device.response_message.rx_time = rx_time
                device.response_message.tx_time = int(time.monotonic_ns() / 1e3) - start_time
                if request.HasField("echo") or return_response:
                    server.send(device.response_message, device.subaddress_id, sub_dst)


def create_server(family, names, port=None, device_handler=default_device_handler):
    if type(names) is str:
        names = [names]
    if type(family) is str:
        family = [family] * len(names)

    if len(names) != len(family):
        raise ValueError("family and names of different size")

    num_devices = len(names)
    devices = list()

    if type(device_handler) is not list:
        device_handler = [device_handler for _ in names]

    for i in range(num_devices):
        sub_id = i + 1  # TODO: Should we start ids at 0?
        devices.append(Device(family[i], names[i], sub_id, device_handler=device_handler[i]))

    return VirtualDeviceServer(devices, port)
