Phi2_Qlora / app.py
padmanabhbosamia's picture
Upload 12 files
bfe5d0e verified
raw
history blame
6.31 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import random
import time
import os
# Load the model and tokenizer
model_path = "./phi2-qlora-final"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True
)
# Custom CSS for better styling
custom_css = """
.gradio-container {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
.container {
max-width: 800px;
margin: auto;
padding: 20px;
}
.title {
text-align: center;
color: #2c3e50;
margin-bottom: 20px;
}
.description {
text-align: center;
color: #7f8c8d;
margin-bottom: 30px;
}
.loading {
display: flex;
justify-content: center;
align-items: center;
height: 100px;
}
.error {
color: #e74c3c;
padding: 10px;
border-radius: 5px;
background-color: #fde8e8;
margin: 10px 0;
}
"""
def generate_response(prompt, max_length=512, temperature=0.7, top_p=0.9, top_k=50):
"""Generate response with progress indicator"""
try:
if not prompt.strip():
return "Please enter a prompt."
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_length=max_length,
temperature=temperature,
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
top_p=top_p,
top_k=top_k,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
except Exception as e:
return f"Error generating response: {str(e)}"
def clear_all():
"""Clear all inputs and outputs"""
return "", "", 512, 0.7, 0.9, 50
# Example prompts
example_prompts = [
"What is the capital of France?",
"Explain quantum computing in simple terms.",
"Write a short story about a robot learning to paint.",
"What are the benefits of meditation?",
"How does photosynthesis work?",
]
# Create the Gradio interface
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as iface:
gr.Markdown(
"""
# πŸ€– Phi-2 QLoRA Chat Interface
Chat with the fine-tuned Phi-2 model using QLoRA. Adjust the parameters below to control the generation.
""",
elem_classes="title"
)
gr.Markdown(
"""
This interface allows you to interact with a fine-tuned Phi-2 model. You can adjust various parameters to control the generation process.
""",
elem_classes="description"
)
with gr.Row():
with gr.Column(scale=2):
# Input section
with gr.Group():
gr.Markdown("### πŸ’­ Input")
prompt = gr.Textbox(
label="Enter your prompt:",
placeholder="Type your message here...",
lines=3,
show_label=True,
container=True
)
with gr.Row():
max_length = gr.Slider(
minimum=64,
maximum=1024,
value=512,
step=64,
label="Max Length",
info="Maximum length of generated response"
)
temperature = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.1,
label="Temperature",
info="Higher values make output more random"
)
with gr.Row():
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.1,
label="Top P",
info="Nucleus sampling parameter"
)
top_k = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=1,
label="Top K",
info="Top-k sampling parameter"
)
# Buttons
with gr.Row():
submit_btn = gr.Button("Generate Response", variant="primary")
clear_btn = gr.Button("Clear All", variant="secondary")
with gr.Column(scale=2):
# Output section
with gr.Group():
gr.Markdown("### πŸ€– Response")
output = gr.Textbox(
label="Model Response:",
lines=5,
show_label=True,
container=True
)
# Examples section
with gr.Group():
gr.Markdown("### πŸ“ Example Prompts")
gr.Examples(
examples=example_prompts,
inputs=prompt,
outputs=output,
fn=generate_response,
cache_examples=True
)
# Footer
gr.Markdown(
"""
---
Made with ❀️ using Phi-2 and QLoRA
""",
elem_classes="footer"
)
# Event handlers
submit_btn.click(
fn=generate_response,
inputs=[prompt, max_length, temperature, top_p, top_k],
outputs=output
)
clear_btn.click(
fn=clear_all,
inputs=[],
outputs=[prompt, output, max_length, temperature, top_p, top_k]
)
if __name__ == "__main__":
iface.launch(
share=True, # Enable sharing
server_name="0.0.0.0", # Allow external access
server_port=7860, # Default Gradio port
show_error=True # Show detailed error messages
)