Akshayram1 commited on
Commit
900613f
·
verified ·
1 Parent(s): 72fe4af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -37
app.py CHANGED
@@ -5,12 +5,13 @@ import torch
5
  from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
6
 
7
  # Model and Processor Setup
8
- model_id = "google/paligemma2-3b-mix-448"
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
  HF_KEY = os.getenv("HF_KEY")
11
  if not HF_KEY:
12
  raise ValueError("Please set the HF_KEY environment variable with your Hugging Face API token")
13
 
 
14
  model = PaliGemmaForConditionalGeneration.from_pretrained(
15
  model_id,
16
  token=HF_KEY,
@@ -47,25 +48,8 @@ def detect_objects(image: PIL.Image.Image) -> str:
47
  def vqa(image: PIL.Image.Image, question: str) -> str:
48
  return infer(image, f"Q: {question} A:", max_new_tokens=50)
49
 
50
- # Custom CSS for Styling
51
- custom_css = """
52
- .gradio-container {
53
- font-family: 'Arial', sans-serif;
54
- }
55
- .upload-button {
56
- background-color: #4285f4;
57
- color: white;
58
- border-radius: 5px;
59
- padding: 10px 20px;
60
- }
61
- .output-text {
62
- font-size: 18px;
63
- font-weight: bold;
64
- }
65
- """
66
-
67
  # Gradio App
68
- with gr.Blocks(css=custom_css) as demo:
69
  gr.Markdown("# PaliGemma Multi-Modal App")
70
  gr.Markdown("Upload an image and explore its features using the PaliGemma model!")
71
 
@@ -75,9 +59,9 @@ with gr.Blocks(css=custom_css) as demo:
75
  with gr.Row():
76
  with gr.Column():
77
  caption_image = gr.Image(type="pil", label="Upload Image", width=512, height=512)
78
- caption_btn = gr.Button("Generate Caption", elem_classes="upload-button")
79
  with gr.Column():
80
- caption_output = gr.Text(label="Generated Caption", elem_classes="output-text")
81
  caption_btn.click(fn=generate_caption, inputs=[caption_image], outputs=[caption_output])
82
 
83
  # Tab 2: Object Detection
@@ -85,9 +69,9 @@ with gr.Blocks(css=custom_css) as demo:
85
  with gr.Row():
86
  with gr.Column():
87
  detect_image = gr.Image(type="pil", label="Upload Image", width=512, height=512)
88
- detect_btn = gr.Button("Detect Objects", elem_classes="upload-button")
89
  with gr.Column():
90
- detect_output = gr.Text(label="Detected Objects", elem_classes="output-text")
91
  detect_btn.click(fn=detect_objects, inputs=[detect_image], outputs=[detect_output])
92
 
93
  # Tab 3: Visual Question Answering (VQA)
@@ -96,9 +80,9 @@ with gr.Blocks(css=custom_css) as demo:
96
  with gr.Column():
97
  vqa_image = gr.Image(type="pil", label="Upload Image", width=512, height=512)
98
  vqa_question = gr.Text(label="Ask a Question", placeholder="What is in the image?")
99
- vqa_btn = gr.Button("Ask", elem_classes="upload-button")
100
  with gr.Column():
101
- vqa_output = gr.Text(label="Answer", elem_classes="output-text")
102
  vqa_btn.click(fn=vqa, inputs=[vqa_image, vqa_question], outputs=[vqa_output])
103
 
104
  # Tab 4: Text Generation (Original Feature)
@@ -107,21 +91,11 @@ with gr.Blocks(css=custom_css) as demo:
107
  with gr.Column():
108
  text_image = gr.Image(type="pil", label="Upload Image", width=512, height=512)
109
  text_input = gr.Text(label="Input Text", placeholder="Describe the image...")
110
- text_btn = gr.Button("Generate Text", elem_classes="upload-button")
111
  with gr.Column():
112
- text_output = gr.Text(label="Generated Text", elem_classes="output-text")
113
  text_btn.click(fn=infer, inputs=[text_image, text_input, gr.Slider(10, 200, value=50)], outputs=[text_output])
114
 
115
- # Image Upload/Download
116
- with gr.Row():
117
- upload_button = gr.UploadButton("Upload Image", file_types=["image"], elem_classes="upload-button")
118
- download_button = gr.DownloadButton("Download Results", elem_classes="upload-button")
119
-
120
- # Real-Time Updates
121
- caption_image.change(fn=generate_caption, inputs=[caption_image], outputs=[caption_output], live=True)
122
- detect_image.change(fn=detect_objects, inputs=[detect_image], outputs=[detect_output], live=True)
123
- vqa_image.change(fn=lambda x: vqa(x, "What is in the image?"), inputs=[vqa_image], outputs=[vqa_output], live=True)
124
-
125
  # Launch the App
126
  if __name__ == "__main__":
127
  demo.queue(max_size=10).launch(debug=True)
 
5
  from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
6
 
7
  # Model and Processor Setup
8
+ model_id = "gv-hf/paligemma2-3b-mix-448"
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
  HF_KEY = os.getenv("HF_KEY")
11
  if not HF_KEY:
12
  raise ValueError("Please set the HF_KEY environment variable with your Hugging Face API token")
13
 
14
+ # Load model and processor
15
  model = PaliGemmaForConditionalGeneration.from_pretrained(
16
  model_id,
17
  token=HF_KEY,
 
48
  def vqa(image: PIL.Image.Image, question: str) -> str:
49
  return infer(image, f"Q: {question} A:", max_new_tokens=50)
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  # Gradio App
52
+ with gr.Blocks() as demo:
53
  gr.Markdown("# PaliGemma Multi-Modal App")
54
  gr.Markdown("Upload an image and explore its features using the PaliGemma model!")
55
 
 
59
  with gr.Row():
60
  with gr.Column():
61
  caption_image = gr.Image(type="pil", label="Upload Image", width=512, height=512)
62
+ caption_btn = gr.Button("Generate Caption")
63
  with gr.Column():
64
+ caption_output = gr.Text(label="Generated Caption")
65
  caption_btn.click(fn=generate_caption, inputs=[caption_image], outputs=[caption_output])
66
 
67
  # Tab 2: Object Detection
 
69
  with gr.Row():
70
  with gr.Column():
71
  detect_image = gr.Image(type="pil", label="Upload Image", width=512, height=512)
72
+ detect_btn = gr.Button("Detect Objects")
73
  with gr.Column():
74
+ detect_output = gr.Text(label="Detected Objects")
75
  detect_btn.click(fn=detect_objects, inputs=[detect_image], outputs=[detect_output])
76
 
77
  # Tab 3: Visual Question Answering (VQA)
 
80
  with gr.Column():
81
  vqa_image = gr.Image(type="pil", label="Upload Image", width=512, height=512)
82
  vqa_question = gr.Text(label="Ask a Question", placeholder="What is in the image?")
83
+ vqa_btn = gr.Button("Ask")
84
  with gr.Column():
85
+ vqa_output = gr.Text(label="Answer")
86
  vqa_btn.click(fn=vqa, inputs=[vqa_image, vqa_question], outputs=[vqa_output])
87
 
88
  # Tab 4: Text Generation (Original Feature)
 
91
  with gr.Column():
92
  text_image = gr.Image(type="pil", label="Upload Image", width=512, height=512)
93
  text_input = gr.Text(label="Input Text", placeholder="Describe the image...")
94
+ text_btn = gr.Button("Generate Text")
95
  with gr.Column():
96
+ text_output = gr.Text(label="Generated Text")
97
  text_btn.click(fn=infer, inputs=[text_image, text_input, gr.Slider(10, 200, value=50)], outputs=[text_output])
98
 
 
 
 
 
 
 
 
 
 
 
99
  # Launch the App
100
  if __name__ == "__main__":
101
  demo.queue(max_size=10).launch(debug=True)