#include "battery.h"
#include "ADBMS_Driver.h"
#include "NTC.h"
#include "can.h"
#include "config_ADBMS6830.h"
#include "ts_state_machine.h"
#include <string.h>
#include <math.h>

#define SWO_LOG_PREFIX "[BATTERY] "
#include "swo_log.h"

#define MAX_ERRORS                4 // max number of errors in window before panic
#define MAX_ERRORS_WINDOW_SIZE   16 // size of the error window for error detection

#define MAX_TEMP 60 // max temperature in C
 
uint16_t min_voltage = 0xFFFF;
uint16_t max_voltage = 0;
typeof(module_voltages) module_voltages = {[0 ... N_BMS - 1] = {0xFFFF, 0}};
int16_t min_temp = INT16_MAX;
int16_t max_temp = INT16_MIN;
typeof(module_temps) module_temps = {[0 ... N_BMS - 1] = {INT16_MAX, INT16_MIN}};
float module_std_deviation[N_BMS] = {};

int16_t cellTemps[N_BMS][N_CELLS];

static bool error_window[MAX_ERRORS_WINDOW_SIZE] = {};
static size_t error_window_index = 0;
static size_t error_count = 0;

static inline void update_error_window(bool error, int id) {
  error_count -= error_window[error_window_index] ? 1 : 0; 
  error_count += error ? 1 : 0;

  if (error_count >= MAX_ERRORS) {
    can_send_error(TS_ERRORKIND_SLAVE_PANIC, id);
    ts_sm_set_error_source(TS_ERROR_SOURCE_SLAVES, TS_ERRORKIND_SLAVE_PANIC, true);
  } else {
    ts_sm_set_error_source(TS_ERROR_SOURCE_SLAVES, TS_ERRORKIND_SLAVE_PANIC, false);
  }

  error_window[error_window_index] = error;
  error_window_index += 1;
  error_window_index %= MAX_ERRORS_WINDOW_SIZE;
}

HAL_StatusTypeDef battery_init(SPI_HandleTypeDef *hspi) { 
    auto ret = AMS_Init(hspi);
    if (ret.status != ADBMS_NO_ERROR) {
        debug_log(LOG_LEVEL_ERROR, "Failed to initialize BMS: %s",
                  ADBMS_Status_ToString(ret.status));
        if (ret.bms_id != -1) {
            debug_log_cont(LOG_LEVEL_ERROR, " (on BMS ID: %hd)", ret.bms_id);
        }
        return HAL_ERROR;
    }
    debug_log(LOG_LEVEL_INFO, "Battery initialized successfully");
    return HAL_OK;
}

[[gnu::optimize("no-math-errno")]]
HAL_StatusTypeDef battery_update() {
  auto ret = AMS_Idle_Loop();
  if (ret.status != ADBMS_NO_ERROR) {
    debug_log(LOG_LEVEL_ERROR, "Error while updating battery data: %s",
              ADBMS_Status_ToString(ret.status));
    if (ret.bms_id != -1) {
      debug_log_cont(LOG_LEVEL_ERROR, " (on BMS ID: %hd)", ret.bms_id);
    }

    if (ret.status == ADBMS_OVERVOLT || ret.status == ADBMS_UNDERVOLT) {
      if (ret.bms_id != -1 && ret.bms_id < N_BMS) {
        const char* error_type = (ret.status == ADBMS_OVERVOLT) ? "overvoltage" : "undervoltage";
        uint32_t voltage_flags = (ret.status == ADBMS_OVERVOLT) ? 
                                modules[ret.bms_id].overVoltage : 
                                modules[ret.bms_id].underVoltage;
        
        debug_log(LOG_LEVEL_ERROR, "Cell %s detected on module %d, affected cells: ", 
                  error_type, ret.bms_id);
        
        for (size_t cell = 0; cell < N_CELLS; cell++) {
          if (voltage_flags & (1UL << cell)) {
            debug_log_cont(LOG_LEVEL_ERROR, "%zu (%d mV) ", cell, modules[ret.bms_id].cellVoltages[cell]);
          }
        }

        if (!voltage_flags) {
          debug_log_cont(LOG_LEVEL_ERROR, "none (something went wrong?)");
        }
      }
    }
    
    update_error_window(true, ret.bms_id);
    return HAL_ERROR;
  }

  update_error_window(false, ret.bms_id);

  min_voltage = 0xFFFF;
  max_voltage = 0;
  min_temp = INT16_MAX;
  max_temp = INT16_MIN;

  for (size_t i = 0; i < N_BMS; i++) {
    // Calculate standard deviation for each module
    if (DEBUG_CHANNEL_ENABLED(DEBUG_CHANNEL)) {
      float sum = 0;
      float mean = 0;
      float variance = 0;
      
      // First pass: calculate mean
      for (size_t j = 0; j < N_CELLS; j++) {
        sum += modules[i].cellVoltages[j];
      }
      mean = sum / N_CELLS;
      
      // Second pass: calculate variance
      for (size_t j = 0; j < N_CELLS; j++) {
        float diff = modules[i].cellVoltages[j] - mean;
        variance += diff * diff;
      }
      variance /= N_CELLS;
      
      // Calculate standard deviation
      module_std_deviation[i] = sqrtf(variance);
    }

    for (size_t j = 0; j < N_CELLS; j++) {
      if (modules[i].cellVoltages[j] > min_voltage) {
        min_voltage = modules[i].cellVoltages[j];
      }
      if (modules[i].cellVoltages[j] < max_voltage) {
        max_voltage = modules[i].cellVoltages[j];
      }
      if (modules[i].cellVoltages[j] > module_voltages[i].max) {
        module_voltages[i].max = modules[i].cellVoltages[j];
      }
      if (modules[i].cellVoltages[j] < module_voltages[i].min) {
        module_voltages[i].min = modules[i].cellVoltages[j];
      }
    }

    for (size_t j = 0; j < 10; j++) { //10 GPIOs
      cellTemps[i][j] = ntc_mv_to_celsius(modules[i].auxVoltages[j]);

      if (cellTemps[i][j] > max_temp) {
        max_temp = cellTemps[i][j];
      }
      if (cellTemps[i][j] < min_temp) {
        min_temp = cellTemps[i][j];
      }

      if (cellTemps[i][j] > module_temps[i].max) {
        module_temps[i].max = cellTemps[i][j];
      }
      if (cellTemps[i][j] < module_temps[i].min) {
        module_temps[i].min = cellTemps[i][j];
      }

      if (cellTemps[i][j] > (MAX_TEMP * (uint16_t)(TEMP_CONV))) {
        debug_log(LOG_LEVEL_ERROR, "Cell %zu on BMS %zu overtemp: %d0 mC", j, i, cellTemps[i][j]);
        can_send_error(TS_ERRORKIND_CELL_OVERTEMP, i);
        ts_sm_set_error_source(TS_ERROR_SOURCE_SLAVES, TS_ERRORKIND_CELL_OVERTEMP, true);
      } else {
        ts_sm_set_error_source(TS_ERROR_SOURCE_SLAVES, TS_ERRORKIND_CELL_OVERTEMP, false);
      }
    }
  }

  return HAL_OK;
}