File size: 5,213 Bytes
0912f0e
 
 
 
 
 
 
72fe4af
0912f0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import gradio as gr
import PIL.Image
import torch
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor

# Model and Processor Setup
model_id = "google/paligemma2-3b-mix-448"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
HF_KEY = os.getenv("HF_KEY")
if not HF_KEY:
    raise ValueError("Please set the HF_KEY environment variable with your Hugging Face API token")

model = PaliGemmaForConditionalGeneration.from_pretrained(
    model_id,
    token=HF_KEY,
    trust_remote_code=True
).eval().to(device)

processor = PaliGemmaProcessor.from_pretrained(
    model_id,
    token=HF_KEY,
    trust_remote_code=True
)

# Inference Function
def infer(image: PIL.Image.Image, text: str, max_new_tokens: int) -> str:
    inputs = processor(text=text, images=image, return_tensors="pt").to(device)
    with torch.inference_mode():
        generated_ids = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False
        )
    result = processor.batch_decode(generated_ids, skip_special_tokens=True)
    return result[0][len(text):].lstrip("\n")

# Image Captioning
def generate_caption(image: PIL.Image.Image) -> str:
    return infer(image, "caption", max_new_tokens=50)

# Object Detection
def detect_objects(image: PIL.Image.Image) -> str:
    return infer(image, "detect objects", max_new_tokens=200)

# Visual Question Answering (VQA)
def vqa(image: PIL.Image.Image, question: str) -> str:
    return infer(image, f"Q: {question} A:", max_new_tokens=50)

# Custom CSS for Styling
custom_css = """
.gradio-container {
    font-family: 'Arial', sans-serif;
}
.upload-button {
    background-color: #4285f4;
    color: white;
    border-radius: 5px;
    padding: 10px 20px;
}
.output-text {
    font-size: 18px;
    font-weight: bold;
}
"""

# Gradio App
with gr.Blocks(css=custom_css) as demo:
    gr.Markdown("# PaliGemma Multi-Modal App")
    gr.Markdown("Upload an image and explore its features using the PaliGemma model!")

    with gr.Tabs():
        # Tab 1: Image Captioning
        with gr.Tab("Image Captioning"):
            with gr.Row():
                with gr.Column():
                    caption_image = gr.Image(type="pil", label="Upload Image", width=512, height=512)
                    caption_btn = gr.Button("Generate Caption", elem_classes="upload-button")
                with gr.Column():
                    caption_output = gr.Text(label="Generated Caption", elem_classes="output-text")
            caption_btn.click(fn=generate_caption, inputs=[caption_image], outputs=[caption_output])

        # Tab 2: Object Detection
        with gr.Tab("Object Detection"):
            with gr.Row():
                with gr.Column():
                    detect_image = gr.Image(type="pil", label="Upload Image", width=512, height=512)
                    detect_btn = gr.Button("Detect Objects", elem_classes="upload-button")
                with gr.Column():
                    detect_output = gr.Text(label="Detected Objects", elem_classes="output-text")
            detect_btn.click(fn=detect_objects, inputs=[detect_image], outputs=[detect_output])

        # Tab 3: Visual Question Answering (VQA)
        with gr.Tab("Visual Question Answering"):
            with gr.Row():
                with gr.Column():
                    vqa_image = gr.Image(type="pil", label="Upload Image", width=512, height=512)
                    vqa_question = gr.Text(label="Ask a Question", placeholder="What is in the image?")
                    vqa_btn = gr.Button("Ask", elem_classes="upload-button")
                with gr.Column():
                    vqa_output = gr.Text(label="Answer", elem_classes="output-text")
            vqa_btn.click(fn=vqa, inputs=[vqa_image, vqa_question], outputs=[vqa_output])

        # Tab 4: Text Generation (Original Feature)
        with gr.Tab("Text Generation"):
            with gr.Row():
                with gr.Column():
                    text_image = gr.Image(type="pil", label="Upload Image", width=512, height=512)
                    text_input = gr.Text(label="Input Text", placeholder="Describe the image...")
                    text_btn = gr.Button("Generate Text", elem_classes="upload-button")
                with gr.Column():
                    text_output = gr.Text(label="Generated Text", elem_classes="output-text")
            text_btn.click(fn=infer, inputs=[text_image, text_input, gr.Slider(10, 200, value=50)], outputs=[text_output])

    # Image Upload/Download
    with gr.Row():
        upload_button = gr.UploadButton("Upload Image", file_types=["image"], elem_classes="upload-button")
        download_button = gr.DownloadButton("Download Results", elem_classes="upload-button")

    # Real-Time Updates
    caption_image.change(fn=generate_caption, inputs=[caption_image], outputs=[caption_output], live=True)
    detect_image.change(fn=detect_objects, inputs=[detect_image], outputs=[detect_output], live=True)
    vqa_image.change(fn=lambda x: vqa(x, "What is in the image?"), inputs=[vqa_image], outputs=[vqa_output], live=True)

# Launch the App
if __name__ == "__main__":
    demo.queue(max_size=10).launch(debug=True)