Akshayram1 commited on
Commit
0912f0e
·
verified ·
1 Parent(s): 9dd61d7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -0
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import PIL.Image
4
+ import torch
5
+ from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
6
+
7
+ # Model and Processor Setup
8
+ model_id = "gv-hf/paligemma2-3b-mix-448"
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ HF_KEY = os.getenv("HF_KEY")
11
+ if not HF_KEY:
12
+ raise ValueError("Please set the HF_KEY environment variable with your Hugging Face API token")
13
+
14
+ model = PaliGemmaForConditionalGeneration.from_pretrained(
15
+ model_id,
16
+ token=HF_KEY,
17
+ trust_remote_code=True
18
+ ).eval().to(device)
19
+
20
+ processor = PaliGemmaProcessor.from_pretrained(
21
+ model_id,
22
+ token=HF_KEY,
23
+ trust_remote_code=True
24
+ )
25
+
26
+ # Inference Function
27
+ def infer(image: PIL.Image.Image, text: str, max_new_tokens: int) -> str:
28
+ inputs = processor(text=text, images=image, return_tensors="pt").to(device)
29
+ with torch.inference_mode():
30
+ generated_ids = model.generate(
31
+ **inputs,
32
+ max_new_tokens=max_new_tokens,
33
+ do_sample=False
34
+ )
35
+ result = processor.batch_decode(generated_ids, skip_special_tokens=True)
36
+ return result[0][len(text):].lstrip("\n")
37
+
38
+ # Image Captioning
39
+ def generate_caption(image: PIL.Image.Image) -> str:
40
+ return infer(image, "caption", max_new_tokens=50)
41
+
42
+ # Object Detection
43
+ def detect_objects(image: PIL.Image.Image) -> str:
44
+ return infer(image, "detect objects", max_new_tokens=200)
45
+
46
+ # Visual Question Answering (VQA)
47
+ def vqa(image: PIL.Image.Image, question: str) -> str:
48
+ return infer(image, f"Q: {question} A:", max_new_tokens=50)
49
+
50
+ # Custom CSS for Styling
51
+ custom_css = """
52
+ .gradio-container {
53
+ font-family: 'Arial', sans-serif;
54
+ }
55
+ .upload-button {
56
+ background-color: #4285f4;
57
+ color: white;
58
+ border-radius: 5px;
59
+ padding: 10px 20px;
60
+ }
61
+ .output-text {
62
+ font-size: 18px;
63
+ font-weight: bold;
64
+ }
65
+ """
66
+
67
+ # Gradio App
68
+ with gr.Blocks(css=custom_css) as demo:
69
+ gr.Markdown("# PaliGemma Multi-Modal App")
70
+ gr.Markdown("Upload an image and explore its features using the PaliGemma model!")
71
+
72
+ with gr.Tabs():
73
+ # Tab 1: Image Captioning
74
+ with gr.Tab("Image Captioning"):
75
+ with gr.Row():
76
+ with gr.Column():
77
+ caption_image = gr.Image(type="pil", label="Upload Image", width=512, height=512)
78
+ caption_btn = gr.Button("Generate Caption", elem_classes="upload-button")
79
+ with gr.Column():
80
+ caption_output = gr.Text(label="Generated Caption", elem_classes="output-text")
81
+ caption_btn.click(fn=generate_caption, inputs=[caption_image], outputs=[caption_output])
82
+
83
+ # Tab 2: Object Detection
84
+ with gr.Tab("Object Detection"):
85
+ with gr.Row():
86
+ with gr.Column():
87
+ detect_image = gr.Image(type="pil", label="Upload Image", width=512, height=512)
88
+ detect_btn = gr.Button("Detect Objects", elem_classes="upload-button")
89
+ with gr.Column():
90
+ detect_output = gr.Text(label="Detected Objects", elem_classes="output-text")
91
+ detect_btn.click(fn=detect_objects, inputs=[detect_image], outputs=[detect_output])
92
+
93
+ # Tab 3: Visual Question Answering (VQA)
94
+ with gr.Tab("Visual Question Answering"):
95
+ with gr.Row():
96
+ with gr.Column():
97
+ vqa_image = gr.Image(type="pil", label="Upload Image", width=512, height=512)
98
+ vqa_question = gr.Text(label="Ask a Question", placeholder="What is in the image?")
99
+ vqa_btn = gr.Button("Ask", elem_classes="upload-button")
100
+ with gr.Column():
101
+ vqa_output = gr.Text(label="Answer", elem_classes="output-text")
102
+ vqa_btn.click(fn=vqa, inputs=[vqa_image, vqa_question], outputs=[vqa_output])
103
+
104
+ # Tab 4: Text Generation (Original Feature)
105
+ with gr.Tab("Text Generation"):
106
+ with gr.Row():
107
+ with gr.Column():
108
+ text_image = gr.Image(type="pil", label="Upload Image", width=512, height=512)
109
+ text_input = gr.Text(label="Input Text", placeholder="Describe the image...")
110
+ text_btn = gr.Button("Generate Text", elem_classes="upload-button")
111
+ with gr.Column():
112
+ text_output = gr.Text(label="Generated Text", elem_classes="output-text")
113
+ text_btn.click(fn=infer, inputs=[text_image, text_input, gr.Slider(10, 200, value=50)], outputs=[text_output])
114
+
115
+ # Image Upload/Download
116
+ with gr.Row():
117
+ upload_button = gr.UploadButton("Upload Image", file_types=["image"], elem_classes="upload-button")
118
+ download_button = gr.DownloadButton("Download Results", elem_classes="upload-button")
119
+
120
+ # Real-Time Updates
121
+ caption_image.change(fn=generate_caption, inputs=[caption_image], outputs=[caption_output], live=True)
122
+ detect_image.change(fn=detect_objects, inputs=[detect_image], outputs=[detect_output], live=True)
123
+ vqa_image.change(fn=lambda x: vqa(x, "What is in the image?"), inputs=[vqa_image], outputs=[vqa_output], live=True)
124
+
125
+ # Launch the App
126
+ if __name__ == "__main__":
127
+ demo.queue(max_size=10).launch(debug=True)