#define _GNU_SOURCE
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <signal.h>
#include <sys/types.h>
#include <sys/wait.h>
#include <sys/stat.h>
#include <string.h>
#include <dirent.h>
#include <fcntl.h>
#include <pwd.h>
#include <limits.h>
#include "self_criu.h"

int criu_verbose = 0;
static char *ckpt_path = NULL;
static char *restore_flag_file = NULL;

static const char* get_real_user_home() {
    const char *user = getenv("SUDO_USER");
    if (user) {
        struct passwd *pw = getpwnam(user);
        if (pw) return pw->pw_dir;
    }
    const char *home = getenv("HOME");
    return home ? home : "/root";
}

static void criu_cleanup(void) {
    struct stat st;
    if (ckpt_path && stat(ckpt_path, &st) == 0 && S_ISDIR(st.st_mode)) {
        char *cmd;
        if (asprintf(&cmd, "rm -rf %s 2>/dev/null", ckpt_path) != -1) {
            if (system(cmd)) {}
            free(cmd);
        }
    }
}

static void handle_exit_signal(int sig) {
    criu_cleanup();
    _exit(0);
}

static void handle_checkpoint(int sig) {
    pid_t target = getpid();
    char *cmd, *log_file;
    
    if (!ckpt_path || asprintf(&log_file, "%s/criu_dump.log", ckpt_path) == -1) return;

    if (criu_verbose) {
        if (asprintf(&cmd, "(criu dump -t %d -D %s --shell-job --leave-running && kill -9 %d) &", 
                     target, ckpt_path, target) == -1) {
            free(log_file);
            return;
        }
    } else {
        if (asprintf(&cmd, "(criu dump -t %d -D %s --shell-job --leave-running > %s 2>&1 || cat %s) && kill -9 %d &", 
                     target, ckpt_path, log_file, log_file, target) == -1) {
            free(log_file);
            return;
        }
    }

    printf("\n[CRIU] Checkpointing PID %d...\n", target);
    if (system(cmd)) {}

    free(cmd);
    free(log_file);

    while (1) {
        if (restore_flag_file && access(restore_flag_file, F_OK) == 0) {
            unlink(restore_flag_file);
            printf("[CRIU] Resumed successfully.\n");
            return; 
        }
        usleep(100000);
    }
}

void criu_init(int argc, char **argv, int flags) {
    if (flags & CRIU_DEBUG) criu_verbose = 1;

    if (geteuid() != 0) {
        char **new_argv = malloc((argc + 2) * sizeof(char *));
        if (!new_argv) exit(1);
        new_argv[0] = "sudo";
        for (int i = 0; i < argc; i++) {
            new_argv[i + 1] = argv[i];
        }
        new_argv[argc + 1] = NULL;
        execvp("sudo", new_argv);
        perror("[CRIU] Sudo elevation failed");
        free(new_argv);
        exit(1);
    }

    const char *p_name = (strrchr(argv[0], '/')) ? strrchr(argv[0], '/') + 1 : argv[0];
    const char *home = get_real_user_home();
    char *base_dir;
    
    if (asprintf(&base_dir, "%s/.checkpoint", home) == -1) exit(1);
    if (asprintf(&restore_flag_file, "%s/.restore_flag_%s", base_dir, p_name) == -1) exit(1);

    DIR *dir = opendir(base_dir);
    if (dir) {
        struct dirent *entry;
        while ((entry = readdir(dir)) != NULL) {
            if (strncmp(entry->d_name, p_name, strlen(p_name)) == 0 && strchr(entry->d_name, '_')) {
                int old_pid = atoi(strrchr(entry->d_name, '_') + 1);
                if (asprintf(&ckpt_path, "%s/%s", base_dir, entry->d_name) == -1) break;
                
                if (kill(old_pid, 0) == 0) {
                    free(ckpt_path);
                    ckpt_path = NULL;
                    continue;
                }

                printf("[CRIU] Found checkpoint. Reclaiming PID %d...\n", old_pid);
                int flag_fd = creat(restore_flag_file, 0666);
                if (flag_fd >= 0) close(flag_fd);

                int fd = open("/proc/sys/kernel/ns_last_pid", O_WRONLY);
                if (fd >= 0) {
                    char buf[32];
                    snprintf(buf, sizeof(buf), "%d", old_pid - 1);
                    if (write(fd, buf, strlen(buf)) < 0) {}
                    close(fd);
                }

                char *args[] = {"criu", "restore", "-D", ckpt_path, "--shell-job", NULL};
                execvp("criu", args);
                unlink(restore_flag_file);
            }
        }
        closedir(dir);
    }

    if (asprintf(&ckpt_path, "%s/%s_%d", base_dir, p_name, getpid()) == -1) exit(1);
    mkdir(base_dir, 0755);
    mkdir(ckpt_path, 0755);
    free(base_dir);

    atexit(criu_cleanup);
    signal(SIGQUIT, handle_checkpoint);
    if (flags & CRIU_CLEANUP_ON_SIGINT) signal(SIGINT, handle_exit_signal);
}
