from ctypes import c_ushort as _uint16_t

SourceSubaddressField = 2
DestSubaddressField = 4
SourceSubaddressFieldEnc = SourceSubaddressField << 3
DestSubaddressFieldEnc = DestSubaddressField << 3


def _ntohs(buffer):
  a = int(buffer[0])
  b = int(buffer[1])
  return _uint16_t((a << 8) | b).value


def _write_uint16(buffer, val):
  buffer[0:2] = val.to_bytes(2, byteorder='big', signed=False)


def _compute_varint_size(value):
  """
  Compute the size of a varint value.
  """
  if value is None: return 0
  if value <= 0x7f: return 1
  if value <= 0x3fff: return 2
  if value <= 0x1fffff: return 3
  if value <= 0xfffffff: return 4
  if value <= 0x7ffffffff: return 5
  if value <= 0x3ffffffffff: return 6
  if value <= 0x1ffffffffffff: return 7
  if value <= 0xffffffffffffff: return 8
  if value <= 0x7fffffffffffffff: return 9
  return 10


def _encode_varint(buffer, buffer_offset, value):
  """
  Encodes the value into the buffer, at the specified offset.

  :returns: The offset into the buffer immediately after the varint encoded
  """
  bits = value & 0x7f
  value >>= 7
  while value:
    buffer[buffer_offset] = 0x80|bits
    bits = value & 0x7f
    value >>= 7
    buffer_offset += 1
  buffer[buffer_offset] = bits
  buffer_offset += 1
  return buffer_offset


def _decode_varint(buffer, pos):
  result = 0
  shift = 0
  while 1:
    b = buffer[pos]
    result |= ((b & 0x7f) << shift)
    pos += 1
    if not (b & 0x80):
      return (result, pos)
    shift += 7
    if shift >= 64:
      raise ValueError('Too many bytes when decoding varint.')


def _read_subaddress_fields(buffer, buffer_size, buffer_offset):
  """
  It is assumed that the input is >= 2

  :return: (src_subaddress, dst_subaddress, buffer_offset)
  """
  src_subaddress = None
  dst_subaddress = None

  # Read first subaddress field (if it exists)
  field_id = buffer[buffer_offset]
  if field_id == SourceSubaddressFieldEnc:
    value, buffer_offset = _decode_varint(buffer, buffer_offset + 1)
    src_subaddress = value
  elif field_id == DestSubaddressFieldEnc:
    value, buffer_offset = _decode_varint(buffer, buffer_offset + 1)
    dst_subaddress = value
  else:
    return src_subaddress, dst_subaddress, buffer_offset

  # Read second subaddress field (if it exists)
  field_id = buffer[buffer_offset]
  if field_id == SourceSubaddressFieldEnc:
    value, buffer_offset = _decode_varint(buffer, buffer_offset + 1)
    src_subaddress = value
  elif field_id == DestSubaddressFieldEnc:
    value, buffer_offset = _decode_varint(buffer, buffer_offset + 1)
    dst_subaddress = value

  return src_subaddress, dst_subaddress, buffer_offset


class ModuleConnection:

  def __init__(self, connection):
    self._connection = connection
    self._in_buffer = bytearray(2048)
    self._out_buffer = bytearray(2048)

  def __repr__(self):
    return 'ModuleConnection(connection: {0})'.format(repr(self._connection))

  def __recv(self, out_msg):
    header_size = 2
    buffer = self._in_buffer
    read_len = self._connection.recv(buffer)
    if read_len < header_size:
      return None, None, None
    msg_len = _ntohs(buffer)
    # Note: for now, `msg_len` should always be equal to `read_len - 2`
    if msg_len != (read_len - 2):
      raise RuntimeWarning('Warning: received a message payload of size {0}; {1} was expected. Data is likely corrupt.'.format(msg_len, read_len - 2))

    if read_len < (header_size + msg_len):
      return None, None, None

    src_subaddress, dst_subaddress, buffer_offset = _read_subaddress_fields(buffer, msg_len, header_size)
    out_msg.ParseFromString(buffer[buffer_offset:read_len])
    return src_subaddress, dst_subaddress, out_msg

  @property
  def timeout_seconds(self):
    return self._connection.timeout_seconds

  @timeout_seconds.setter
  def timeout_seconds(self, val):
    self._connection.timeout_seconds = val

  @property
  def timeout_ms(self):
    return int(self._connection.timeout_seconds * 1000)

  def recv(self, out_msg):
    """
    :returns: out_msg on success, `None` otherwise
    """
    src_subaddress, dst_subaddress, out_msg = self.__recv(out_msg)
    return out_msg

  def recv_with_subaddress(self, out_msg):
    """
    :returns: (subaddress_src, subaddress_dst, out_msg) on success, (`None`, `None`, `None`) otherwise
    """
    return self.__recv(out_msg)

  def send(self, out_msg, src_subaddr=None, dst_subaddr=None):
    msg_length = out_msg.ByteSize()
    if msg_length == 0:
      return False

    buffer = self._out_buffer
    header_size = 2
    src_subaddr_size = _compute_varint_size(src_subaddr)
    dst_subaddr_size = _compute_varint_size(dst_subaddr)

    if src_subaddr_size > 0:
      # Add 1 to account for encoded tag
      src_subaddr_size += 1

    if dst_subaddr_size > 0:
      # Add 1 to account for encoded tag
      dst_subaddr_size += 1


    total_length = msg_length + header_size + src_subaddr_size + dst_subaddr_size

    if total_length > len(buffer):
      return False

    _write_uint16(buffer, msg_length + src_subaddr_size + dst_subaddr_size)

    buffer_offset = header_size

    if src_subaddr is not None:
      buffer_offset = _encode_varint(buffer, buffer_offset, SourceSubaddressFieldEnc)
      buffer_offset = _encode_varint(buffer, buffer_offset, src_subaddr)

    if dst_subaddr is not None:
      buffer_offset = _encode_varint(buffer, buffer_offset, DestSubaddressFieldEnc)
      buffer_offset = _encode_varint(buffer, buffer_offset, dst_subaddr)

    buffer[buffer_offset:total_length] = out_msg.SerializeToString()
    return self._connection.send(buffer, total_length)
