// SPDX-License-Identifier: GPL-2.0
#ifndef TEST
#include <linux/types.h>
#include <linux/slab.h>
#include <linux/sched.h>
#include <linux/init.h>
#include <asm/limits.h> // ramend!
#include <generated/asm-offsets.h>
#define dbg_print(...) do{}while(0)
#define fprintf(...) do{}while(0)
#endif

#define MEMBLOCK_FREE       0
#define MEMBLOCK_OCCUPIED   1

#ifndef GFP_ZERO
#define GFP_ZERO 0
#endif

#ifndef GFP_KTHREAD
#define GFP_KTHREAD 0
#endif

#if RAM_END > 0xffffff
struct memblock {
	u32 size :31;
	u32 type : 1;
};
#elif RAM_END > 0x7fffff
struct memblock {
	u24 size;
	u8  type;
};
#elif RAM_END > 0xffff
struct memblock {
	u24 size :23;
	u24 type : 1;
};
#elif RAM_END > 0x7fff
struct memblock __attribute__((packed)) {
	u16 size;
	u8  type;
};
#elif RAM_END > 0xff
struct memblock {
	u16 size :15;
	u16 type : 1;
};
#elif RAM_END > 0x7f
struct memblock {
	u8 size;
	u8 type;
};
#else
struct memblock {
	u8 size :7;
	u8 type :1;
};
#endif

#ifndef TEST
extern u8 __bss_end[];
#define memblock_start ((struct memblock*)__bss_end)
#define memblock_end ((struct memblock*)(RAM_END + 1))
#endif
#define memblock_next(mb)  ((struct memblock*) ((uintptr_t)(mb) + (mb)->size + sizeof(struct memblock)))


static void *_kmalloc(alloc_size_t size_, int flags) {
	struct memblock *fit = NULL;
	struct memblock *fit_prev = NULL;
	struct memblock *prev = NULL;
	struct memblock *block;
	u8 offset = 0;
	u8 fit_offset = 0;
	void *ret;
	int fit_size = 0;
	uintptr_t size = size_;

#if 0
	if (!size) // FIXME: We assume nobody will call kmalloc with a size of 0
		return NULL;
#endif
	for (block = memblock_start; block != memblock_end; prev = block, block = memblock_next(block)) {
		if (block->type != MEMBLOCK_FREE)
			continue;
		if (block->size < size)
			continue;
		// check if block allows our alignment
		if (flags & GFP_KTHREAD) {
#ifdef TEST
			uintptr_t ret = dbg_addr((block + 1));
#else
			uintptr_t ret = (uintptr_t)(block + 1);
#endif
			// Stack is at the end of task_struct
			u8 stack_start = (uintptr_t)ret + TASK_stack;
			u8 stack_end = (uintptr_t)ret + size;
			/* FIXME: check if alignment at the end of the free block would be ok */
			/* But only if we already have current! */
			if (stack_start > stack_end) {
				/* We move our block until it fits the alignment requirement */
				/* We could also try to use the end of the free block and see if it meets our
				 * alignment requirement
				 */
				offset = 0x100 - (stack_start & 0xff);
				if (block->size < size + offset + 2)
					continue;
			} else
				offset = 0;
		}
		fit = block;
		fit_prev = prev;
		fit_offset = offset;
		fit_size = block->size;
		goto found;
	}
	if (!fit)
		return fit; // return NULL
found:
	dbg_print(stderr, "kmalloc: Found fit @%03x offset=%x\n", (int)((uintptr_t)fit - (uintptr_t)test_ram), fit_offset);
	if (fit_offset) {
		/* FIXME:
		 * - If offset <= sizeof(struct memblock)  -> Enlarge previous block, no matter if it is free or not
		 * - Otherwise, create a new free block before us
		 */
		/* FIXME: Join free segments afterwards in a separate loop? */
		if (fit_offset > sizeof(struct memblock)) {
			dbg_print("kmalloc: Adding free block before\n");
			fit->size = fit_offset - sizeof(struct memblock);
			fit->type = MEMBLOCK_FREE; // Is already free, but we get better code in the bitfield case if we set it again
			fit = (void*)fit + fit_offset;
			fit_size -= fit_offset;
			fit_offset = 0;
		} else if (fit_prev) {
			dbg_print("kmalloc: enlarging fit by offset\n");
			fit_prev->size += fit_offset;
			fit = (void*)fit + fit_offset;
			fit_size -= fit_offset;
			fit_offset = 0;
		}
	}
	dbg_print(stderr, "kmalloc: Mangled fit @%03x offset=%x\n", (int)((uintptr_t)fit - (uintptr_t)test_ram), fit_offset);
	ret = (void*)(fit + 1) + fit_offset;

	{
		uintptr_t newfreeblk_size;
		newfreeblk_size = fit_size - size;
		dbg_print(stderr, "kmalloc: newfreeblk_size=%x\n", newfreeblk_size);
		if (newfreeblk_size > sizeof(struct memblock)) {
			block = (struct memblock*)((uintptr_t)fit + size + sizeof(struct memblock));
			dbg_print(stderr, "kmalloc: newfreeblk pos=%x fit=%x fitsize=%x\n", (int)((uintptr_t)block - (uintptr_t)test_ram), (int)((uintptr_t)fit - (uintptr_t)test_ram), fit->size);
			block->size = newfreeblk_size - sizeof(struct memblock);
			block->type = 0; // free
		} else {
			size += newfreeblk_size;
		}
	}
	fit->size = size;
	fit->type = MEMBLOCK_OCCUPIED;

	if (flags & GFP_ZERO) {
		char *p = ret + size;
		memop_size_t i;
		for (i = 0; i < size; i++)
			*--p = 0;
		ret = p;
	}

	return ret;
}

static void _kfree(const void *mem) {
	struct memblock *freeblock = (void*)mem;
	struct memblock *next;
	uintptr_t size;

#ifdef TEST
	if (mem == NULL)
		return;
#endif
	freeblock--;

	/* Kmalloc can produce gaps at the start smaller than struct memblock. Compensate for that. */
	// FIXME: This can only happen with GFP_KTHREAD, and then only in very rare cases!
#ifdef CONFIG_TINY_STACK
	if ((uintptr_t)freeblock + sizeof(struct memblock) <= (uintptr_t)memblock_start) {
		freeblock = memblock_start;
	}
#endif

	size = freeblock->size;

	/* First, join with next block */
	next = memblock_next(freeblock);
	if (next != memblock_end && next->type == MEMBLOCK_FREE) {
		size += next->size + sizeof(struct memblock);
	}

	/* See if we can join with previous block */
	if (freeblock != memblock_start) {
		struct memblock *block;
		// FIXME: If we scan anyway, we could also reject objects not found in the list!
		for (block = memblock_start; ; block = next) {
			next = memblock_next(block);
			if (next == freeblock) {
				if (block->type == MEMBLOCK_FREE) {
					freeblock = block;
					size += block->size + sizeof(struct memblock);
				} else {
					block = next;
				}
				break;
			}
		}
	}
	freeblock->size = size;
	freeblock->type = MEMBLOCK_FREE;
}


#ifndef CONFIG_MUTEX
#define mutex_lock(x) do{}while(0)
#define mutex_unlock(x) do{}while(0)
#warning kmalloc is used without locking
#else
struct mutex kmalloc_mutex;
#endif

void *kmalloc(alloc_size_t size, int flags) {
	void *ret;

	mutex_lock(&kmalloc_mutex);
	ret = _kmalloc(size, flags);
	mutex_unlock(&kmalloc_mutex);

	return ret;
}

void kfree(const void *mem) {
	mutex_lock(&kmalloc_mutex);
	_kfree(mem);
	mutex_unlock(&kmalloc_mutex);
}

static int kmalloc_init(void) {
	struct memblock *block = memblock_start;
	block->size = (uintptr_t)memblock_end - (uintptr_t)block - sizeof(struct memblock);
	dbg_print(stderr, "memblock_end - memblock_start = %x", (int)((uintptr_t)memblock_end - (uintptr_t)block));
	block->type = MEMBLOCK_FREE;
	return 0;
}

early_initcall(kmalloc_init);
