BLIP2 / app.py
openfree's picture
Update app.py
5c85beb verified
raw
history blame
5.63 kB
#!/usr/bin/env python
import os
import string
import gradio as gr
import PIL.Image
import spaces
import torch
from transformers import AutoProcessor, BitsAndBytesConfig, Blip2ForConditionalGeneration
# ์Šคํƒ€์ผ ์ƒ์ˆ˜ ์ •์˜
CUSTOM_CSS = """
.container {
max-width: 1000px;
margin: auto;
padding: 2rem;
background: linear-gradient(to bottom right, #ffffff, #f8f9fa);
border-radius: 15px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.title {
font-size: 2.5rem;
color: #1a73e8;
text-align: center;
margin-bottom: 2rem;
font-weight: bold;
}
.tab-nav {
background: #f8f9fa;
border-radius: 10px;
padding: 0.5rem;
margin-bottom: 1rem;
}
.input-box {
border: 2px solid #e0e0e0;
border-radius: 8px;
transition: all 0.3s ease;
}
.input-box:focus {
border-color: #1a73e8;
box-shadow: 0 0 0 2px rgba(26, 115, 232, 0.2);
}
.button-primary {
background: #1a73e8;
color: white;
padding: 0.75rem 1.5rem;
border-radius: 8px;
border: none;
cursor: pointer;
transition: all 0.3s ease;
}
.button-primary:hover {
background: #1557b0;
transform: translateY(-1px);
}
.output-box {
background: #ffffff;
border-radius: 8px;
padding: 1rem;
margin-top: 1rem;
border: 1px solid #e0e0e0;
}
.chatbot-message {
padding: 1rem;
margin: 0.5rem 0;
border-radius: 8px;
background: #f8f9fa;
}
.advanced-settings {
background: #ffffff;
border-radius: 8px;
padding: 1rem;
margin-top: 1rem;
}
.slider-container {
padding: 0.5rem;
background: #f8f9fa;
border-radius: 6px;
}
"""
DESCRIPTION = """
<div class="title">
๐Ÿ–ผ๏ธ BLIP-2 Visual Intelligence System
</div>
<p style='text-align: center; color: #666;'>
Advanced AI system for image understanding and natural conversation
</p>
"""
if not torch.cuda.is_available():
DESCRIPTION += "\n<p style='color: #dc3545;'>Running on CPU ๐Ÿฅถ This demo requires GPU to function properly.</p>"
# ๋ชจ๋ธ ์„ค์ • ๋ถ€๋ถ„์€ ๋™์ผํ•˜๊ฒŒ ์œ ์ง€...
def create_interface():
with gr.Blocks(css=CUSTOM_CSS) as demo:
gr.Markdown(DESCRIPTION)
with gr.Group(elem_classes="container"):
with gr.Row():
with gr.Column(scale=1):
image = gr.Image(
type="pil",
label="Upload Image",
elem_classes="input-box"
)
with gr.Column(scale=2):
with gr.Tabs(elem_classes="tab-nav"):
with gr.Tab(label="โœจ Image Captioning"):
caption_button = gr.Button(
"Generate Caption",
elem_classes="button-primary"
)
caption_output = gr.Textbox(
label="Generated Caption",
elem_classes="output-box"
)
with gr.Tab(label="๐Ÿ’ญ Visual Q&A"):
chatbot = gr.Chatbot(
elem_classes="chatbot-message"
)
vqa_input = gr.Textbox(
placeholder="Ask me anything about the image...",
elem_classes="input-box"
)
with gr.Row():
clear_button = gr.Button(
"Clear Chat",
elem_classes="button-secondary"
)
submit_button = gr.Button(
"Send Message",
elem_classes="button-primary"
)
with gr.Accordion("๐Ÿ› ๏ธ Advanced Settings", open=False, elem_classes="advanced-settings"):
# ๊ณ ๊ธ‰ ์„ค์ • ์ปจํŠธ๋กค๋“ค...
with gr.Row():
with gr.Column():
text_decoding_method = gr.Radio(
choices=["Beam search", "Nucleus sampling"],
value="Nucleus sampling",
label="Decoding Method"
)
temperature = gr.Slider(
minimum=0.5,
maximum=1.0,
value=1.0,
label="Temperature",
elem_classes="slider-container"
)
with gr.Column():
length_penalty = gr.Slider(
minimum=-1.0,
maximum=2.0,
value=1.0,
label="Length Penalty",
elem_classes="slider-container"
)
repetition_penalty = gr.Slider(
minimum=1.0,
maximum=5.0,
value=1.5,
label="Repetition Penalty",
elem_classes="slider-container"
)
# ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ ์—ฐ๊ฒฐ...
return demo
if __name__ == "__main__":
demo = create_interface()
demo.queue(max_size=10).launch()