#include "can-halal.h"

#include <string.h>

#if defined(FTCAN_IS_BXCAN)
static CAN_HandleTypeDef *hcan;

HAL_StatusTypeDef ftcan_init(CAN_HandleTypeDef *handle) {
  hcan = handle;

  HAL_StatusTypeDef status =
      HAL_CAN_ActivateNotification(hcan, CAN_IT_RX_FIFO0_MSG_PENDING);
  if (status != HAL_OK) {
    return status;
  }

  return HAL_CAN_Start(hcan);
}

HAL_StatusTypeDef ftcan_transmit(uint16_t id, const uint8_t *data,
                                 size_t datalen) {
  static CAN_TxHeaderTypeDef header;
  header.StdId = id;
  header.IDE = CAN_ID_STD;
  header.RTR = CAN_RTR_DATA;
  header.DLC = datalen;
  uint32_t mailbox;
  return HAL_CAN_AddTxMessage(hcan, &header, data, &mailbox);
}

HAL_StatusTypeDef ftcan_add_filter(uint16_t id, uint16_t mask) {
  static uint32_t next_filter_no = 0;
  static CAN_FilterTypeDef filter;
  if (next_filter_no % 2 == 0) {
    filter.FilterIdHigh = id << 5;
    filter.FilterMaskIdHigh = mask << 5;
    filter.FilterIdLow = id << 5;
    filter.FilterMaskIdLow = mask << 5;
  } else {
    // Leave high filter untouched from the last configuration
    filter.FilterIdLow = id << 5;
    filter.FilterMaskIdLow = mask << 5;
  }
  filter.FilterFIFOAssignment = CAN_FILTER_FIFO0;
  filter.FilterBank = next_filter_no / 2;
  if (filter.FilterBank > FTCAN_NUM_FILTERS + 1) {
    return HAL_ERROR;
  }
  filter.FilterMode = CAN_FILTERMODE_IDMASK;
  filter.FilterScale = CAN_FILTERSCALE_16BIT;
  filter.FilterActivation = CAN_FILTER_ENABLE;

  // Disable slave filters
  // TODO: Some STM32 have multiple CAN peripherals, and one uses the slave
  // filter bank
  filter.SlaveStartFilterBank = FTCAN_NUM_FILTERS;

  HAL_StatusTypeDef status = HAL_CAN_ConfigFilter(hcan, &filter);
  if (status == HAL_OK) {
    next_filter_no++;
  }
  return status;
}

void HAL_CAN_RxFifo0MsgPendingCallback(CAN_HandleTypeDef *handle) {
  if (handle != hcan) {
    return;
  }
  CAN_RxHeaderTypeDef header;
  uint8_t data[8];
  if (HAL_CAN_GetRxMessage(hcan, CAN_RX_FIFO0, &header, data) != HAL_OK) {
    return;
  }

  if (header.IDE != CAN_ID_STD) {
    return;
  }

  ftcan_msg_received_cb(header.StdId, header.DLC, data);
}
#elif defined(FTCAN_IS_FDCAN)
static FDCAN_HandleTypeDef *hcan;

HAL_StatusTypeDef ftcan_init(FDCAN_HandleTypeDef *handle) {
  hcan = handle;

  HAL_StatusTypeDef status =
      HAL_FDCAN_ActivateNotification(hcan, FDCAN_IT_RX_FIFO0_NEW_MESSAGE, 0);
  if (status != HAL_OK) {
    return status;
  }
  // Reject non-matching messages
  status =
      HAL_FDCAN_ConfigGlobalFilter(hcan, FDCAN_REJECT, FDCAN_REJECT,
                                   FDCAN_REJECT_REMOTE, FDCAN_REJECT_REMOTE);
  if (status != HAL_OK) {
    return status;
  }

  return HAL_FDCAN_Start(hcan);
}

HAL_StatusTypeDef ftcan_transmit(uint16_t id, const uint8_t *data,
                                 size_t datalen) {
  static FDCAN_TxHeaderTypeDef header;
  header.Identifier = id;
  header.IdType = FDCAN_STANDARD_ID;
  header.TxFrameType = FDCAN_DATA_FRAME;
  switch (datalen) {
  case 0:
    header.DataLength = FDCAN_DLC_BYTES_0;
    break;
  case 1:
    header.DataLength = FDCAN_DLC_BYTES_1;
    break;
  case 2:
    header.DataLength = FDCAN_DLC_BYTES_2;
    break;
  case 3:
    header.DataLength = FDCAN_DLC_BYTES_3;
    break;
  case 4:
    header.DataLength = FDCAN_DLC_BYTES_4;
    break;
  case 5:
    header.DataLength = FDCAN_DLC_BYTES_5;
    break;
  case 6:
    header.DataLength = FDCAN_DLC_BYTES_6;
    break;
  case 7:
    header.DataLength = FDCAN_DLC_BYTES_7;
    break;
  case 8:
  default:
    header.DataLength = FDCAN_DLC_BYTES_8;
    break;
  }
  header.ErrorStateIndicator = FDCAN_ESI_PASSIVE;
  header.BitRateSwitch = FDCAN_BRS_OFF;
  header.FDFormat = FDCAN_CLASSIC_CAN;
  header.TxEventFifoControl = FDCAN_NO_TX_EVENTS;

  // HAL_FDCAN_AddMessageToTxFifoQ doesn't modify the data, but it's not marked
  // as const for some reason.
  uint8_t *data_nonconst = (uint8_t *)data;
  return HAL_FDCAN_AddMessageToTxFifoQ(hcan, &header, data_nonconst);
}

HAL_StatusTypeDef ftcan_add_filter(uint16_t id, uint16_t mask) {
  static uint32_t next_filter_no = 0;
  static FDCAN_FilterTypeDef filter;
  filter.IdType = FDCAN_STANDARD_ID;
  filter.FilterIndex = next_filter_no;
  if (filter.FilterIndex > FTCAN_NUM_FILTERS + 1) {
    return HAL_ERROR;
  }
  filter.FilterType = FDCAN_FILTER_MASK;
  filter.FilterConfig = FDCAN_FILTER_TO_RXFIFO0;
  filter.FilterID1 = id;
  filter.FilterID2 = mask;

  HAL_StatusTypeDef status = HAL_FDCAN_ConfigFilter(hcan, &filter);
  if (status == HAL_OK) {
    next_filter_no++;
  }
  return status;
}

void HAL_FDCAN_RxFifo0Callback(FDCAN_HandleTypeDef *handle,
                               uint32_t RxFifo0ITs) {
  if (handle != hcan || (RxFifo0ITs & FDCAN_IT_RX_FIFO0_NEW_MESSAGE) == RESET) {
    return;
  }

  static FDCAN_RxHeaderTypeDef header;
  static uint8_t data[8];
  if (HAL_FDCAN_GetRxMessage(hcan, FDCAN_RX_FIFO0, &header, data) != HAL_OK) {
    return;
  }

  if (header.FDFormat != FDCAN_CLASSIC_CAN ||
      header.RxFrameType != FDCAN_DATA_FRAME ||
      header.IdType != FDCAN_STANDARD_ID) {
    return;
  }

  size_t datalen;
  switch (header.DataLength) {
  case FDCAN_DLC_BYTES_0:
    datalen = 0;
    break;
  case FDCAN_DLC_BYTES_1:
    datalen = 1;
    break;
  case FDCAN_DLC_BYTES_2:
    datalen = 2;
    break;
  case FDCAN_DLC_BYTES_3:
    datalen = 3;
    break;
  case FDCAN_DLC_BYTES_4:
    datalen = 4;
    break;
  case FDCAN_DLC_BYTES_5:
    datalen = 5;
    break;
  case FDCAN_DLC_BYTES_6:
    datalen = 6;
    break;
  case FDCAN_DLC_BYTES_7:
    datalen = 7;
    break;
  case FDCAN_DLC_BYTES_8:
    datalen = 8;
    break;
  default:
    return;
  }

  ftcan_msg_received_cb(header.Identifier, datalen, data);
}
#endif

__weak void ftcan_msg_received_cb(uint16_t id, size_t datalen,
                                  const uint8_t *data) {}

uint64_t ftcan_unmarshal_unsigned(const uint8_t **data_ptr, size_t num_bytes) {
  if (num_bytes > 8) {
    num_bytes = 8;
  }

  const uint8_t *data = *data_ptr;
  uint64_t result = 0;
  for (size_t i = 0; i < num_bytes; i++) {
    result <<= 8;
    result |= data[i];
  }
  *data_ptr += num_bytes;
  return result;
}

int64_t ftcan_unmarshal_signed(const uint8_t **data_ptr, size_t num_bytes) {
  if (num_bytes > 8) {
    num_bytes = 8;
  }

  uint64_t result_unsigned = ftcan_unmarshal_unsigned(data_ptr, num_bytes);
  // Sign extend by shifting left, then copying to a signed int and shifting
  // back to the right
  size_t diff_to_64 = 64 - num_bytes * 8;
  result_unsigned <<= diff_to_64;
  int64_t result;
  memcpy(&result, &result_unsigned, 8);
  return result >> diff_to_64;
}

uint8_t *ftcan_marshal_unsigned(uint8_t *data, uint64_t val, size_t num_bytes) {
  if (num_bytes > 8) {
    num_bytes = 8;
  }

  for (int i = num_bytes - 1; i >= 0; i--) {
    data[i] = val & 0xFF;
    val >>= 8;
  }

  return data + num_bytes;
}

uint8_t *ftcan_marshal_signed(uint8_t *data, int64_t val, size_t num_bytes) {
  return ftcan_marshal_unsigned(data, val, num_bytes);
}