#include <stdint.h>

#define RCU_APB2EN 0x40021018
#define GPIOA_CTL0 0x40010800
#define GPIOA_OCTL 0x4001080C

#define TIMER_FREQ ((uint32_t)8000000/4)
#define TIMER_CTRL_ADDR 0xd1000000
#define TIMER_MTIME 0x0
#define TIMER_MTIMECMP 0x8

extern void _trap_entry();

__attribute__((aligned(4))) uint8_t func1_thread_stack[512] = {0};
__attribute__((aligned(4))) uint8_t func2_thread_stack[512] = {0};
uint32_t func1_sp;
uint32_t func2_sp;
uint32_t main_sp;
uint32_t thread_sp;
uint32_t current_thread;

void gpio_init()
{
    // enable GPIOA clock
    *(uint32_t *)RCU_APB2EN |= (uint32_t)1 << 2;

    // set PA1 as output push pull
    *(uint32_t *)GPIOA_CTL0 = *(uint32_t *)GPIOA_CTL0 & ((uint32_t)0xffffff0f) |
                              (uint32_t)1 << 4;

    // PA1 output low
    *(uint32_t *)GPIOA_OCTL &= (uint32_t)0xfffffffd;
}

void timer_init()
{
    *(volatile uint32_t *)(TIMER_CTRL_ADDR + TIMER_MTIME + 4) = 0;
    *(volatile uint32_t *)(TIMER_CTRL_ADDR + TIMER_MTIME) = 0;
    *(volatile uint32_t *)(TIMER_CTRL_ADDR + TIMER_MTIMECMP + 4) = 0;
    *(volatile uint32_t *)(TIMER_CTRL_ADDR + TIMER_MTIMECMP) = TIMER_FREQ;
}

void func1()
{
    while(1){
        if(*(uint32_t *)GPIOA_OCTL & (uint32_t)0b10){
            *(uint32_t *)GPIOA_OCTL &= (uint32_t)0xfffffffd;
        }
    }
}

void func2()
{
    while(1){
        if(!(*(uint32_t *)GPIOA_OCTL & (uint32_t)0b10)){
            *(uint32_t *)GPIOA_OCTL |= (uint32_t)0b10;
        }
    }
}

uint32_t thread_init(void (*entry)(), uint8_t *stack_addr, uint32_t stack_size)
{
    uint32_t *p = (uint32_t *)(stack_addr + stack_size - 32 * 4);
    p[0] = (uint32_t)entry;
    p[2] = 0x1880;
    return (uint32_t)p;
}

extern void context_switch(uint32_t *from, uint32_t *to);
extern void context_switch_to(uint32_t *to);

void handle_trap(uint32_t mcause)
{
    if((mcause & 0x80000000) && (mcause & 0xfff) == 7){
        // reset mtime
        *(volatile uint32_t *)(TIMER_CTRL_ADDR + TIMER_MTIME + 4) = 0;
        *(volatile uint32_t *)(TIMER_CTRL_ADDR + TIMER_MTIME) = 0;
        if(current_thread == 1){
            current_thread = 2;
            context_switch(&func1_sp, &func2_sp);
        } else{
            current_thread = 1;
            context_switch(&func2_sp, &func1_sp);
        }
    } else{
        asm volatile("nop");
        while(1);
    }
}

int main()
{
    gpio_init();
    timer_init();

    uint32_t base = (uint32_t)_trap_entry;
    asm volatile("csrw mtvec, %0" :: "r"(base));
    asm volatile("csrs mie, %0" :: "r"(1 << 7));
    // asm volatile("csrs mstatus, %0" :: "r"(1 << 3));

    func1_sp = thread_init(func1, func1_thread_stack, 512);
    func2_sp = thread_init(func2, func2_thread_stack, 512);

    current_thread = 1;
    context_switch_to(&func1_sp);

    while(1);

    return 0;
}
