import torch from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel import gradio as gr import os # Use GPU if available device = "cuda" if torch.cuda.is_available() else "cpu" # Base model and adapter paths base_model_name = "microsoft/phi-2" # Pull from HF Hub directly adapter_path = "Shriti09/Microsoft-Phi-QLora" # Update with your Hugging Face repo path print("🔧 Loading base model...") # Using the Accelerator to load the model and dispatch to the correct devices base_model = AutoModelForCausalLM.from_pretrained( base_model_name, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32 ) print("🔧 Loading LoRA adapter...") adapter_model = PeftModel.from_pretrained(base_model, adapter_path) print("🔗 Merging adapter into base model...") merged_model = adapter_model.merge_and_unload() merged_model.eval() # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(base_model_name) print("✅ Model ready for inference!") # Chat function with history def chat_fn(message, history): # Convert history to the required format for gr.Chatbot (list of dictionaries with role and content) full_prompt = "" for user_msg, bot_msg in history: full_prompt += f"User: {user_msg}\nAI: {bot_msg}\n" full_prompt += f"User: {message}\nAI:" # Tokenize inputs inputs = tokenizer(full_prompt, return_tensors="pt").to(device) with torch.no_grad(): outputs = merged_model.generate( **inputs, max_new_tokens=150, do_sample=True, temperature=0.7, top_p=0.9, pad_token_id=tokenizer.eos_token_id ) # Decode and return only the AI's latest response response = tokenizer.decode(outputs[0], skip_special_tokens=True) response = response.split("AI:")[-1].strip() # Append to history in the correct format for gr.Chatbot (list of dictionaries) history.append({"role": "user", "content": message}) history.append({"role": "assistant", "content": response}) return history, history # Gradio UI with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("

🧠 Phi-2 QLoRA Chatbot

") # Use 'type' parameter to specify message format for gr.Chatbot() chatbot = gr.Chatbot(type="messages") # Use 'messages' type for structured messages message = gr.Textbox(label="Your message:") clear = gr.Button("Clear chat") state = gr.State([]) message.submit(chat_fn, [message, state], [chatbot, state]) clear.click(lambda: [], None, chatbot) clear.click(lambda: [], None, state) # Run the app without the 'concurrency_count' argument and share the app publicly demo.queue().launch(share=True)