#!/usr/bin/env python3

import math
import sys

from hebi_tools.util import create_module_connection, new_root_message

class Units:
  def __init__(self, str_rep, abbrev, convert_from_sec, prec):
    self.str_rep = str_rep
    self.abbrev = abbrev
    self.convert_from_sec = convert_from_sec
    self.prec = prec
  
def printThresholds(threshes, left_padding = 0):
  width = 12
  thresholds = ' ' * (left_padding + int(width / 2))
  for t in threshes:
    strT = str(t)
    padding = width - len(strT)
    thresholds += ' ' * math.floor(padding / 2) + strT + ' ' * math.ceil(padding / 2)
  print(thresholds)
  lines = ' ' * (left_padding + int(width / 2))
  for t in threshes:
    padding = width - 1
    lines += ' ' * math.floor(padding / 2) + '|' + ' ' * math.ceil(padding / 2)
  print(lines)

def printBuckets(units, bucket_vals, left_padding = 0):
  width = 12
  buckets = ' ' * left_padding
  for b in bucket_vals:
    strB = ('{0' + units.prec + '}').format(b)
    padding = width - len(strB)
    lpad = int(padding/2)
    rpad = padding - lpad
    buckets += ' ' * lpad + strB + ' ' * rpad
  print(buckets)

def printBucketMatrix(units, x_threshes, y_threshes, buckets):
  printThresholds(x_threshes, left_padding = 10)
  for i in range(len(y_threshes)):
    startIndex = i * (len(x_threshes) + 1) 
    endIndex = (i + 1) * (len(x_threshes) + 1)
    printBuckets(units, buckets[startIndex:endIndex], left_padding = 10)
    print('{0} --- '.format(y_threshes[i]))
  startIndex = len(y_threshes) * (len(x_threshes) + 1) 
  printBuckets(units, buckets[startIndex:], left_padding = 10)

def exportThresholds(export_file, threshes, left_padding = 0, units = None):
  if units != None:
    export_file.write('Thresholds ({0}), '.format(units))
  else:
    export_file.write(', ')
  for p in range(left_padding):
    export_file.write(',')
  for t in threshes:
    export_file.write('{0}, '.format(t))
  export_file.write('inf\n')

def exportBuckets(export_file, units, values, left_padding = 0, write_label = True):
  if (write_label):
    export_file.write('Time (' + units.abbrev + '), ')
  for p in range(left_padding):
    export_file.write(',')
  for v in values:
    export_file.write(('{0' + units.prec + '}, ').format(v))
  export_file.write('\n')

def exportBucketMatrix(export_file, units, x_threshes, y_threshes, buckets, x_units, y_units):
  export_file.write(',,' + len(x_threshes) * (x_units + ', ') + '\n')
  exportThresholds(export_file, x_threshes, left_padding = 1)
  for i in range(len(y_threshes)):
    startIndex = i * (len(x_threshes) + 1) 
    endIndex = (i + 1) * (len(x_threshes) + 1)
    export_file.write('{0}, {1}, '.format(y_units, y_threshes[i]))
    exportBuckets(export_file, units, buckets[startIndex:endIndex], left_padding = 0, write_label = False)
  startIndex = len(y_threshes) * (len(x_threshes) + 1) 
  export_file.write('{0}, inf, '.format(y_units))
  exportBuckets(export_file, units, buckets[startIndex:], left_padding = 0, write_label = False)

def processData(export_mode, units: Units, export_filename, rtd, rtd2):
  if export_mode:
    print('Writing data to: {0}'.format(export_filename))
    export_file = open(export_filename, "w")

  if rtd.HasField('reboot_count'):
    reboot_count = rtd.reboot_count + (rtd2.reboot_count if rtd2 else 0)
    if export_mode:
      export_file.write('Reboot Count, {0}\n\n'.format(reboot_count))
    else:
      print('reboot count: {0}'.format(reboot_count))
      print('')
  if rtd.HasField('startup_position'):
    # Ony use latest! Don't worry about rtd2!
    revs = rtd.startup_position.revolutions
    offset = rtd.startup_position.offset
    if export_mode:
      export_file.write('Startup Position\nRevolutions, {0}\nOffset (radians), {1:.6f}\n\n'.format(revs, offset))
    else:
      print('startup position: {0}, {1:.6f}'.format(revs, offset))
      print('')
  if rtd.HasField('commanded_seconds'):
    seconds = (rtd.commanded_seconds + (rtd2.commanded_seconds if rtd2 else 0)) * units.convert_from_sec
    if export_mode:
      export_file.write('Commanded {0}, {1}\n\n'.format(units.str_rep, seconds))
    else:
      print(('commanded {0}: {1' + units.prec + '}').format(units.str_rep.lower(), seconds))
      print('')
  if rtd.HasField('total_seconds'):
    seconds = (rtd.total_seconds + (rtd2.total_seconds if rtd2 else 0)) * units.convert_from_sec
    if export_mode:
      export_file.write('Total {0}, {1}\n\n'.format(units.str_rep, seconds))
    else:
      print(('total {0}: {1' + units.prec + '}').format(units.str_rep.lower(), seconds))
      print('')
  if rtd.HasField('electrical_power'):
    if rtd2:
      buckets = [(rtd.electrical_power.buckets[i] + rtd2.electrical_power.buckets[i]) * units.convert_from_sec for i in range(len(rtd.electrical_power.buckets))]
    else:
      buckets = [v * units.convert_from_sec for v in rtd.electrical_power.buckets]
    if export_mode:
      export_file.write('Electrical Power, thresholds represent max value in bucket\n')
      exportThresholds(export_file, rtd.electrical_power.thresholds, units='W')
      exportBuckets(export_file, units, buckets)
      export_file.write('\n')
    else:
      print('got electrical power:')
      printThresholds(rtd.electrical_power.thresholds)
      printBuckets(units, buckets)
      print('')
  if rtd.HasField('mechanical_power'):
    if rtd2:
      buckets = [(rtd.mechanical_power.buckets[i] + rtd2.mechanical_power.buckets[i]) * units.convert_from_sec for i in range(len(rtd.mechanical_power.buckets))]
    else:
      buckets = [v * units.convert_from_sec for v in rtd.mechanical_power.buckets]
    if export_mode:
      export_file.write('Mechanical power matrix, Time (' + units.abbrev + ') spent in each bucket\n')
      exportBucketMatrix(export_file, units, rtd.mechanical_power.y_thresholds, rtd.mechanical_power.x_thresholds, buckets, x_units = "max torque (Nm)", y_units = "max speed (rad/s)")
      export_file.write('\n')
    else:
      print('got mechanical power matrix (torque on x axis, speed on y):')
      printBucketMatrix(units, rtd.mechanical_power.y_thresholds, rtd.mechanical_power.x_thresholds, buckets)
      print('')
  if rtd.HasField('winding_temperature'):
    if rtd2:
      buckets = [(rtd.winding_temperature.buckets[i] + rtd2.winding_temperature.buckets[i]) * units.convert_from_sec for i in range(len(rtd.winding_temperature.buckets))]
    else:
      buckets = [v * units.convert_from_sec for v in rtd.winding_temperature.buckets]
    if export_mode:
      export_file.write('Winding Temperature, thresholds represent max value in bucket\n')
      exportThresholds(export_file, rtd.winding_temperature.thresholds, units='deg C')
      exportBuckets(export_file, units, buckets)
      export_file.write('\n')
    else:
      print('got winding temperature:')
      printThresholds(rtd.winding_temperature.thresholds)
      printBuckets(units, buckets)
      print('')

argc = len(sys.argv)

export_mode = False
export_filename = ''
valid_args = False
hours = False
if argc == 2:
  valid_args = True
elif argc == 3:
  if sys.argv[2] == '-h':
    valid_args = True
    hours = True
elif argc == 4:
  if sys.argv[2] == '-x':
    valid_args = True
    export_mode = True
    export_filename = sys.argv[3]
elif argc == 5:
  if sys.argv[2] == '-h' and sys.argv[3] == '-x':
    valid_args = True
    hours = True
    export_mode = True
    export_filename = sys.argv[4]
  elif sys.argv[2] == '-x' and sys.argv[4] == '-h':
    valid_args = True
    hours = True
    export_mode = True
    export_filename = sys.argv[3]

if not valid_args:
  print('Usage:')
  print('to display on screen:')
  print('  get_runtime_data <dest_ip>')
  print('to output to csv:')
  print('  get_runtime_data <dest_ip> -x <base filename to export to>')
  print('add -h after <dest_ip> to use Hours instead of Seconds for data, e.g.')
  print('  get_runtime_data 10.10.10.100 -h -x my_log')
  exit(1)

ip_addr = sys.argv[1]
connection = create_module_connection(ip_addr)

message_out = new_root_message()
message_in = new_root_message()

# NOTE: This is required to get the settings back in feedback
message_out.request_settings = True
message_out.request_runtime_data = True

# Send the message to the module
print('Sending message...')
if not connection.send(message_out):
  print('ERROR: Failed when sending message. Exiting.')
  exit(1)

# Try to receive a message. Wait for 5 seconds
if not connection.recv(message_in):
  print('ERROR: Did not get a response from the module after {0} milliseconds.'.format(connection.timeout_ms))
  exit(1)

if not message_in.settings.HasField('runtime_data'):
  print('ERROR: Returned message did not contain runtime data.')
  exit(1)

print('')
print('Command succeeded as expected. Received runtime data from module:')

rtd = message_in.settings.runtime_data

if hours:
  units = Units("Hours", "hr", 1.0/3600.0, ":.2f")
else:
  units = Units("Seconds", "s", 1, "")

processData(export_mode, units, export_filename + "_trip.csv", rtd, None)

stored_rtd = message_in.settings.stored_runtime_data

processData(export_mode, units, export_filename + "_lifetime.csv", stored_rtd, rtd)
