#include "ts_state_machine.h"

#include "main.h"
#include "shunt_monitoring.h"
#include "stm32h7xx_hal.h"
#include "stm32h7xx_hal_gpio.h"
#include "status_led.h"
#include <stdint.h>

TSStateHandle ts_state;

static uint32_t precharge_95_reached_timestamp = 0;
static uint32_t charging_check_timestamp = 0;
static uint32_t discharge_begin_timestamp = 0;

void ts_sm_init() {
  ts_state.current_state = TS_INACTIVE;
  ts_state.target_state = TS_INACTIVE;
  ts_state.error_source = 0;
}

void ts_sm_update() {
  if (ts_state.error_source) {
    ts_state.current_state = TS_ERROR;
  }

  switch (ts_state.current_state) {
  case TS_INACTIVE:
    ts_state.current_state = ts_sm_update_inactive();
    break;
  case TS_ACTIVE:
    ts_state.current_state = ts_sm_update_active();
    break;
  case TS_PRECHARGE:
    ts_state.current_state = ts_sm_update_precharge();
    break;
  case TS_DISCHARGE:
    ts_state.current_state = ts_sm_update_discharge();
    break;
  case TS_ERROR:
    ts_state.current_state = ts_sm_update_error();
    break;
  case TS_CHARGING_CHECK:
    ts_state.current_state = ts_sm_update_charging_check();
    break;
  case TS_CHARGING:
    ts_state.current_state = ts_sm_update_charging();
    break;
  }

  ts_sm_set_relay_positions(ts_state.current_state);
  status_led_state(ts_state.current_state, (TSErrorKind) ts_state.error_type);
}

TSState ts_sm_update_inactive() {
  if (ts_state.target_state == TS_ACTIVE) {
    if (sdc_closed) {
      precharge_95_reached_timestamp = 0;
      return TS_PRECHARGE;
    } else {
      return TS_DISCHARGE;
    }
  } else if (ts_state.target_state == TS_CHARGING) {
    if (sdc_closed) {
      charging_check_timestamp = HAL_GetTick();
      return TS_CHARGING_CHECK;
    } else {
      return TS_DISCHARGE;
    }
  }

  return TS_INACTIVE;
}

TSState ts_sm_update_active() {
  if (ts_state.target_state == TS_INACTIVE || !sdc_closed) {
    discharge_begin_timestamp = HAL_GetTick();
    return TS_DISCHARGE;
  }

  return TS_ACTIVE;
}

TSState ts_sm_update_precharge() {
  if (ts_state.target_state == TS_INACTIVE || !sdc_closed) {
    discharge_begin_timestamp = HAL_GetTick();
    return TS_DISCHARGE;
  }
  if (shunt_data.voltage_veh > MIN_VEHICLE_SIDE_VOLTAGE &&
      shunt_data.voltage_veh > 0.95 * shunt_data.voltage_bat) {
    uint32_t now = HAL_GetTick();
    if (precharge_95_reached_timestamp == 0) {
      precharge_95_reached_timestamp = now;
    } else if ((now - precharge_95_reached_timestamp) >= PRECHARGE_95_DURATION) {
      precharge_95_reached_timestamp = 0;
      return TS_ACTIVE;
    }
  }

  return TS_PRECHARGE;
}

TSState ts_sm_update_discharge() {
  if (HAL_GetTick() - discharge_begin_timestamp >= DISCHARGE_DURATION) {
    return TS_INACTIVE;
  } else {
    return TS_DISCHARGE;
  }
}

TSState ts_sm_update_error() {
  static uint32_t no_error_since = 0;
  if (ts_state.error_source == 0) {
    uint32_t now = HAL_GetTick();
    if (no_error_since == 0) {
      no_error_since = now;
    } else if (now - no_error_since > NO_ERROR_TIME) {
      no_error_since = 0;
      HAL_GPIO_WritePin(AMS_NERROR_GPIO_Port, AMS_NERROR_Pin, GPIO_PIN_SET);
      return TS_INACTIVE;
    }
  }

  HAL_GPIO_WritePin(AMS_NERROR_GPIO_Port, AMS_NERROR_Pin, GPIO_PIN_RESET);
  return TS_ERROR;
}

TSState ts_sm_update_charging_check() {
  if (ts_state.target_state == TS_INACTIVE || !sdc_closed) {
    discharge_begin_timestamp = HAL_GetTick();
    return TS_DISCHARGE;
  }

  if (shunt_data.voltage_veh > shunt_data.voltage_bat) {
    return TS_CHARGING;
  } else if (HAL_GetTick() - charging_check_timestamp >
             MAX_CHARGING_CHECK_DURATION) {
    return TS_ERROR;
  }

  return TS_CHARGING_CHECK;
}

TSState ts_sm_update_charging() {
  if (ts_state.target_state == TS_INACTIVE || !sdc_closed) {
    discharge_begin_timestamp = HAL_GetTick();
    return TS_DISCHARGE;
  }
  if (shunt_data.current < 0) {
    return TS_ERROR;
  }

  return TS_CHARGING;
}

void ts_sm_set_relay_positions(TSState state) {
  switch (state) {
  case TS_INACTIVE:
  case TS_DISCHARGE:
  case TS_ERROR:
    ts_sm_set_relay_position(RELAY_NEG, 0);
    ts_sm_set_relay_position(RELAY_POS, 0);
    ts_sm_set_relay_position(RELAY_PRECHARGE, 0);
    break;
  case TS_ACTIVE:
  case TS_CHARGING:
    ts_sm_set_relay_position(RELAY_NEG, 1);
    ts_sm_set_relay_position(RELAY_POS, 1);
    ts_sm_set_relay_position(RELAY_PRECHARGE, 1);
    // TODO: Open precharge relay after a while
    break;
  case TS_PRECHARGE:
  case TS_CHARGING_CHECK:
    ts_sm_set_relay_position(RELAY_NEG, 1);
    ts_sm_set_relay_position(RELAY_POS, 0);
    ts_sm_set_relay_position(RELAY_PRECHARGE, 1);
    break;
  }
}

void ts_sm_set_relay_position(Relay relay, int closed) {
  static int neg_closed = 0;
  static int pos_closed = 0;
  static int precharge_closed = 0;

  GPIO_PinState state = closed ? GPIO_PIN_SET : GPIO_PIN_RESET;
  switch (relay) {
  case RELAY_NEG:
    ts_sm_check_close_wait(&neg_closed, closed);
    HAL_GPIO_WritePin(NEG_AIR_CTRL_GPIO_Port, NEG_AIR_CTRL_Pin, state);
    break;
  case RELAY_POS:
    ts_sm_check_close_wait(&pos_closed, closed);
    HAL_GPIO_WritePin(POS_AIR_CTRL_GPIO_Port, POS_AIR_CTRL_Pin, state);
    break;
  case RELAY_PRECHARGE:
    ts_sm_check_close_wait(&precharge_closed, closed);
    HAL_GPIO_WritePin(PRECHARGE_CTRL_GPIO_Port, PRECHARGE_CTRL_Pin, state);
    break;
  }
}

void ts_sm_check_close_wait(int *is_closed, int should_close) {
  static uint32_t last_close_timestamp = 0;
  if (should_close != *is_closed) {
    *is_closed = should_close;
    if (should_close) {
      uint32_t dt = HAL_GetTick() - last_close_timestamp;
      if (dt < RELAY_CLOSE_WAIT) {
        HAL_Delay(RELAY_CLOSE_WAIT - dt);
      }
      last_close_timestamp = HAL_GetTick();
    }
  }
}

void ts_sm_handle_ams_in(const uint8_t *data) {
  if (data[0] & 0x01) {
    ts_state.target_state = TS_ACTIVE;
  } else {
    ts_state.target_state = TS_INACTIVE;
  }
}

void ts_sm_set_error_source(TSErrorSource source, TSErrorKind error_type, bool is_errored) {
  if (is_errored) {
    ts_state.error_source |= source;
    ts_state.error_type = error_type;
  } else {
    ts_state.error_source &= ~source;
    ts_state.error_type = ~error_type;
  }
}