#include "ClockSync.h"

#include "AMS_CAN.h"

#include "stm32f412rx.h"
#include "stm32f4xx_hal.h"
#include "stm32f4xx_hal_can.h"

#include <stdint.h>

ClockSyncState clock_sync_state = CLOCK_SYNC_FREQ_HOPPING_STAGE1;

static uint32_t last_clock_sync_frame_time = 0;
static uint32_t last_master_heartbeat_time = 0;
static uint32_t master_heartbeat_counter = 0;

static uint32_t freq_hopping_start_trim = 0;
static uint32_t freq_hopping_iteration = 0;
static uint32_t freq_hopping_stage2_start_time = 0;
static uint32_t freq_hopping_stage2_start_counter = 0;
static uint32_t freq_hopping_stage2_attempts = 0;

void clock_sync_update() {
  ClockSyncState next_state;
  switch (clock_sync_state) {
  case CLOCK_SYNC_NORMAL_OPERATION:
    next_state = clock_sync_update_normal_operation();
    break;
  case CLOCK_SYNC_FREQ_HOPPING_STAGE1:
    next_state = clock_sync_update_freq_hopping_stage1();
    break;
  case CLOCK_SYNC_FREQ_HOPPING_STAGE2:
    next_state = clock_sync_update_freq_hopping_stage2();
    break;
  default:
    // Shouldn't ever happen?
    next_state = CLOCK_SYNC_FREQ_HOPPING_STAGE1;
  }

  if (next_state != clock_sync_state) {
    switch (next_state) {
    case CLOCK_SYNC_NORMAL_OPERATION:
      clock_sync_start_normal_operation();
      break;
    case CLOCK_SYNC_FREQ_HOPPING_STAGE1:
      clock_sync_start_freq_hopping_stage1();
      break;
    case CLOCK_SYNC_FREQ_HOPPING_STAGE2:
      clock_sync_start_freq_hopping_stage2();
      break;
    }
  }
  clock_sync_state = next_state;
}

void clock_sync_start_normal_operation() {}

void clock_sync_start_freq_hopping_stage1() {
  freq_hopping_start_trim = get_hsi_trim();
  freq_hopping_iteration = 0;
}

void clock_sync_start_freq_hopping_stage2() {
  freq_hopping_start_trim = get_hsi_trim();
  freq_hopping_stage2_start_time = HAL_GetTick();
  freq_hopping_stage2_start_counter = master_heartbeat_counter;
  freq_hopping_stage2_attempts = 0;
}

ClockSyncState clock_sync_update_normal_operation() {
  uint32_t now = HAL_GetTick();
  uint8_t transmit_errors =
      (ams_can_handle->Instance->ESR & CAN_ESR_TEC_Msk) >> CAN_ESR_TEC_Pos;
  if (now - last_master_heartbeat_time > MASTER_HEARTBEAT_DESYNC_THRESH ||
      transmit_errors > CLOCK_SYNC_MAX_TRANSMIT_ERRORS) {
    return CLOCK_SYNC_FREQ_HOPPING_STAGE1;
  }

  return CLOCK_SYNC_NORMAL_OPERATION;
}

ClockSyncState clock_sync_update_freq_hopping_stage1() {
  uint32_t now = HAL_GetTick();
  if (now - last_clock_sync_frame_time < CLOCK_SYNC_SANITY_INTERVAL_MAX) {
    // We are at least close to re-sync'ing, go to stage 2
    return CLOCK_SYNC_FREQ_HOPPING_STAGE2;
  }

  if (now - last_master_heartbeat_time > MASTER_HEARTBEAT_SANITY_INTERVAL_MAX) {
    uint8_t new_trim = calculate_freq_hopping_trim(freq_hopping_iteration);
    set_hsi_trim(new_trim);

    freq_hopping_iteration++;
    if ((freq_hopping_iteration + 1) * FREQ_HOPPING_TRIM_STEPS >
        RCC_CR_HSITRIM_MAX) {
      // The next delta would be too large, start again
      freq_hopping_iteration = 0;
    }
  }
  return CLOCK_SYNC_FREQ_HOPPING_STAGE1;
}

ClockSyncState clock_sync_update_freq_hopping_stage2() {
  if (master_heartbeat_counter - freq_hopping_stage2_start_counter >
      FREQ_HOPPING_STAGE2_FRAMES) {
    // We've re-sync'd!
    return CLOCK_SYNC_NORMAL_OPERATION;
  }

  uint32_t now = HAL_GetTick();
  if (now - freq_hopping_stage2_start_time >
      FREQ_HOPPING_STAGE2_FRAMES * MASTER_HEARTBEAT_SANITY_INTERVAL_MAX) {
    freq_hopping_stage2_attempts++;
    if (freq_hopping_stage2_attempts > FREQ_HOPPING_STAGE2_MAX_ATTEMPTS) {
      // Looks like we're not really close to sync'ing, go back to stage 1
      return CLOCK_SYNC_FREQ_HOPPING_STAGE1;
    }
    // We haven't received all heartbeats, trim further
    uint8_t new_trim =
        calculate_freq_hopping_trim(freq_hopping_stage2_attempts);
    set_hsi_trim(new_trim);
    freq_hopping_stage2_start_counter = master_heartbeat_counter;
    freq_hopping_stage2_start_time = now;
  }

  return CLOCK_SYNC_FREQ_HOPPING_STAGE2;
}

void clock_sync_handle_clock_sync_frame(uint8_t counter) {
  static uint32_t f_pre_trim = CLOCK_TARGET_FREQ;
  static int32_t trimmed_last_frame = 0;
  static int32_t last_trim_delta = HSI_TRIM_FREQ;
  static uint8_t last_clock_sync_frame_counter = 0;

  uint32_t now = HAL_GetTick();
  uint32_t n_measured = now - last_clock_sync_frame_time;
  uint8_t expected_counter = last_clock_sync_frame_counter + 1;
  /* Sanity checks:
   * - Are we actually in normal operation mode?
   * - Have we received a sync frame before?
   * - Did the counter increment by one (mod 2^8)? I.e., did we miss a frame?
   * - Is the measured time elapsed within feasible bounds?
   */
  if (clock_sync_state == CLOCK_SYNC_NORMAL_OPERATION &&
      last_clock_sync_frame_time != 0 && counter == expected_counter &&
      n_measured >= CLOCK_SYNC_SANITY_INTERVAL_MIN &&
      n_measured <= CLOCK_SYNC_SANITY_INTERVAL_MAX) {
    uint32_t f_real = n_measured * (CLOCK_TARGET_FREQ / CLOCK_SYNC_INTERVAL);

    if (trimmed_last_frame) {
      // Update trim delta
      last_trim_delta = f_pre_trim - f_real;
      if (last_trim_delta == 0) {
        last_trim_delta = HSI_TRIM_FREQ;
      } else if (last_trim_delta < 0) {
        last_trim_delta = -last_trim_delta;
      }
      trimmed_last_frame = 0;
    }

    int32_t delta_f = CLOCK_TARGET_FREQ - f_real;
    int32_t delta_quants = delta_f / last_trim_delta;
    if (delta_quants != 0) {
      // We were able to receive the frame, so we should be reasonably close. It
      // should thus be enough to trim by -1 or 1.
      int32_t trim_delta = (delta_quants < 0) ? -1 : 1;
      trim_hsi_by(trim_delta);
      f_pre_trim = f_real;
      trimmed_last_frame = 1;
    }
  }
  last_clock_sync_frame_time = now;
  last_clock_sync_frame_counter = counter;
}

void clock_sync_handle_master_heartbeat() {
  last_master_heartbeat_time = HAL_GetTick();
  master_heartbeat_counter++;
}

uint8_t get_hsi_trim() {
  return (RCC->CR & RCC_CR_HSITRIM_Msk) >> RCC_CR_HSITRIM_Pos;
}

void set_hsi_trim(uint8_t trim) {
  uint32_t rcc_cr = RCC->CR;
  // Clear current trim and overwrite with new trim
  rcc_cr = (rcc_cr & ~RCC_CR_HSITRIM_Msk) |
           ((trim << RCC_CR_HSITRIM_Pos) & RCC_CR_HSITRIM_Msk);
  RCC->CR = rcc_cr;
}

void trim_hsi_by(int32_t delta) {
  // Determine current trim
  int32_t trim = get_hsi_trim();
  trim += delta;
  if (trim > RCC_CR_HSITRIM_MAX) {
    trim = RCC_CR_HSITRIM_MAX;
  } else if (trim < 0) {
    trim = 0;
  }
  set_hsi_trim(trim);
}

uint8_t calculate_freq_hopping_trim(uint32_t freq_hopping_iteration) {
  int32_t trim_delta = (freq_hopping_iteration + 1) * FREQ_HOPPING_TRIM_STEPS;
  if (freq_hopping_iteration % 2 == 0) {
    trim_delta = -trim_delta;
  }

  int32_t new_trim = freq_hopping_start_trim + trim_delta;
  if (new_trim < 0) {
    new_trim += RCC_CR_HSITRIM_MAX + 1;
  } else if (new_trim > RCC_CR_HSITRIM_MAX) {
    new_trim -= RCC_CR_HSITRIM_MAX + 1;
  }
}