Spaces:
Sleeping
Sleeping
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from peft import PeftModel | |
import gradio as gr | |
import os | |
# Use GPU if available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Base model and adapter paths | |
base_model_name = "microsoft/phi-2" # Pull from HF Hub directly | |
adapter_path = "Shriti09/Microsoft-Phi-QLora" # Update with your Hugging Face repo path | |
print("π§ Loading base model...") | |
# Using the Accelerator to load the model and dispatch to the correct devices | |
base_model = AutoModelForCausalLM.from_pretrained( | |
base_model_name, | |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
) | |
print("π§ Loading LoRA adapter...") | |
adapter_model = PeftModel.from_pretrained(base_model, adapter_path) | |
print("π Merging adapter into base model...") | |
merged_model = adapter_model.merge_and_unload() | |
merged_model.eval() | |
# Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(base_model_name) | |
print("β Model ready for inference!") | |
# Chat function with history | |
def chat_fn(message, history): | |
# Convert history to the required format for gr.Chatbot (list of dictionaries with role and content) | |
full_prompt = "" | |
for user_msg, bot_msg in history: | |
full_prompt += f"User: {user_msg}\nAI: {bot_msg}\n" | |
full_prompt += f"User: {message}\nAI:" | |
# Tokenize inputs | |
inputs = tokenizer(full_prompt, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = merged_model.generate( | |
**inputs, | |
max_new_tokens=150, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
# Decode and return only the AI's latest response | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
response = response.split("AI:")[-1].strip() | |
# Append to history in the correct format for gr.Chatbot (list of dictionaries) | |
history.append({"role": "user", "content": message}) | |
history.append({"role": "assistant", "content": response}) | |
return history, history | |
# Gradio UI | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("<h1>π§ Phi-2 QLoRA Chatbot</h1>") | |
# Use 'type' parameter to specify message format for gr.Chatbot() | |
chatbot = gr.Chatbot(type="messages") # Use 'messages' type for structured messages | |
message = gr.Textbox(label="Your message:") | |
clear = gr.Button("Clear chat") | |
state = gr.State([]) | |
message.submit(chat_fn, [message, state], [chatbot, state]) | |
clear.click(lambda: [], None, chatbot) | |
clear.click(lambda: [], None, state) | |
# Run the app without the 'concurrency_count' argument and share the app publicly | |
demo.queue().launch(share=True) | |