#!/usr/bin/env python import os import string import gradio as gr import PIL.Image import spaces import torch from transformers import AutoProcessor, BitsAndBytesConfig, Blip2ForConditionalGeneration # 스타일 상수 정의 CUSTOM_CSS = """ .container { max-width: 1000px; margin: auto; padding: 2rem; background: linear-gradient(to bottom right, #ffffff, #f8f9fa); border-radius: 15px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); } .title { font-size: 2.5rem; color: #1a73e8; text-align: center; margin-bottom: 2rem; font-weight: bold; } .tab-nav { background: #f8f9fa; border-radius: 10px; padding: 0.5rem; margin-bottom: 1rem; } .input-box { border: 2px solid #e0e0e0; border-radius: 8px; transition: all 0.3s ease; } .input-box:focus { border-color: #1a73e8; box-shadow: 0 0 0 2px rgba(26, 115, 232, 0.2); } .button-primary { background: #1a73e8; color: white; padding: 0.75rem 1.5rem; border-radius: 8px; border: none; cursor: pointer; transition: all 0.3s ease; } .button-primary:hover { background: #1557b0; transform: translateY(-1px); } .button-secondary { background: #f8f9fa; color: #1a73e8; border: 1px solid #1a73e8; padding: 0.75rem 1.5rem; border-radius: 8px; cursor: pointer; transition: all 0.3s ease; } .button-secondary:hover { background: #e8f0fe; } .output-box { background: #ffffff; border-radius: 8px; padding: 1rem; margin-top: 1rem; border: 1px solid #e0e0e0; } .chatbot-message { padding: 1rem; margin: 0.5rem 0; border-radius: 8px; background: #f8f9fa; } .advanced-settings { background: #ffffff; border-radius: 8px; padding: 1rem; margin-top: 1rem; } .slider-container { padding: 0.5rem; background: #f8f9fa; border-radius: 6px; } .examples-container { margin-top: 2rem; padding: 1rem; background: #ffffff; border-radius: 8px; box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05); } """ DESCRIPTION = """
🖼️ BLIP-2 Visual Intelligence System

Advanced AI system for image understanding and natural conversation

""" if not torch.cuda.is_available(): DESCRIPTION += "\n

Running on CPU 🥶 This demo requires GPU to function properly.

" device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") MODEL_ID_OPT_2_7B = "Salesforce/blip2-opt-2.7b" MODEL_ID_OPT_6_7B = "Salesforce/blip2-opt-6.7b" MODEL_ID_FLAN_T5_XL = "Salesforce/blip2-flan-t5-xl" MODEL_ID_FLAN_T5_XXL = "Salesforce/blip2-flan-t5-xxl" MODEL_ID = os.getenv("MODEL_ID", MODEL_ID_FLAN_T5_XXL) if MODEL_ID not in [MODEL_ID_OPT_2_7B, MODEL_ID_OPT_6_7B, MODEL_ID_FLAN_T5_XL, MODEL_ID_FLAN_T5_XXL]: error_message = f"Invalid MODEL_ID: {MODEL_ID}" raise ValueError(error_message) if torch.cuda.is_available(): processor = AutoProcessor.from_pretrained(MODEL_ID) model = Blip2ForConditionalGeneration.from_pretrained( MODEL_ID, device_map="auto", quantization_config=BitsAndBytesConfig(load_in_8bit=True) ) @spaces.GPU def generate_caption( image: PIL.Image.Image, decoding_method: str = "Nucleus sampling", temperature: float = 1.0, length_penalty: float = 1.0, repetition_penalty: float = 1.5, max_length: int = 50, min_length: int = 1, num_beams: int = 5, top_p: float = 0.9, ) -> str: inputs = processor(images=image, return_tensors="pt").to(device, torch.float16) generated_ids = model.generate( pixel_values=inputs.pixel_values, do_sample=decoding_method == "Nucleus sampling", temperature=temperature, length_penalty=length_penalty, repetition_penalty=repetition_penalty, max_length=max_length, min_length=min_length, num_beams=num_beams, top_p=top_p, ) return processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() @spaces.GPU def answer_question( image: PIL.Image.Image, prompt: str, decoding_method: str = "Nucleus sampling", temperature: float = 1.0, length_penalty: float = 1.0, repetition_penalty: float = 1.5, max_length: int = 50, min_length: int = 1, num_beams: int = 5, top_p: float = 0.9, ) -> str: inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16) generated_ids = model.generate( **inputs, do_sample=decoding_method == "Nucleus sampling", temperature=temperature, length_penalty=length_penalty, repetition_penalty=repetition_penalty, max_length=max_length, min_length=min_length, num_beams=num_beams, top_p=top_p, ) return processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() def postprocess_output(output: str) -> str: if output and output[-1] not in string.punctuation: output += "." return output def chat( image: PIL.Image.Image, text: str, decoding_method: str = "Nucleus sampling", temperature: float = 1.0, length_penalty: float = 1.0, repetition_penalty: float = 1.5, max_length: int = 50, min_length: int = 1, num_beams: int = 5, top_p: float = 0.9, history_orig: list[str] | None = None, history_qa: list[str] | None = None, ) -> tuple[list[tuple[str, str]], list[str], list[str]]: history_orig = history_orig or [] history_qa = history_qa or [] history_orig.append(text) text_qa = f"Question: {text} Answer:" history_qa.append(text_qa) prompt = " ".join(history_qa) output = answer_question( image=image, prompt=prompt, decoding_method=decoding_method, temperature=temperature, length_penalty=length_penalty, repetition_penalty=repetition_penalty, max_length=max_length, min_length=min_length, num_beams=num_beams, top_p=top_p, ) output = postprocess_output(output) history_orig.append(output) history_qa.append(output) chat_val = list(zip(history_orig[0::2], history_orig[1::2], strict=False)) return chat_val, history_orig, history_qa chat.zerogpu = True # type: ignore examples = [ [ "images/house.png", "How could someone get out of the house?", ], [ "images/flower.jpg", "What is this flower and where is it's origin?", ], [ "images/pizza.jpg", "What are steps to cook it?", ], [ "images/sunset.jpg", "Here is a romantic message going along the photo:", ], [ "images/forbidden_city.webp", "In what dynasties was this place built?", ], ] with gr.Blocks(css=CUSTOM_CSS) as demo: gr.Markdown(DESCRIPTION) with gr.Group(elem_classes="container"): with gr.Row(): with gr.Column(scale=1): image = gr.Image( type="pil", label="Upload Image", elem_classes="input-box" ) with gr.Column(scale=2): with gr.Tabs(elem_classes="tab-nav"): with gr.Tab(label="✨ Image Captioning"): caption_button = gr.Button( "Generate Caption", elem_classes="button-primary" ) caption_output = gr.Textbox( label="Generated Caption", elem_classes="output-box" ) with gr.Tab(label="💭 Visual Q&A"): chatbot = gr.Chatbot( elem_classes="chatbot-message" ) history_orig = gr.State(value=[]) history_qa = gr.State(value=[]) vqa_input = gr.Textbox( placeholder="Ask me anything about the image...", elem_classes="input-box" ) with gr.Row(): clear_button = gr.Button( "Clear Chat", elem_classes="button-secondary" ) submit_button = gr.Button( "Send Message", elem_classes="button-primary" ) with gr.Accordion("🛠️ Advanced Settings", open=False, elem_classes="advanced-settings"): with gr.Row(): with gr.Column(): text_decoding_method = gr.Radio( choices=["Beam search", "Nucleus sampling"], value="Nucleus sampling", label="Decoding Method" ) temperature = gr.Slider( minimum=0.5, maximum=1.0, value=1.0, label="Temperature", info="Used with nucleus sampling", elem_classes="slider-container" ) length_penalty = gr.Slider( minimum=-1.0, maximum=2.0, value=1.0, label="Length Penalty", info="Set to larger for longer sequence", elem_classes="slider-container" ) with gr.Column(): repetition_penalty = gr.Slider( minimum=1.0, maximum=5.0, value=1.5, label="Repetition Penalty", info="Larger value prevents repetition", elem_classes="slider-container" ) max_length = gr.Slider( minimum=20, maximum=512, value=50, label="Max Length", elem_classes="slider-container" ) min_length = gr.Slider( minimum=1, maximum=100, value=1, label="Min Length", elem_classes="slider-container" ) num_beams = gr.Slider( minimum=1, maximum=10, value=5, label="Number of Beams", elem_classes="slider-container" ) top_p = gr.Slider( minimum=0.5, maximum=1.0, value=0.9, label="Top P", info="Used with nucleus sampling", elem_classes="slider-container" ) with gr.Group(elem_classes="examples-container"): gr.Examples( examples=examples, inputs=[image, vqa_input], label="Try these examples" ) # Event handlers caption_button.click( fn=generate_caption, inputs=[ image, text_decoding_method, temperature, length_penalty, repetition_penalty, max_length, min_length, num_beams, top_p, ], outputs=caption_output, api_name="caption" ) chat_inputs = [ image, vqa_input, text_decoding_method, temperature, length_penalty, repetition_penalty, max_length, min_length, num_beams, top_p, history_orig, history_qa, ] chat_outputs = [ chatbot, history_orig, history_qa, ] vqa_input.submit( fn=chat, inputs=chat_inputs, outputs=chat_outputs ).success( fn=lambda: "", outputs=vqa_input, queue=False, api_name=False ) submit_button.click( fn=chat, inputs=chat_inputs, outputs=chat_outputs, api_name="chat" ).success( fn=lambda: "", outputs=vqa_input, queue=False, api_name=False ) clear_button.click( fn=lambda: ("", [], [], []), inputs=None, outputs=[ vqa_input, chatbot, history_orig, history_qa, ], queue=False, api_name="clear" ) image.change( fn=lambda: ("", [], [], []), inputs=None, outputs=[ caption_output, chatbot, history_orig, history_qa, ], queue=False ) if __name__ == "__main__": demo.queue(max_size=10).launch()