import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import torch import random import time import os # Load the model and tokenizer model_path = "./phi2-qlora-final" tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained( model_path, device_map="cpu", # Force CPU usage torch_dtype=torch.float32, # Use float32 for CPU trust_remote_code=True ) # Custom CSS for better styling custom_css = """ .gradio-container { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; } .container { max-width: 800px; margin: auto; padding: 20px; } .title { text-align: center; color: #2c3e50; margin-bottom: 20px; } .description { text-align: center; color: #7f8c8d; margin-bottom: 30px; } .loading { display: flex; justify-content: center; align-items: center; height: 100px; } .error { color: #e74c3c; padding: 10px; border-radius: 5px; background-color: #fde8e8; margin: 10px 0; } """ def generate_response(prompt, max_length=512, temperature=0.7, top_p=0.9, top_k=50): """Generate response with progress indicator""" try: if not prompt.strip(): return "Please enter a prompt." inputs = tokenizer(prompt, return_tensors="pt") with torch.no_grad(): # Disable gradient computation for inference outputs = model.generate( **inputs, max_length=max_length, temperature=temperature, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id, do_sample=True, top_p=top_p, top_k=top_k, ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) return response except Exception as e: return f"Error generating response: {str(e)}" def clear_all(): """Clear all inputs and outputs""" return "", "", 512, 0.7, 0.9, 50 # Example prompts example_prompts = [ "What is the capital of France?", "Explain quantum computing in simple terms.", "Write a short story about a robot learning to paint.", "What are the benefits of meditation?", "How does photosynthesis work?", ] # Create the Gradio interface with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as iface: gr.Markdown( """ # 🤖 Phi-2 QLoRA Chat Interface (CPU Version) Chat with the fine-tuned Phi-2 model using QLoRA. This version runs on CPU for better compatibility. """, elem_classes="title" ) gr.Markdown( """ This interface allows you to interact with a fine-tuned Phi-2 model. Note that responses may be slower due to CPU-only inference. """, elem_classes="description" ) with gr.Row(): with gr.Column(scale=2): # Input section with gr.Group(): gr.Markdown("### 💭 Input") prompt = gr.Textbox( label="Enter your prompt:", placeholder="Type your message here...", lines=3, show_label=True, container=True ) with gr.Row(): max_length = gr.Slider( minimum=64, maximum=512, # Reduced max length for CPU value=256, # Reduced default length step=64, label="Max Length", info="Maximum length of generated response" ) temperature = gr.Slider( minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature", info="Higher values make output more random" ) with gr.Row(): top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top P", info="Nucleus sampling parameter" ) top_k = gr.Slider( minimum=1, maximum=100, value=50, step=1, label="Top K", info="Top-k sampling parameter" ) # Buttons with gr.Row(): submit_btn = gr.Button("Generate Response", variant="primary") clear_btn = gr.Button("Clear All", variant="secondary") with gr.Column(scale=2): # Output section with gr.Group(): gr.Markdown("### 🤖 Response") output = gr.Textbox( label="Model Response:", lines=5, show_label=True, container=True ) # Examples section with gr.Group(): gr.Markdown("### 📝 Example Prompts") gr.Examples( examples=example_prompts, inputs=prompt, outputs=output, fn=generate_response, cache_examples=True ) # Footer gr.Markdown( """ --- Made with ❤️ using Phi-2 and QLoRA (CPU Version) """, elem_classes="footer" ) # Event handlers submit_btn.click( fn=generate_response, inputs=[prompt, max_length, temperature, top_p, top_k], outputs=output ) clear_btn.click( fn=clear_all, inputs=[], outputs=[prompt, output, max_length, temperature, top_p, top_k] ) if __name__ == "__main__": iface.launch( share=True, # Enable sharing server_name="0.0.0.0", # Allow external access server_port=7860, # Default Gradio port show_error=True # Show detailed error messages )