import spaces
import gradio as gr
from transformers import pipeline, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
import os

@spaces.GPU
def load_model(model_name):
    return pipeline("text-generation", model=model_name, device_map="cuda", torch_dtype=torch.bfloat16, trust_remote_code=True, token=os.environ["token"])
@spaces.GPU()
def generate(
    model_name,
    user_input,
    temperature=0.4,
    top_p=0.95,
    min_p=0.1,
    top_k=50,
    max_new_tokens=256,
):
    pipe = load_model(model_name)

    # Set tokenize correctly. Otherwise ticking the box breaks it.
    if model_name == "M4-ai/tau-1.8B":
        prompt = user_input
    else:
        prompt = f"<|im_start|>user\n{user_input}<|im_end|>\n<|im_start|>assistant\n"
    streamer = TextIteratorStreamer(pipe.tokenizer, timeout=240.0, skip_prompt=True, skip_special_tokens=True)
    generation_kwargs = dict(text_inputs=prompt, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, min_p=min_p, top_k=top_k, 
                              temperature=temperature, num_beams=1, repetition_penalty=1.1)
    t = Thread(target=pipe.__call__, kwargs=generation_kwargs)
    t.start()
    outputs = []
    for chunk in streamer:
        outputs.append(chunk)
        yield "".join(outputs)

model_choices = ["Locutusque/Apollo-0.4-InternLM-2.5-7B", "Locutusque/Apollo-0.4-Llama-3.1-8B", "Locutusque/Llama-3-NeuralHermes-Pro-8B", "Locutusque/Hercules-5.0-Qwen2-7B", "Locutusque/Llama-3-NeuralHercules-5.0-8B", "Locutusque/Hercules-5.0-Index-1.9B", "Locutusque/Llama-3-Hercules-5.0-8B"]
# What at the best options? 
g = gr.Interface(
    fn=generate,
    inputs=[
        gr.components.Dropdown(choices=model_choices, label="Model", value=model_choices[0], interactive=True),
        gr.components.Textbox(lines=2, label="Prompt", value="Write me a Python program that calculates the factorial of a given number."),
        gr.components.Slider(minimum=0, maximum=1, value=0.8, label="Temperature"),
        gr.components.Slider(minimum=0, maximum=1, value=0.95, label="Top p"),
        gr.components.Slider(minimum=0, maximum=1, value=0.1, label="Min P"),
        gr.components.Slider(minimum=0, maximum=100, step=1, value=15, label="Top k"),
        gr.components.Slider(minimum=1, maximum=2048, step=1, value=1024, label="Max tokens"),  
    ],
    outputs=[gr.Textbox(lines=10, label="Output")],
    title="Locutusque's Language Models",
    description="Try out Locutusque's language models here! Credit goes to Mediocreatmybest for this space. You may also find some experimental preview models that have not been made public here.",
    concurrency_limit=1
)

g.launch(max_threads=4)