seawolf2357 commited on
Commit
da87199
·
verified ·
1 Parent(s): 18d18c8

Create app-backup.py

Browse files
Files changed (1) hide show
  1. app-backup.py +349 -0
app-backup.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import re
5
+ import tempfile
6
+ from collections.abc import Iterator
7
+ from threading import Thread
8
+
9
+ import cv2
10
+ import gradio as gr
11
+ import spaces
12
+ import torch
13
+ from loguru import logger
14
+ from PIL import Image
15
+ from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
16
+
17
+ model_id = os.getenv("MODEL_ID", "google/gemma-3-27b-it")
18
+ processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
19
+ model = Gemma3ForConditionalGeneration.from_pretrained(
20
+ model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
21
+ )
22
+
23
+ MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
24
+
25
+
26
+ def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
27
+ image_count = 0
28
+ video_count = 0
29
+ for path in paths:
30
+ if path.endswith(".mp4"):
31
+ video_count += 1
32
+ else:
33
+ image_count += 1
34
+ return image_count, video_count
35
+
36
+
37
+ def count_files_in_history(history: list[dict]) -> tuple[int, int]:
38
+ image_count = 0
39
+ video_count = 0
40
+ for item in history:
41
+ if item["role"] != "user" or isinstance(item["content"], str):
42
+ continue
43
+ if item["content"][0].endswith(".mp4"):
44
+ video_count += 1
45
+ else:
46
+ image_count += 1
47
+ return image_count, video_count
48
+
49
+
50
+ def validate_media_constraints(message: dict, history: list[dict]) -> bool:
51
+ new_image_count, new_video_count = count_files_in_new_message(message["files"])
52
+ history_image_count, history_video_count = count_files_in_history(history)
53
+ image_count = history_image_count + new_image_count
54
+ video_count = history_video_count + new_video_count
55
+ if video_count > 1:
56
+ gr.Warning("Only one video is supported.")
57
+ return False
58
+ if video_count == 1:
59
+ if image_count > 0:
60
+ gr.Warning("Mixing images and videos is not allowed.")
61
+ return False
62
+ if "<image>" in message["text"]:
63
+ gr.Warning("Using <image> tags with video files is not supported.")
64
+ return False
65
+ # TODO: Add frame count validation for videos similar to image count limits # noqa: FIX002, TD002, TD003
66
+ if video_count == 0 and image_count > MAX_NUM_IMAGES:
67
+ gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.")
68
+ return False
69
+ if "<image>" in message["text"] and message["text"].count("<image>") != new_image_count:
70
+ gr.Warning("The number of <image> tags in the text does not match the number of images.")
71
+ return False
72
+ return True
73
+
74
+
75
+ def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
76
+ vidcap = cv2.VideoCapture(video_path)
77
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
78
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
79
+
80
+ frame_interval = int(fps / 3)
81
+ frames = []
82
+
83
+ for i in range(0, total_frames, frame_interval):
84
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
85
+ success, image = vidcap.read()
86
+ if success:
87
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
88
+ pil_image = Image.fromarray(image)
89
+ timestamp = round(i / fps, 2)
90
+ frames.append((pil_image, timestamp))
91
+
92
+ vidcap.release()
93
+ return frames
94
+
95
+
96
+ def process_video(video_path: str) -> list[dict]:
97
+ content = []
98
+ frames = downsample_video(video_path)
99
+ for frame in frames:
100
+ pil_image, timestamp = frame
101
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
102
+ pil_image.save(temp_file.name)
103
+ content.append({"type": "text", "text": f"Frame {timestamp}:"})
104
+ content.append({"type": "image", "url": temp_file.name})
105
+ logger.debug(f"{content=}")
106
+ return content
107
+
108
+
109
+ def process_interleaved_images(message: dict) -> list[dict]:
110
+ logger.debug(f"{message['files']=}")
111
+ parts = re.split(r"(<image>)", message["text"])
112
+ logger.debug(f"{parts=}")
113
+
114
+ content = []
115
+ image_index = 0
116
+ for part in parts:
117
+ logger.debug(f"{part=}")
118
+ if part == "<image>":
119
+ content.append({"type": "image", "url": message["files"][image_index]})
120
+ logger.debug(f"file: {message['files'][image_index]}")
121
+ image_index += 1
122
+ elif part.strip():
123
+ content.append({"type": "text", "text": part.strip()})
124
+ elif isinstance(part, str) and part != "<image>":
125
+ content.append({"type": "text", "text": part})
126
+ logger.debug(f"{content=}")
127
+ return content
128
+
129
+
130
+ def process_new_user_message(message: dict) -> list[dict]:
131
+ if not message["files"]:
132
+ return [{"type": "text", "text": message["text"]}]
133
+
134
+ if message["files"][0].endswith(".mp4"):
135
+ return [{"type": "text", "text": message["text"]}, *process_video(message["files"][0])]
136
+
137
+ if "<image>" in message["text"]:
138
+ return process_interleaved_images(message)
139
+
140
+ return [
141
+ {"type": "text", "text": message["text"]},
142
+ *[{"type": "image", "url": path} for path in message["files"]],
143
+ ]
144
+
145
+
146
+ def process_history(history: list[dict]) -> list[dict]:
147
+ messages = []
148
+ current_user_content: list[dict] = []
149
+ for item in history:
150
+ if item["role"] == "assistant":
151
+ if current_user_content:
152
+ messages.append({"role": "user", "content": current_user_content})
153
+ current_user_content = []
154
+ messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
155
+ else:
156
+ content = item["content"]
157
+ if isinstance(content, str):
158
+ current_user_content.append({"type": "text", "text": content})
159
+ else:
160
+ current_user_content.append({"type": "image", "url": content[0]})
161
+ return messages
162
+
163
+
164
+ @spaces.GPU(duration=120)
165
+ def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
166
+ if not validate_media_constraints(message, history):
167
+ yield ""
168
+ return
169
+
170
+ messages = []
171
+ if system_prompt:
172
+ messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
173
+ messages.extend(process_history(history))
174
+ messages.append({"role": "user", "content": process_new_user_message(message)})
175
+
176
+ inputs = processor.apply_chat_template(
177
+ messages,
178
+ add_generation_prompt=True,
179
+ tokenize=True,
180
+ return_dict=True,
181
+ return_tensors="pt",
182
+ ).to(device=model.device, dtype=torch.bfloat16)
183
+
184
+ streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
185
+ generate_kwargs = dict(
186
+ inputs,
187
+ streamer=streamer,
188
+ max_new_tokens=max_new_tokens,
189
+ )
190
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
191
+ t.start()
192
+
193
+ output = ""
194
+ for delta in streamer:
195
+ output += delta
196
+ yield output
197
+
198
+
199
+ examples = [
200
+ [
201
+ {
202
+ "text": "I need to be in Japan for 10 days, going to Tokyo, Kyoto and Osaka. Think about number of attractions in each of them and allocate number of days to each city. Make public transport recommendations.",
203
+ "files": [],
204
+ }
205
+ ],
206
+ [
207
+ {
208
+ "text": "Write the matplotlib code to generate the same bar chart.",
209
+ "files": ["assets/additional-examples/barchart.png"],
210
+ }
211
+ ],
212
+ [
213
+ {
214
+ "text": "What is odd about this video?",
215
+ "files": ["assets/additional-examples/tmp.mp4"],
216
+ }
217
+ ],
218
+ [
219
+ {
220
+ "text": "I already have this supplement <image> and I want to buy this one <image>. Any warnings I should know about?",
221
+ "files": ["assets/additional-examples/pill1.png", "assets/additional-examples/pill2.png"],
222
+ }
223
+ ],
224
+ [
225
+ {
226
+ "text": "Write a poem inspired by the visual elements of the images.",
227
+ "files": ["assets/sample-images/06-1.png", "assets/sample-images/06-2.png"],
228
+ }
229
+ ],
230
+ [
231
+ {
232
+ "text": "Compose a short musical piece inspired by the visual elements of the images.",
233
+ "files": [
234
+ "assets/sample-images/07-1.png",
235
+ "assets/sample-images/07-2.png",
236
+ "assets/sample-images/07-3.png",
237
+ "assets/sample-images/07-4.png",
238
+ ],
239
+ }
240
+ ],
241
+ [
242
+ {
243
+ "text": "Write a short story about what might have happened in this house.",
244
+ "files": ["assets/sample-images/08.png"],
245
+ }
246
+ ],
247
+ [
248
+ {
249
+ "text": "Create a short story based on the sequence of images.",
250
+ "files": [
251
+ "assets/sample-images/09-1.png",
252
+ "assets/sample-images/09-2.png",
253
+ "assets/sample-images/09-3.png",
254
+ "assets/sample-images/09-4.png",
255
+ "assets/sample-images/09-5.png",
256
+ ],
257
+ }
258
+ ],
259
+ [
260
+ {
261
+ "text": "Describe the creatures that would live in this world.",
262
+ "files": ["assets/sample-images/10.png"],
263
+ }
264
+ ],
265
+ [
266
+ {
267
+ "text": "Read text in the image.",
268
+ "files": ["assets/additional-examples/1.png"],
269
+ }
270
+ ],
271
+ [
272
+ {
273
+ "text": "When is this ticket dated and how much did it cost?",
274
+ "files": ["assets/additional-examples/2.png"],
275
+ }
276
+ ],
277
+ [
278
+ {
279
+ "text": "Read the text in the image into markdown.",
280
+ "files": ["assets/additional-examples/3.png"],
281
+ }
282
+ ],
283
+ [
284
+ {
285
+ "text": "Evaluate this integral.",
286
+ "files": ["assets/additional-examples/4.png"],
287
+ }
288
+ ],
289
+ [
290
+ {
291
+ "text": "caption this image",
292
+ "files": ["assets/sample-images/01.png"],
293
+ }
294
+ ],
295
+ [
296
+ {
297
+ "text": "What's the sign says?",
298
+ "files": ["assets/sample-images/02.png"],
299
+ }
300
+ ],
301
+ [
302
+ {
303
+ "text": "Compare and contrast the two images.",
304
+ "files": ["assets/sample-images/03.png"],
305
+ }
306
+ ],
307
+ [
308
+ {
309
+ "text": "List all the objects in the image and their colors.",
310
+ "files": ["assets/sample-images/04.png"],
311
+ }
312
+ ],
313
+ [
314
+ {
315
+ "text": "Describe the atmosphere of the scene.",
316
+ "files": ["assets/sample-images/05.png"],
317
+ }
318
+ ],
319
+ ]
320
+
321
+ DESCRIPTION = """\
322
+ <img src='https://huggingface.co/spaces/huggingface-projects/gemma-3-12b-it/resolve/main/assets/logo.png' id='logo' />
323
+
324
+ This is a demo of Gemma 3 27B it, a vision language model with outstanding performance on a wide range of tasks.
325
+ You can upload images, interleaved images and videos. Note that video input only supports single-turn conversation and mp4 input.
326
+ """
327
+
328
+ demo = gr.ChatInterface(
329
+ fn=run,
330
+ type="messages",
331
+ chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]),
332
+ textbox=gr.MultimodalTextbox(file_types=["image", ".mp4"], file_count="multiple", autofocus=True),
333
+ multimodal=True,
334
+ additional_inputs=[
335
+ gr.Textbox(label="System Prompt", value="You are a helpful assistant."),
336
+ gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700),
337
+ ],
338
+ stop_btn=False,
339
+ title="Gemma 3 27B IT",
340
+ description=DESCRIPTION,
341
+ examples=examples,
342
+ run_examples_on_click=False,
343
+ cache_examples=False,
344
+ css_paths="style.css",
345
+ delete_cache=(1800, 1800),
346
+ )
347
+
348
+ if __name__ == "__main__":
349
+ demo.launch()