#include "soc_estimation.h"

#include "shunt_monitoring.h"
#include "slave_monitoring.h"

#include "stm32h7xx_hal.h"

#include <stddef.h>
#include <stdint.h>

#define SOC_ESTIMATION_NO_CURRENT_THRESH 200     // mA
#define SOC_ESTIMATION_NO_CURRENT_TIME 100000    // ms
#define SOC_ESTIMATION_BATTERY_CAPACITY 70300800 // mAs
ocv_soc_pair_t OCV_SOC_PAIRS[] = {
    {25000, 0.00f},  {29900, 3.97f},  {32300, 9.36f},  {33200, 12.60f},
    {33500, 13.68f}, {34100, 20.15f}, {35300, 32.01f}, {38400, 66.53f},
    {40100, 83.79f}, {40200, 90.26f}, {40400, 94.58f}, {41000, 98.89f},
    {42000, 100.00f}};

float current_soc;

int current_was_flowing;
uint32_t last_current_time;
float soc_before_current;
float mAs_before_current;

void soc_init() {
  current_soc = 0;
  last_current_time = 0;
  current_was_flowing = 1;
}

void soc_update() {
  uint32_t now = HAL_GetTick();
  if (shunt_data.current >= SOC_ESTIMATION_NO_CURRENT_THRESH) {
    last_current_time = now;
    if (!current_was_flowing) {
      soc_before_current = current_soc;
      mAs_before_current = shunt_data.current_counter;
    }
    current_was_flowing = 1;
  } else {
    current_was_flowing = 0;
  }

  if (now - last_current_time >= SOC_ESTIMATION_NO_CURRENT_TIME ||
      last_current_time == 0) {
    // Assume we're measuring OCV if there's been no current for a while (or
    // we've just turned on the battery).
    current_soc = soc_for_ocv(min_voltage);
  } else {
    // Otherwise, use the current counter to update SoC
    float as_delta = shunt_data.current_counter - mAs_before_current;
    float soc_delta = as_delta / SOC_ESTIMATION_BATTERY_CAPACITY * 100;
    current_soc = soc_before_current + soc_delta;
  }
}

float soc_for_ocv(uint16_t ocv) {
  size_t i = 0;
  size_t array_length = sizeof(OCV_SOC_PAIRS) / sizeof(*OCV_SOC_PAIRS);
  // Find the index of the first element with OCV greater than the target OCV
  while (i < array_length && OCV_SOC_PAIRS[i].ocv <= ocv) {
    i++;
  }

  // If the target OCV is lower than the smallest OCV in the array, return the
  // first SOC value
  if (i == 0) {
    return OCV_SOC_PAIRS[0].soc;
  }

  // If the target OCV is higher than the largest OCV in the array, return the
  // last SOC value
  if (i == array_length) {
    return OCV_SOC_PAIRS[array_length - 1].soc;
  }

  // Perform linear interpolation
  uint16_t ocv1 = OCV_SOC_PAIRS[i - 1].ocv;
  uint16_t ocv2 = OCV_SOC_PAIRS[i].ocv;
  float soc1 = OCV_SOC_PAIRS[i - 1].soc;
  float soc2 = OCV_SOC_PAIRS[i].soc;

  float slope = (soc2 - soc1) / (ocv2 - ocv1);
  float interpolated_soc = soc1 + slope * (ocv - ocv1);

  return interpolated_soc;
}