import gradio as gr
import huggingface_hub
import os
import spaces
import torch

from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, Qwen2_5_VLForConditionalGeneration
from datasets import load_dataset

huggingface_hub.login(os.getenv('HF_TOKEN'))
peft_model_id = "debisoft/DeepSeek-R1-Distill-Qwen-7B-thinking-function_calling-quant-V0"
#peft_model_id = "debisoft/Qwen2.5-VL-7B-Instruct-thinking-function_calling-quant-V0"
#peft_model_id = "debisoft/Qwen2.5-VL-3B-Instruct-thinking-function_calling-V0"

bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
        )

device = "auto"
cuda_device = torch.device("cuda")
cpu_device = torch.device("cpu")

config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path,
#model = Qwen2_5_VLForConditionalGeneration.from_pretrained(config.base_model_name_or_path,
                                             quantization_config=bnb_config,
                                             device_map="auto",
                                             )
tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
model.resize_token_embeddings(len(tokenizer))

#tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B")
#model = AutoModelForCausalLM.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B")


@spaces.GPU
def get_completion(msg):
    peft_model = PeftModel.from_pretrained(model, peft_model_id, device_map="cuda"
        #offload_folder = "offload/"
        )

    #peft_model.to(torch.bfloat16)
    peft_model.eval()

    #peft_model.to(cuda_device)
    #"Are you sentient?"
    inputs = tokenizer(msg, return_tensors="pt").to(cuda_device)

    with torch.no_grad():
        outputs = peft_model.generate(
            **inputs, max_new_tokens=512, pad_token_id = tokenizer.eos_token_id
        )

    #peft_model.to(cpu_device)

    return tokenizer.decode(outputs[0], skip_special_tokens=True)

def greet(input):
	total_prompt=f"""{input}"""

	print("***total_prompt:")
	print(total_prompt)
	response = get_completion(total_prompt)
	#gen_text = response["predictions"][0]["generated_text"]
	#return json.dumps(extract_json(gen_text, 3))

	###gen_text = response["choices"][0]["text"]

	#return gen_text

	###return json.dumps(extract_json(gen_text, -1))
	return response

demo = gr.Interface(fn=greet, inputs=[gr.Textbox(label="Elevator pitcher", lines=1)], outputs=gr.Text())
demo.launch()