#include "spi_flash.h"

void spi_flash_init()
{
    MX_SPI2_Init();
}

static inline void spi_flash_cs_high()
{
    HAL_GPIO_WritePin(spi_flash_cs_GPIO_Port, spi_flash_cs_Pin, GPIO_PIN_SET);
}

static inline void spi_flash_cs_low()
{
    HAL_GPIO_WritePin(spi_flash_cs_GPIO_Port, spi_flash_cs_Pin, GPIO_PIN_RESET);
}

static uint8_t spi_flash_send_byte(uint8_t data_tx)
{
    uint8_t data_rx;
    HAL_SPI_TransmitReceive(&hspi2, &data_tx, &data_rx, 1, HAL_MAX_DELAY);
    return data_rx;
}

static uint8_t spi_flash_read_byte()
{
    return spi_flash_send_byte(SPI_FLASH_DUMMY_BYTE);
}

uint32_t spi_flash_read_id()
{
    spi_flash_cs_low();
    spi_flash_send_byte(SPI_FLASH_READ_JEDEC_ID);
    uint32_t temp0 = spi_flash_read_byte();
    uint32_t temp1 = spi_flash_read_byte();
    uint32_t temp2 = spi_flash_read_byte();
    spi_flash_cs_high();
    uint32_t res = (temp0 << 16) | (temp1 << 8) | temp2;
    return res;
}

void spi_flash_write_enable()
{
    spi_flash_cs_low();
    spi_flash_send_byte(SPI_FLASH_WRITE_ENABLE);
    spi_flash_cs_high();
}

void spi_flash_wait_for_write_end()
{
    spi_flash_cs_low();
    spi_flash_send_byte(SPI_FLASH_READ_STATUS_REG1);
    uint8_t status = 0;
    do{
        status = spi_flash_read_byte();
    } while(status & SPI_FLASH_STATUS_BUSY);

    spi_flash_cs_high();
}

static void spi_flash_erase(erase_mode_t mode, uint32_t addr)
{
    spi_flash_write_enable();
    spi_flash_cs_low();
    switch(mode){
        case erase_mode_chip:
            spi_flash_send_byte(SPI_FLASH_CHIP_ERASE);
            break;
        case erase_mode_sector:
            spi_flash_send_byte(SPI_FLASH_SECTOR_ERASE);
            break;
        case erase_mode_block_32kb:
            spi_flash_send_byte(SPI_FLASH_BLOCK_ERASE_32);
            break;
        case erase_mode_block_64kb:
            spi_flash_send_byte(SPI_FLASH_BLOCK_ERASE_64);
            break;
        default:
            spi_flash_send_byte(SPI_FLASH_DUMMY_BYTE);
            break;
    }
    if(mode != erase_mode_chip){
        spi_flash_send_byte(addr >> 16);
        spi_flash_send_byte((addr & 0xff00) >> 8);
        spi_flash_send_byte(addr & 0xff);
    }
    spi_flash_cs_high();
    spi_flash_wait_for_write_end();
}

inline void spi_flash_erase_sector(uint32_t sector_addr)
{
    spi_flash_erase(erase_mode_sector, sector_addr);
}

inline void spi_flash_erase_block_32kb(uint32_t block_addr)
{
    spi_flash_erase(erase_mode_block_32kb, block_addr);
}

inline void spi_flash_erase_block_64kb(uint32_t block_addr)
{
    spi_flash_erase(erase_mode_block_64kb, block_addr);
}

inline void spi_flash_erase_chip()
{
    spi_flash_erase(erase_mode_chip, 0);
}

static void spi_flash_write_page(uint8_t *pbuf, uint32_t addr, uint32_t len)
{
    spi_flash_write_enable();
    spi_flash_cs_low();
    spi_flash_send_byte(SPI_FLASH_PAGE_PROG);
    spi_flash_send_byte(addr >> 16);
    spi_flash_send_byte((addr & 0xff00) >> 8);
    spi_flash_send_byte(addr & 0xff);
    while(len--){
        spi_flash_send_byte(*pbuf);
        pbuf++;
    }
    spi_flash_cs_high();
    spi_flash_wait_for_write_end();
}

void spi_flash_write_buffer(uint8_t *pbuf, uint32_t addr, uint32_t len)
{
    uint32_t page_num;
    uint32_t offset;
    uint32_t byte_rem;
    uint32_t count;

    offset = addr % SPI_FLASH_PAGE_SIZE;
    page_num = len / SPI_FLASH_PAGE_SIZE;
    byte_rem = len % SPI_FLASH_PAGE_SIZE;
    count = SPI_FLASH_PAGE_SIZE - offset;
    if(!offset){  // write address is page size aligned
        if(!page_num){
            spi_flash_write_page(pbuf, addr, len);
        } else{
            while(page_num--){
                spi_flash_write_page(pbuf, addr, SPI_FLASH_PAGE_SIZE);
                pbuf += SPI_FLASH_PAGE_SIZE;
                addr += SPI_FLASH_PAGE_SIZE;
            }
            spi_flash_write_page(pbuf, addr, byte_rem);
        }
    } else{     // write address isn't page size aligned
        if(!page_num){
            if(byte_rem > count){
                spi_flash_write_page(pbuf, addr, count);
                pbuf += count;
                addr += count;
                spi_flash_write_page(pbuf, addr, byte_rem - count);
            } else{
                spi_flash_write_page(pbuf, addr, len);
            }
        } else{
            spi_flash_write_page(pbuf, addr, count);
            pbuf += count;
            addr += count;
            len -= count;
            page_num = len / SPI_FLASH_PAGE_SIZE;
            byte_rem = len % SPI_FLASH_PAGE_SIZE;
            while(page_num--){
                spi_flash_write_page(pbuf, addr, SPI_FLASH_PAGE_SIZE);
                pbuf += SPI_FLASH_PAGE_SIZE;
                addr += SPI_FLASH_PAGE_SIZE;
            }
            if(byte_rem)
                spi_flash_write_page(pbuf, addr, byte_rem);
        }
    }
}

void spi_flash_read_buffer(uint8_t *pbuf, uint32_t addr, uint32_t len)
{
    spi_flash_cs_low();
    spi_flash_send_byte(SPI_FLASH_READ_DATA);
    spi_flash_send_byte(addr >> 16);
    spi_flash_send_byte((addr & 0xff00) >> 8);
    spi_flash_send_byte(addr & 0xff);
    while(len--){
        *pbuf = spi_flash_read_byte();
        pbuf++;
    }
    spi_flash_cs_high();
}
