openfree commited on
Commit
5c85beb
·
verified ·
1 Parent(s): b983b59

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -312
app.py CHANGED
@@ -2,326 +2,187 @@
2
 
3
  import os
4
  import string
5
-
6
  import gradio as gr
7
  import PIL.Image
8
  import spaces
9
  import torch
10
  from transformers import AutoProcessor, BitsAndBytesConfig, Blip2ForConditionalGeneration
11
 
12
- DESCRIPTION = "# [BLIP-2](https://github.com/salesforce/LAVIS/tree/main/projects/blip2)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  if not torch.cuda.is_available():
15
- DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
16
-
17
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
18
-
19
- MODEL_ID_OPT_2_7B = "Salesforce/blip2-opt-2.7b"
20
- MODEL_ID_OPT_6_7B = "Salesforce/blip2-opt-6.7b"
21
- MODEL_ID_FLAN_T5_XL = "Salesforce/blip2-flan-t5-xl"
22
- MODEL_ID_FLAN_T5_XXL = "Salesforce/blip2-flan-t5-xxl"
23
- MODEL_ID = os.getenv("MODEL_ID", MODEL_ID_FLAN_T5_XXL)
24
- if MODEL_ID not in [MODEL_ID_OPT_2_7B, MODEL_ID_OPT_6_7B, MODEL_ID_FLAN_T5_XL, MODEL_ID_FLAN_T5_XXL]:
25
- error_message = f"Invalid MODEL_ID: {MODEL_ID}"
26
- raise ValueError(error_message)
27
-
28
- if torch.cuda.is_available():
29
- processor = AutoProcessor.from_pretrained(MODEL_ID)
30
- model = Blip2ForConditionalGeneration.from_pretrained(
31
- MODEL_ID, device_map="auto", quantization_config=BitsAndBytesConfig(load_in_8bit=True)
32
- )
33
-
34
-
35
- @spaces.GPU
36
- def generate_caption(
37
- image: PIL.Image.Image,
38
- decoding_method: str = "Nucleus sampling",
39
- temperature: float = 1.0,
40
- length_penalty: float = 1.0,
41
- repetition_penalty: float = 1.5,
42
- max_length: int = 50,
43
- min_length: int = 1,
44
- num_beams: int = 5,
45
- top_p: float = 0.9,
46
- ) -> str:
47
- inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
48
- generated_ids = model.generate(
49
- pixel_values=inputs.pixel_values,
50
- do_sample=decoding_method == "Nucleus sampling",
51
- temperature=temperature,
52
- length_penalty=length_penalty,
53
- repetition_penalty=repetition_penalty,
54
- max_length=max_length,
55
- min_length=min_length,
56
- num_beams=num_beams,
57
- top_p=top_p,
58
- )
59
- return processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
60
-
61
-
62
- @spaces.GPU
63
- def answer_question(
64
- image: PIL.Image.Image,
65
- prompt: str,
66
- decoding_method: str = "Nucleus sampling",
67
- temperature: float = 1.0,
68
- length_penalty: float = 1.0,
69
- repetition_penalty: float = 1.5,
70
- max_length: int = 50,
71
- min_length: int = 1,
72
- num_beams: int = 5,
73
- top_p: float = 0.9,
74
- ) -> str:
75
- inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)
76
- generated_ids = model.generate(
77
- **inputs,
78
- do_sample=decoding_method == "Nucleus sampling",
79
- temperature=temperature,
80
- length_penalty=length_penalty,
81
- repetition_penalty=repetition_penalty,
82
- max_length=max_length,
83
- min_length=min_length,
84
- num_beams=num_beams,
85
- top_p=top_p,
86
- )
87
- return processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
88
-
89
-
90
- def postprocess_output(output: str) -> str:
91
- if output and output[-1] not in string.punctuation:
92
- output += "."
93
- return output
94
-
95
-
96
- def chat(
97
- image: PIL.Image.Image,
98
- text: str,
99
- decoding_method: str = "Nucleus sampling",
100
- temperature: float = 1.0,
101
- length_penalty: float = 1.0,
102
- repetition_penalty: float = 1.5,
103
- max_length: int = 50,
104
- min_length: int = 1,
105
- num_beams: int = 5,
106
- top_p: float = 0.9,
107
- history_orig: list[str] | None = None,
108
- history_qa: list[str] | None = None,
109
- ) -> tuple[list[tuple[str, str]], list[str], list[str]]:
110
- history_orig = history_orig or []
111
- history_qa = history_qa or []
112
- history_orig.append(text)
113
- text_qa = f"Question: {text} Answer:"
114
- history_qa.append(text_qa)
115
- prompt = " ".join(history_qa)
116
-
117
- output = answer_question(
118
- image=image,
119
- prompt=prompt,
120
- decoding_method=decoding_method,
121
- temperature=temperature,
122
- length_penalty=length_penalty,
123
- repetition_penalty=repetition_penalty,
124
- max_length=max_length,
125
- min_length=min_length,
126
- num_beams=num_beams,
127
- top_p=top_p,
128
- )
129
- output = postprocess_output(output)
130
- history_orig.append(output)
131
- history_qa.append(output)
132
-
133
- chat_val = list(zip(history_orig[0::2], history_orig[1::2], strict=False))
134
- return chat_val, history_orig, history_qa
135
-
136
-
137
- chat.zerogpu = True # type: ignore
138
-
139
-
140
- examples = [
141
- [
142
- "images/house.png",
143
- "How could someone get out of the house?",
144
- ],
145
- [
146
- "images/flower.jpg",
147
- "What is this flower and where is it's origin?",
148
- ],
149
- [
150
- "images/pizza.jpg",
151
- "What are steps to cook it?",
152
- ],
153
- [
154
- "images/sunset.jpg",
155
- "Here is a romantic message going along the photo:",
156
- ],
157
- [
158
- "images/forbidden_city.webp",
159
- "In what dynasties was this place built?",
160
- ],
161
- ]
162
-
163
- with gr.Blocks(css_paths="style.css") as demo:
164
- gr.Markdown(DESCRIPTION)
165
-
166
- with gr.Group():
167
- image = gr.Image(type="pil")
168
- with gr.Tabs():
169
- with gr.Tab(label="Image Captioning"):
170
- caption_button = gr.Button("Caption it!")
171
- caption_output = gr.Textbox(label="Caption Output", show_label=False, container=False)
172
- with gr.Tab(label="Visual Question Answering"):
173
- chatbot = gr.Chatbot(label="VQA Chat", show_label=False)
174
- history_orig = gr.State(value=[])
175
- history_qa = gr.State(value=[])
176
- vqa_input = gr.Text(label="Chat Input", show_label=False, max_lines=1, container=False)
177
  with gr.Row():
178
- clear_chat_button = gr.Button("Clear")
179
- chat_button = gr.Button("Submit", variant="primary")
180
- with gr.Accordion(label="Advanced settings", open=False):
181
- text_decoding_method = gr.Radio(
182
- label="Text Decoding Method",
183
- choices=["Beam search", "Nucleus sampling"],
184
- value="Nucleus sampling",
185
- )
186
- temperature = gr.Slider(
187
- label="Temperature",
188
- info="Used with nucleus sampling.",
189
- minimum=0.5,
190
- maximum=1.0,
191
- step=0.1,
192
- value=1.0,
193
- )
194
- length_penalty = gr.Slider(
195
- label="Length Penalty",
196
- info="Set to larger for longer sequence, used with beam search.",
197
- minimum=-1.0,
198
- maximum=2.0,
199
- step=0.2,
200
- value=1.0,
201
- )
202
- repetition_penalty = gr.Slider(
203
- label="Repetition Penalty",
204
- info="Larger value prevents repetition.",
205
- minimum=1.0,
206
- maximum=5.0,
207
- step=0.5,
208
- value=1.5,
209
- )
210
- max_length = gr.Slider(
211
- label="Max Length",
212
- minimum=20,
213
- maximum=512,
214
- step=1,
215
- value=50,
216
- )
217
- min_length = gr.Slider(
218
- label="Minimum Length",
219
- minimum=1,
220
- maximum=100,
221
- step=1,
222
- value=1,
223
- )
224
- num_beams = gr.Slider(
225
- label="Number of Beams",
226
- minimum=1,
227
- maximum=10,
228
- step=1,
229
- value=5,
230
- )
231
- top_p = gr.Slider(
232
- label="Top P",
233
- info="Used with nucleus sampling.",
234
- minimum=0.5,
235
- maximum=1.0,
236
- step=0.1,
237
- value=0.9,
238
- )
239
-
240
- gr.Examples(
241
- examples=examples,
242
- inputs=[image, vqa_input],
243
- )
244
-
245
- caption_button.click(
246
- fn=generate_caption,
247
- inputs=[
248
- image,
249
- text_decoding_method,
250
- temperature,
251
- length_penalty,
252
- repetition_penalty,
253
- max_length,
254
- min_length,
255
- num_beams,
256
- top_p,
257
- ],
258
- outputs=caption_output,
259
- api_name="caption",
260
- )
261
-
262
- chat_inputs = [
263
- image,
264
- vqa_input,
265
- text_decoding_method,
266
- temperature,
267
- length_penalty,
268
- repetition_penalty,
269
- max_length,
270
- min_length,
271
- num_beams,
272
- top_p,
273
- history_orig,
274
- history_qa,
275
- ]
276
- chat_outputs = [
277
- chatbot,
278
- history_orig,
279
- history_qa,
280
- ]
281
- vqa_input.submit(
282
- fn=chat,
283
- inputs=chat_inputs,
284
- outputs=chat_outputs,
285
- ).success(
286
- fn=lambda: "",
287
- outputs=vqa_input,
288
- queue=False,
289
- api_name=False,
290
- )
291
- chat_button.click(
292
- fn=chat,
293
- inputs=chat_inputs,
294
- outputs=chat_outputs,
295
- api_name="chat",
296
- ).success(
297
- fn=lambda: "",
298
- outputs=vqa_input,
299
- queue=False,
300
- api_name=False,
301
- )
302
- clear_chat_button.click(
303
- fn=lambda: ("", [], [], []),
304
- inputs=None,
305
- outputs=[
306
- vqa_input,
307
- chatbot,
308
- history_orig,
309
- history_qa,
310
- ],
311
- queue=False,
312
- api_name="clear",
313
- )
314
- image.change(
315
- fn=lambda: ("", [], [], []),
316
- inputs=None,
317
- outputs=[
318
- caption_output,
319
- chatbot,
320
- history_orig,
321
- history_qa,
322
- ],
323
- queue=False,
324
- )
325
 
326
  if __name__ == "__main__":
327
- demo.queue(max_size=10).launch()
 
 
2
 
3
  import os
4
  import string
 
5
  import gradio as gr
6
  import PIL.Image
7
  import spaces
8
  import torch
9
  from transformers import AutoProcessor, BitsAndBytesConfig, Blip2ForConditionalGeneration
10
 
11
+ # 스타일 상수 정의
12
+ CUSTOM_CSS = """
13
+ .container {
14
+ max-width: 1000px;
15
+ margin: auto;
16
+ padding: 2rem;
17
+ background: linear-gradient(to bottom right, #ffffff, #f8f9fa);
18
+ border-radius: 15px;
19
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
20
+ }
21
+
22
+ .title {
23
+ font-size: 2.5rem;
24
+ color: #1a73e8;
25
+ text-align: center;
26
+ margin-bottom: 2rem;
27
+ font-weight: bold;
28
+ }
29
+
30
+ .tab-nav {
31
+ background: #f8f9fa;
32
+ border-radius: 10px;
33
+ padding: 0.5rem;
34
+ margin-bottom: 1rem;
35
+ }
36
+
37
+ .input-box {
38
+ border: 2px solid #e0e0e0;
39
+ border-radius: 8px;
40
+ transition: all 0.3s ease;
41
+ }
42
+
43
+ .input-box:focus {
44
+ border-color: #1a73e8;
45
+ box-shadow: 0 0 0 2px rgba(26, 115, 232, 0.2);
46
+ }
47
+
48
+ .button-primary {
49
+ background: #1a73e8;
50
+ color: white;
51
+ padding: 0.75rem 1.5rem;
52
+ border-radius: 8px;
53
+ border: none;
54
+ cursor: pointer;
55
+ transition: all 0.3s ease;
56
+ }
57
+
58
+ .button-primary:hover {
59
+ background: #1557b0;
60
+ transform: translateY(-1px);
61
+ }
62
+
63
+ .output-box {
64
+ background: #ffffff;
65
+ border-radius: 8px;
66
+ padding: 1rem;
67
+ margin-top: 1rem;
68
+ border: 1px solid #e0e0e0;
69
+ }
70
+
71
+ .chatbot-message {
72
+ padding: 1rem;
73
+ margin: 0.5rem 0;
74
+ border-radius: 8px;
75
+ background: #f8f9fa;
76
+ }
77
+
78
+ .advanced-settings {
79
+ background: #ffffff;
80
+ border-radius: 8px;
81
+ padding: 1rem;
82
+ margin-top: 1rem;
83
+ }
84
+
85
+ .slider-container {
86
+ padding: 0.5rem;
87
+ background: #f8f9fa;
88
+ border-radius: 6px;
89
+ }
90
+ """
91
+
92
+ DESCRIPTION = """
93
+ <div class="title">
94
+ 🖼️ BLIP-2 Visual Intelligence System
95
+ </div>
96
+ <p style='text-align: center; color: #666;'>
97
+ Advanced AI system for image understanding and natural conversation
98
+ </p>
99
+ """
100
 
101
  if not torch.cuda.is_available():
102
+ DESCRIPTION += "\n<p style='color: #dc3545;'>Running on CPU 🥶 This demo requires GPU to function properly.</p>"
103
+
104
+ # 모델 설정 부분은 동일하게 유지...
105
+
106
+ def create_interface():
107
+ with gr.Blocks(css=CUSTOM_CSS) as demo:
108
+ gr.Markdown(DESCRIPTION)
109
+
110
+ with gr.Group(elem_classes="container"):
111
+ with gr.Row():
112
+ with gr.Column(scale=1):
113
+ image = gr.Image(
114
+ type="pil",
115
+ label="Upload Image",
116
+ elem_classes="input-box"
117
+ )
118
+
119
+ with gr.Column(scale=2):
120
+ with gr.Tabs(elem_classes="tab-nav"):
121
+ with gr.Tab(label="✨ Image Captioning"):
122
+ caption_button = gr.Button(
123
+ "Generate Caption",
124
+ elem_classes="button-primary"
125
+ )
126
+ caption_output = gr.Textbox(
127
+ label="Generated Caption",
128
+ elem_classes="output-box"
129
+ )
130
+
131
+ with gr.Tab(label="💭 Visual Q&A"):
132
+ chatbot = gr.Chatbot(
133
+ elem_classes="chatbot-message"
134
+ )
135
+ vqa_input = gr.Textbox(
136
+ placeholder="Ask me anything about the image...",
137
+ elem_classes="input-box"
138
+ )
139
+
140
+ with gr.Row():
141
+ clear_button = gr.Button(
142
+ "Clear Chat",
143
+ elem_classes="button-secondary"
144
+ )
145
+ submit_button = gr.Button(
146
+ "Send Message",
147
+ elem_classes="button-primary"
148
+ )
149
+
150
+ with gr.Accordion("🛠️ Advanced Settings", open=False, elem_classes="advanced-settings"):
151
+ # 고급 설정 컨트롤들...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  with gr.Row():
153
+ with gr.Column():
154
+ text_decoding_method = gr.Radio(
155
+ choices=["Beam search", "Nucleus sampling"],
156
+ value="Nucleus sampling",
157
+ label="Decoding Method"
158
+ )
159
+ temperature = gr.Slider(
160
+ minimum=0.5,
161
+ maximum=1.0,
162
+ value=1.0,
163
+ label="Temperature",
164
+ elem_classes="slider-container"
165
+ )
166
+ with gr.Column():
167
+ length_penalty = gr.Slider(
168
+ minimum=-1.0,
169
+ maximum=2.0,
170
+ value=1.0,
171
+ label="Length Penalty",
172
+ elem_classes="slider-container"
173
+ )
174
+ repetition_penalty = gr.Slider(
175
+ minimum=1.0,
176
+ maximum=5.0,
177
+ value=1.5,
178
+ label="Repetition Penalty",
179
+ elem_classes="slider-container"
180
+ )
181
+
182
+ # 이벤트 핸들러 연결...
183
+
184
+ return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
  if __name__ == "__main__":
187
+ demo = create_interface()
188
+ demo.queue(max_size=10).launch()