Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
import torch | |
import random | |
import time | |
import os | |
# Load the model and tokenizer | |
model_path = "./phi2-qlora-final" | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
trust_remote_code=True | |
) | |
# Custom CSS for better styling | |
custom_css = """ | |
.gradio-container { | |
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
} | |
.container { | |
max-width: 800px; | |
margin: auto; | |
padding: 20px; | |
} | |
.title { | |
text-align: center; | |
color: #2c3e50; | |
margin-bottom: 20px; | |
} | |
.description { | |
text-align: center; | |
color: #7f8c8d; | |
margin-bottom: 30px; | |
} | |
.loading { | |
display: flex; | |
justify-content: center; | |
align-items: center; | |
height: 100px; | |
} | |
.error { | |
color: #e74c3c; | |
padding: 10px; | |
border-radius: 5px; | |
background-color: #fde8e8; | |
margin: 10px 0; | |
} | |
""" | |
def generate_response(prompt, max_length=512, temperature=0.7, top_p=0.9, top_k=50): | |
"""Generate response with progress indicator""" | |
try: | |
if not prompt.strip(): | |
return "Please enter a prompt." | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
outputs = model.generate( | |
**inputs, | |
max_length=max_length, | |
temperature=temperature, | |
num_return_sequences=1, | |
pad_token_id=tokenizer.eos_token_id, | |
do_sample=True, | |
top_p=top_p, | |
top_k=top_k, | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response | |
except Exception as e: | |
return f"Error generating response: {str(e)}" | |
def clear_all(): | |
"""Clear all inputs and outputs""" | |
return "", "", 512, 0.7, 0.9, 50 | |
# Example prompts | |
example_prompts = [ | |
"What is the capital of France?", | |
"Explain quantum computing in simple terms.", | |
"Write a short story about a robot learning to paint.", | |
"What are the benefits of meditation?", | |
"How does photosynthesis work?", | |
] | |
# Create the Gradio interface | |
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as iface: | |
gr.Markdown( | |
""" | |
# π€ Phi-2 QLoRA Chat Interface | |
Chat with the fine-tuned Phi-2 model using QLoRA. Adjust the parameters below to control the generation. | |
""", | |
elem_classes="title" | |
) | |
gr.Markdown( | |
""" | |
This interface allows you to interact with a fine-tuned Phi-2 model. You can adjust various parameters to control the generation process. | |
""", | |
elem_classes="description" | |
) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
# Input section | |
with gr.Group(): | |
gr.Markdown("### π Input") | |
prompt = gr.Textbox( | |
label="Enter your prompt:", | |
placeholder="Type your message here...", | |
lines=3, | |
show_label=True, | |
container=True | |
) | |
with gr.Row(): | |
max_length = gr.Slider( | |
minimum=64, | |
maximum=1024, | |
value=512, | |
step=64, | |
label="Max Length", | |
info="Maximum length of generated response" | |
) | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature", | |
info="Higher values make output more random" | |
) | |
with gr.Row(): | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.9, | |
step=0.1, | |
label="Top P", | |
info="Nucleus sampling parameter" | |
) | |
top_k = gr.Slider( | |
minimum=1, | |
maximum=100, | |
value=50, | |
step=1, | |
label="Top K", | |
info="Top-k sampling parameter" | |
) | |
# Buttons | |
with gr.Row(): | |
submit_btn = gr.Button("Generate Response", variant="primary") | |
clear_btn = gr.Button("Clear All", variant="secondary") | |
with gr.Column(scale=2): | |
# Output section | |
with gr.Group(): | |
gr.Markdown("### π€ Response") | |
output = gr.Textbox( | |
label="Model Response:", | |
lines=5, | |
show_label=True, | |
container=True | |
) | |
# Examples section | |
with gr.Group(): | |
gr.Markdown("### π Example Prompts") | |
gr.Examples( | |
examples=example_prompts, | |
inputs=prompt, | |
outputs=output, | |
fn=generate_response, | |
cache_examples=True | |
) | |
# Footer | |
gr.Markdown( | |
""" | |
--- | |
Made with β€οΈ using Phi-2 and QLoRA | |
""", | |
elem_classes="footer" | |
) | |
# Event handlers | |
submit_btn.click( | |
fn=generate_response, | |
inputs=[prompt, max_length, temperature, top_p, top_k], | |
outputs=output | |
) | |
clear_btn.click( | |
fn=clear_all, | |
inputs=[], | |
outputs=[prompt, output, max_length, temperature, top_p, top_k] | |
) | |
if __name__ == "__main__": | |
iface.launch( | |
share=True, # Enable sharing | |
server_name="0.0.0.0", # Allow external access | |
server_port=7860, # Default Gradio port | |
show_error=True # Show detailed error messages | |
) |