import spaces
import gradio as gr
from pathlib import Path
import re
import torch
import gc
from typing import Any
from huggingface_hub import hf_hub_download, HfApi
from llama_cpp import Llama
from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
from llama_cpp_agent.providers import LlamaCppPythonProvider
from llama_cpp_agent.chat_history import BasicChatHistory
from llama_cpp_agent.chat_history.messages import Roles
from ja_to_danbooru.ja_to_danbooru import jatags_to_danbooru_tags
import wrapt_timeout_decorator
from llama_cpp_agent.messages_formatter import MessagesFormatter
from formatter import mistral_v1_formatter, mistral_v2_formatter, mistral_v3_tekken_formatter
from llmenv import llm_models, llm_models_dir, llm_formats, llm_languages, dolphin_system_prompt
import subprocess
subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)


llm_models_tupled_list = []
default_llm_model_filename = list(llm_models.keys())[0]
device = "cuda" if torch.cuda.is_available() else "cpu"


def to_list(s: str):
    return [x.strip() for x in s.split(",") if not s == ""]


def list_uniq(l: list):
    return sorted(set(l), key=l.index)


DEFAULT_STATE = {
    "dolphin_sysprompt_mode": "Default",
    "dolphin_output_language": llm_languages[0],
}


def get_state(state: dict, key: str):
    if key in state.keys(): return state[key]
    elif key in DEFAULT_STATE.keys():
        print(f"State '{key}' not found. Use dedault value.")
        return DEFAULT_STATE[key]
    else:
        print(f"State '{key}' not found.")
        return None


def set_state(state: dict, key: str, value: Any):
    state[key] = value


@wrapt_timeout_decorator.timeout(dec_timeout=3.5)
def to_list_ja(s: str):
    s = re.sub(r'[、。]', ',', s)
    return [x.strip() for x in s.split(",") if not s == ""]


def is_japanese(s: str):
    import unicodedata
    for ch in s:
        name = unicodedata.name(ch, "") 
        if "CJK UNIFIED" in name or "HIRAGANA" in name or "KATAKANA" in name:
            return True
    return False


def update_llm_model_tupled_list():
    global llm_models_tupled_list
    llm_models_tupled_list = []
    for k, v in llm_models.items():
        name = k
        value = k
        llm_models_tupled_list.append((name, value))
    model_files = Path(llm_models_dir).glob('*.gguf')
    for path in model_files:
        name = path.name
        value = path.name
        llm_models_tupled_list.append((name, value))
    llm_models_tupled_list = list_uniq(llm_models_tupled_list)
    return llm_models_tupled_list


def download_llm_models():
    global llm_models_tupled_list
    llm_models_tupled_list = []
    for k, v in llm_models.items():
        try:
            hf_hub_download(repo_id = v[0], filename = k, local_dir = llm_models_dir)
        except Exception:
            continue
        name = k
        value = k
        llm_models_tupled_list.append((name, value))


def download_llm_model(filename: str):
    if not filename in llm_models.keys(): return default_llm_model_filename
    try:
        hf_hub_download(repo_id = llm_models[filename][0], filename = filename, local_dir = llm_models_dir)
    except Exception as e:
        print(e)
        return default_llm_model_filename
    update_llm_model_tupled_list()
    return filename


def get_dolphin_model_info(filename: str):
    md = "None"
    items = llm_models.get(filename, None)
    if items:
        md = f'Repo: [{items[0]}](https://huggingface.co/{items[0]})'
    return md


def select_dolphin_model(filename: str, state: dict, progress=gr.Progress(track_tqdm=True)):
    set_state(state, "override_llm_format", None)
    progress(0, desc="Loading model...")
    value = download_llm_model(filename)
    progress(1, desc="Model loaded.")
    md = get_dolphin_model_info(filename)
    return gr.update(value=value, choices=get_dolphin_models()), gr.update(value=get_dolphin_model_format(value)), gr.update(value=md), state


def select_dolphin_format(format_name: str, state: dict):
    set_state(state, "override_llm_format", llm_formats[format_name])
    return gr.update(value=format_name), state


download_llm_model(default_llm_model_filename)


def get_dolphin_models():
    return update_llm_model_tupled_list()


def get_llm_formats():
    return list(llm_formats.keys())


def get_key_from_value(d, val):
    keys = [k for k, v in d.items() if v == val]
    if keys:
        return keys[0]
    return None


def get_dolphin_model_format(filename: str):
    if not filename in llm_models.keys(): filename = default_llm_model_filename
    format = llm_models[filename][1]
    format_name = get_key_from_value(llm_formats, format)
    return format_name


def add_dolphin_models(query: str, format_name: str):
    global llm_models
    api = HfApi()
    add_models = {}
    format = llm_formats[format_name]
    filename = ""
    repo = ""
    try:
        s = list(re.findall(r'^(?:https?://huggingface.co/)?(.+?/.+?)(?:/.*/(.+?.gguf).*?)?$', query)[0])
        if s and  "" in s: s.remove("")
        if len(s) == 1:
            repo = s[0]
            if not api.repo_exists(repo_id = repo): return gr.update()
            files = api.list_repo_files(repo_id = repo)
            for file in files:
                if str(file).endswith(".gguf"): add_models[filename] = [repo, format]
        elif len(s) >= 2:
            repo = s[0]
            filename = s[1]
            if not api.repo_exists(repo_id = repo) or not api.file_exists(repo_id = repo, filename = filename): return gr.update()
            add_models[filename] = [repo, format]
        else: return gr.update()
    except Exception as e:
        print(e)
        return gr.update()
    llm_models = (llm_models | add_models).copy()
    update_llm_model_tupled_list()
    choices = get_dolphin_models()
    return gr.update(choices=choices, value=choices[-1][1])


def get_dolphin_sysprompt(state: dict={}):
    dolphin_sysprompt_mode = get_state(state, "dolphin_sysprompt_mode")
    dolphin_output_language = get_state(state, "dolphin_output_language")
    prompt = re.sub('<LANGUAGE>', dolphin_output_language if dolphin_output_language else llm_languages[0],
                    dolphin_system_prompt.get(dolphin_sysprompt_mode, dolphin_system_prompt[list(dolphin_system_prompt.keys())[0]]))
    return prompt


def get_dolphin_sysprompt_mode():
    return list(dolphin_system_prompt.keys())


def select_dolphin_sysprompt(key: str, state: dict):
    dolphin_sysprompt_mode = get_state(state, "dolphin_sysprompt_mode")
    if not key in dolphin_system_prompt.keys(): dolphin_sysprompt_mode = "Default"
    else: dolphin_sysprompt_mode = key
    set_state(state, "dolphin_sysprompt_mode", dolphin_sysprompt_mode)
    return gr.update(value=get_dolphin_sysprompt(state)), state


def get_dolphin_languages():
    return llm_languages


def select_dolphin_language(lang: str, state: dict):
    set_state(state, "dolphin_output_language", lang)
    return gr.update(value=get_dolphin_sysprompt(state)), state


@wrapt_timeout_decorator.timeout(dec_timeout=5.0)
def get_raw_prompt(msg: str):
    m = re.findall(r'/GENBEGIN/(.+?)/GENEND/', msg, re.DOTALL)
    return re.sub(r'[*/:_"#\n]', ' ', ", ".join(m)).lower() if m else ""


@torch.inference_mode()
@spaces.GPU(duration=59)
def dolphin_respond(
    message: str,
    history: list[tuple[str, str]],
    model: str = default_llm_model_filename,
    system_message: str = get_dolphin_sysprompt(),
    max_tokens: int = 1024,
    temperature: float = 0.7,
    top_p: float = 0.95,
    top_k: int = 40,
    repeat_penalty: float = 1.1,
    state: dict = {},
    progress=gr.Progress(track_tqdm=True),
):
    try:
        model_path = Path(f"{llm_models_dir}/{model}")
        if not model_path.exists(): raise gr.Error(f"Model file not found: {str(model_path)}")
        progress(0, desc="Processing...")
        override_llm_format = get_state(state, "override_llm_format")
        if override_llm_format: chat_template = override_llm_format
        else: chat_template = llm_models[model][1]

        llm = Llama(
            model_path=str(model_path),
            flash_attn=True,
            n_gpu_layers=81, # 81
            n_batch=1024,
            n_ctx=8192, #8192
        )
        provider = LlamaCppPythonProvider(llm)

        agent = LlamaCppAgent(
            provider,
            system_prompt=f"{system_message}",
            predefined_messages_formatter_type=chat_template if not isinstance(chat_template, MessagesFormatter) else None,
            custom_messages_formatter=chat_template if isinstance(chat_template, MessagesFormatter) else None,
            debug_output=False
        )
        
        settings = provider.get_provider_default_settings()
        settings.temperature = temperature
        settings.top_k = top_k
        settings.top_p = top_p
        settings.max_tokens = max_tokens
        settings.repeat_penalty = repeat_penalty
        settings.stream = True

        messages = BasicChatHistory()

        for msn in history:
            user = {
                'role': Roles.user,
                'content': msn[0]
            }
            assistant = {
                'role': Roles.assistant,
                'content': msn[1]
            }
            messages.add_message(user)
            messages.add_message(assistant)
        
        stream = agent.get_chat_response(
            message,
            llm_sampling_settings=settings,
            chat_history=messages,
            returns_streaming_generator=True,
            print_output=False
        )
        
        progress(0.5, desc="Processing...")

        outputs = ""
        for output in stream:
            outputs += output
            yield [(outputs, None)]
    except Exception as e:
        print(e)
        raise gr.Error(f"Error: {e}")
        #yield [("", None)]
    finally:
        torch.cuda.empty_cache()
        gc.collect()


def dolphin_parse(
    history: list[tuple[str, str]],
    state: dict,
):
    try:
        dolphin_sysprompt_mode = get_state(state, "dolphin_sysprompt_mode")
        if dolphin_sysprompt_mode == "Chat with LLM" or not history or len(history) < 1:
            return "", gr.update(), gr.update()
        msg = history[-1][0]
        raw_prompt = get_raw_prompt(msg)
        prompts = []
        if dolphin_sysprompt_mode == "Japanese to Danbooru Dictionary" and is_japanese(raw_prompt):
            prompts = list_uniq(jatags_to_danbooru_tags(to_list_ja(raw_prompt)) + ["nsfw", "explicit"])
        else:
            prompts = list_uniq(to_list(raw_prompt) + ["nsfw", "explicit"])
        return ", ".join(prompts), gr.update(interactive=True), gr.update(interactive=True)
    except Exception as e:
        print(e)
        return "", gr.update(), gr.update()


@torch.inference_mode()
@spaces.GPU(duration=59)
def dolphin_respond_auto(
    message: str,
    history: list[tuple[str, str]],
    model: str = default_llm_model_filename,
    system_message: str = get_dolphin_sysprompt(),
    max_tokens: int = 1024,
    temperature: float = 0.7,
    top_p: float = 0.95,
    top_k: int = 40,
    repeat_penalty: float = 1.1,
    state: dict = {},
    progress=gr.Progress(track_tqdm=True),
):
    try:
        model_path = Path(f"{llm_models_dir}/{model}")
        #if not is_japanese(message): return [(None, None)]
        progress(0, desc="Processing...")

        override_llm_format = get_state(state, "override_llm_format")
        if override_llm_format: chat_template = override_llm_format
        else: chat_template = llm_models[model][1]

        llm = Llama(
            model_path=str(model_path),
            flash_attn=True,
            n_gpu_layers=81, # 81
            n_batch=1024,
            n_ctx=8192, #8192
        )
        provider = LlamaCppPythonProvider(llm)

        agent = LlamaCppAgent(
            provider,
            system_prompt=f"{system_message}",
            predefined_messages_formatter_type=chat_template if not isinstance(chat_template, MessagesFormatter) else None,
            custom_messages_formatter=chat_template if isinstance(chat_template, MessagesFormatter) else None,
            debug_output=False
        )
        
        settings = provider.get_provider_default_settings()
        settings.temperature = temperature
        settings.top_k = top_k
        settings.top_p = top_p
        settings.max_tokens = max_tokens
        settings.repeat_penalty = repeat_penalty
        settings.stream = True

        messages = BasicChatHistory()

        for msn in history:
            user = {
                'role': Roles.user,
                'content': msn[0]
            }
            assistant = {
                'role': Roles.assistant,
                'content': msn[1]
            }
            messages.add_message(user)
            messages.add_message(assistant)
        
        progress(0, desc="Translating...")
        stream = agent.get_chat_response(
            message,
            llm_sampling_settings=settings,
            chat_history=messages,
            returns_streaming_generator=True,
            print_output=False
        )

        progress(0.5, desc="Processing...")

        outputs = ""
        for output in stream:
            outputs += output
            yield [(outputs, None)], gr.update(), gr.update()
    except Exception as e:
        print(e)
        yield [("", None)], gr.update(), gr.update()
    finally:
        torch.cuda.empty_cache()
        gc.collect()


def dolphin_parse_simple(
    message: str,
    history: list[tuple[str, str]],
    state: dict,
):
    try:
        #if not is_japanese(message): return message
        dolphin_sysprompt_mode = get_state(state, "dolphin_sysprompt_mode")
        if dolphin_sysprompt_mode == "Chat with LLM" or not history or len(history) < 1: return message
        msg = history[-1][0]
        raw_prompt = get_raw_prompt(msg)
        prompts = []
        if dolphin_sysprompt_mode == "Japanese to Danbooru Dictionary" and is_japanese(raw_prompt):
            prompts = list_uniq(jatags_to_danbooru_tags(to_list_ja(raw_prompt)) + ["nsfw", "explicit", "rating_explicit"])
        else:
            prompts = list_uniq(to_list(raw_prompt) + ["nsfw", "explicit", "rating_explicit"])
        return ", ".join(prompts)
    except Exception as e:
        print(e)
        return ""


# https://huggingface.co/spaces/CaioXapelaum/GGUF-Playground
import cv2
cv2.setNumThreads(1)


@torch.inference_mode()
@spaces.GPU(duration=59)
def respond_playground(
    message: str,
    history: list[tuple[str, str]],
    model: str = default_llm_model_filename,
    system_message: str = get_dolphin_sysprompt(),
    max_tokens: int = 1024,
    temperature: float = 0.7,
    top_p: float = 0.95,
    top_k: int = 40,
    repeat_penalty: float = 1.1,
    state: dict = {},
    progress=gr.Progress(track_tqdm=True),
):
    try:
        model_path = Path(f"{llm_models_dir}/{model}")
        if not model_path.exists(): raise gr.Error(f"Model file not found: {str(model_path)}")
        override_llm_format = get_state(state, "override_llm_format")
        if override_llm_format: chat_template = override_llm_format
        else: chat_template = llm_models[model][1]

        llm = Llama(
            model_path=str(model_path),
            flash_attn=True,
            n_gpu_layers=81, # 81
            n_batch=1024,
            n_ctx=8192, #8192
        )
        provider = LlamaCppPythonProvider(llm)

        agent = LlamaCppAgent(
            provider,
            system_prompt=f"{system_message}",
            predefined_messages_formatter_type=chat_template if not isinstance(chat_template, MessagesFormatter) else None,
            custom_messages_formatter=chat_template if isinstance(chat_template, MessagesFormatter) else None,
            debug_output=False
        )
        
        settings = provider.get_provider_default_settings()
        settings.temperature = temperature
        settings.top_k = top_k
        settings.top_p = top_p
        settings.max_tokens = max_tokens
        settings.repeat_penalty = repeat_penalty
        settings.stream = True

        messages = BasicChatHistory()

        # Add user and assistant messages to the history
        for msn in history:
            user = {'role': Roles.user, 'content': msn[0]}
            assistant = {'role': Roles.assistant, 'content': msn[1]}
            messages.add_message(user)
            messages.add_message(assistant)

        # Stream the response
        stream = agent.get_chat_response(
            message,
            llm_sampling_settings=settings,
            chat_history=messages,
            returns_streaming_generator=True,
            print_output=False
        )

        outputs = ""
        for output in stream:
            outputs += output
            yield outputs
    except Exception as e:
        print(e)
        raise gr.Error(f"Error: {e}")
        #yield ""
    finally:
        torch.cuda.empty_cache()
        gc.collect()