from openai import OpenAI
import gradio as gr
import os
import json
import html
import random
import datetime

api_key = os.environ.get('FEATHERLESS_API_KEY')

if not api_key:
    raise RuntimeError("Cannot start without required API key. Please register for one at https://featherless.ai")

client = OpenAI(
    base_url="https://api.featherless.ai/v1",
    api_key=api_key
)

with open('./model-cache.json', 'r') as f_model_cache:
    model_cache = json.load(f_model_cache)
model_class_from_model_id = { model_id: model_class for model_class, model_ids in model_cache.items() for model_id in model_ids }

model_class_filter = {
    "mistral-v02-7b-std-lc": True,
    "llama3-8b-8k": True,
    "llama31-8b-16k": True,
    "llama2-solar-10b7-4k": True,
    "mistral-nemo-12b-lc": True,
    "llama2-13b-4k": True,
    "llama3-15b-8k": True,

    "qwen2-32b-lc":False,
    "llama3-70b-8k":False,
    "llama31-70b-16k": False,
    "qwen2-72b-lc":False,
    "mixtral-8x22b-lc":False,
    "llama3-405b-lc":False,
}

# we run a few other models here as well
REFLECTION="mattshumer/Reflection-Llama-3.1-70B"
QWEN25_72B="Qwen/Qwen2.5-72B"
NEMOTRON="nvidia/Llama-3.1-Nemotron-70B-Instruct-HF"
bigger_whitelisted_models = [
    QWEN25_72B,
    NEMOTRON
]
# REFLECTION is in backup hosting
model_class_from_model_id[REFLECTION] = 'llama31-70b-16k'
model_class_from_model_id[NEMOTRON] = 'llama31-70b-16k'
def build_model_choices():
    all_choices = []
    for model_class in model_cache:
        if model_class not in model_class_filter:
            print(f"Warning: new model class {model_class}. Treating as blacklisted")
            continue

        if not model_class_filter[model_class]:
            continue
        all_choices += [ (f"{model_id} ({model_class})", model_id) for model_id in model_cache[model_class] ]

    all_choices += [ (f"{model_id}, {model_class_from_model_id[model_id]}", model_id) for model_id in bigger_whitelisted_models ]

    return all_choices
model_choices = build_model_choices()
def model_in_list(model):
    for label, id in model_choices:
        if id == model:
            return True
    
    return False

# let's use a random but different model each day.
key=os.environ.get('RANDOM_SEED', 'kcOtfNHA+e')
o = random.Random(f"{key}-{datetime.date.today().strftime('%Y-%m-%d')}")
initial_model = o.choice(model_choices)[1]
initial_model = NEMOTRON
# this doesn't work in HF spaces because we're iframed :(
# def initial_model(referer=None):
#     return REFLECTION

#     if referer == 'http://127.0.0.1:7860/':
#         return 'Sao10K/Venomia-1.1-m7'

#     if referer and referer.startswith("https://huggingface.co/"):
#         possible_model = referer[23:]
#         full_model_list = functools.reduce(lambda x,y: x+y, model_cache.values(), [])
#         model_is_supported = possible_model in full_model_list
#         if model_is_supported:
#             return possible_model

#     # let's use a random but different model each day.
#     key=os.environ.get('RANDOM_SEED', 'kcOtfNHA+e')
#     o = random.Random(f"{key}-{datetime.date.today().strftime('%Y-%m-%d')}")
#     return o.choice(model_choices)[1]


REFLECTION_SYSTEM_PROMPT = """You are a world-class AI system, capable of complex reasoning and reflection. Reason through the query inside <thinking> tags, and then provide your final response inside <output> tags. If you detect that you made a mistake in your reasoning at any point, correct yourself inside <reflection> tags."""

def respond(message, history, model, request: gr.Request):
    # insist on that model is in model_choices
    if not model_in_list(model):
        raise RuntimeError(f"{model} is not supported in this hf space. Visit https://featherless.ai to see and use the complete model catalogue")
    
    history_openai_format = []
    for human, assistant in history:
        history_openai_format.append({"role": "user", "content": human })
        history_openai_format.append({"role": "assistant", "content":assistant})
    history_openai_format.append({"role": "user", "content": message})

    if model == REFLECTION:
        history_openai_format = [
            {"role": "system", "content": REFLECTION_SYSTEM_PROMPT},
            *history_openai_format
        ]

    response = client.chat.completions.create(
        model=model,
        messages= history_openai_format,
        temperature=1.0,
        stream=True,
        max_tokens=2000,
        extra_headers={
            'HTTP-Referer': request.headers.get('referer'),
            'X-Title': "HF's missing inference widget"
        }
    )

    partial_message = ""
    for chunk in response:
        if chunk.choices[0].delta.content is not None:
              content = chunk.choices[0].delta.content
              escaped_content = html.escape(content)
              partial_message += escaped_content
              yield partial_message

logo = open('./logo.svg').read()
logo_small = open('./logo-small.svg').read()
title_text="HuggingFace's missing inference widget"
css = """
.logo-mark { fill: #ffe184; }

/* from https://github.com/gradio-app/gradio/issues/4001
 * necessary as putting ChatInterface in gr.Blocks changes behaviour
 */

 .row {
    display: flex;
    justify-content: center;
 }

 .footer p {
    width: 450px;
 }

.contain { display: flex; flex-direction: column; }
.gradio-container { height: 100vh !important; }
#component-0 { height: 100%; }
#chatbot { flex-grow: 1; overflow: auto;}
"""

with gr.Blocks(title_text, css=css) as demo:
    gr.HTML(f"""
        <div class="header">
            <h1 class="row">HuggingFace's missing inference widget</h1>
            <h3 class="row">powered by</h3>
            <div class="row">
                <a href="https://featherless.ai">
                {logo}
                </a>
            </div>
        </div>
    """)

    # hidden_state = gr.State(value=initial_model)
    with gr.Row():
        model_selector = gr.Dropdown(
            label="Select your Model",
            choices=build_model_choices(),
            value=initial_model,
            # value=hidden_state,
            scale=4
        )
        gr.Button(
            value="Visit Model Card ↗️",
            scale=1
        ).click(
            inputs=[model_selector],
            js="(model_selection) => { window.open(`https://featherless.ai/models/${model_selection}/readme`, '_blank') }",
            fn=None,
        )

    gr.ChatInterface(
        respond,
        additional_inputs=[model_selector],
        head=""",
        <script>console.log("Hello from gradio!")</script>
        """,
        concurrency_limit=5
    )

    # logo_small_no_text = open('./logo-small-no-text.svg').read()
    # x_logo = open('./x-logo.svg').read()
    # discord_logo = open('./discord-logo.svg').read()
    
    gr.HTML(f"""
        <div class="footer">
            <div class="row">
                If you enjoyed this space,
                check out&nbsp;<a href="https://featherless.ai">featherless.ai</a>,
                and follow us&nbsp;<a href="https://x.com/FeatherlessAI">on twitter</a>!
            </div>
            <!-- <div class="row">If you enjoyed this space,</div>
            <div class="row">check out&nbsp;<a href="https://featherless.ai">featherless.ai</a>,</div>
            <div class="row">and follow us&nbsp;<a href="https://x.com/FeatherlessAI">on twitter</a>!</div> -->
        </div>
    """)
    # def update_initial_model_choice(request: gr.Request):
    #     return initial_model(request.headers.get('referer'))

    # demo.load(update_initial_model_choice, outputs=model_selector)

demo.launch()