prithivMLmods commited on
Commit
e2863bc
·
verified ·
1 Parent(s): b45b66d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +419 -419
app.py CHANGED
@@ -1,420 +1,420 @@
1
- import os
2
- import random
3
- import uuid
4
- import json
5
- import time
6
- import asyncio
7
- from threading import Thread
8
- import base64
9
- from io import BytesIO
10
- import re
11
-
12
- import gradio as gr
13
- import spaces
14
- import torch
15
- import numpy as np
16
- from PIL import Image, ImageDraw
17
- import cv2
18
-
19
- from transformers import (
20
- Qwen2VLForConditionalGeneration,
21
- Qwen2_5_VLForConditionalGeneration,
22
- AutoProcessor,
23
- TextIteratorStreamer,
24
- )
25
- from qwen_vl_utils import process_vision_info
26
-
27
- # Constants for text generation
28
- MAX_MAX_NEW_TOKENS = 2048
29
- DEFAULT_MAX_NEW_TOKENS = 1024
30
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
31
-
32
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
33
-
34
- # Load Camel-Doc-OCR-062825
35
- MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-062825"
36
- processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
37
- model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
38
- MODEL_ID_M,
39
- trust_remote_code=True,
40
- torch_dtype=torch.float16
41
- ).to(device).eval()
42
-
43
- # Load ViLaSR-7B
44
- MODEL_ID_X = "AntResearchNLP/ViLaSR"
45
- processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
46
- model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
47
- MODEL_ID_X,
48
- trust_remote_code=True,
49
- torch_dtype=torch.float16
50
- ).to(device).eval()
51
-
52
- # Load OCRFlux-3B
53
- MODEL_ID_T = "ChatDOC/OCRFlux-3B"
54
- processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True)
55
- model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained(
56
- MODEL_ID_T,
57
- trust_remote_code=True,
58
- torch_dtype=torch.float16
59
- ).to(device).eval()
60
-
61
- # Load ShotVL-7B
62
- MODEL_ID_S = "Vchitect/ShotVL-7B"
63
- processor_s = AutoProcessor.from_pretrained(MODEL_ID_S, trust_remote_code=True)
64
- model_s = Qwen2_5_VLForConditionalGeneration.from_pretrained(
65
- MODEL_ID_S,
66
- trust_remote_code=True,
67
- torch_dtype=torch.float16
68
- ).to(device).eval()
69
-
70
- # Helper functions for object detection
71
- def image_to_base64(image):
72
- """Convert a PIL image to a base64-encoded string."""
73
- buffered = BytesIO()
74
- image.save(buffered, format="PNG")
75
- img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
76
- return img_str
77
-
78
- def draw_bounding_boxes(image, bounding_boxes, outline_color="red", line_width=2):
79
- """Draw bounding boxes on an image."""
80
- draw = ImageDraw.Draw(image)
81
- for box in bounding_boxes:
82
- xmin, ymin, xmax, ymax = box
83
- draw.rectangle([xmin, ymin, xmax, ymax], outline=outline_color, width=line_width)
84
- return image
85
-
86
- def rescale_bounding_boxes(bounding_boxes, original_width, original_height, scaled_width=1000, scaled_height=1000):
87
- """Rescale bounding boxes from normalized (1000x1000) to original image dimensions."""
88
- x_scale = original_width / scaled_width
89
- y_scale = original_height / scaled_height
90
- rescaled_boxes = []
91
- for box in bounding_boxes:
92
- xmin, ymin, xmax, ymax = box
93
- rescaled_box = [
94
- xmin * x_scale,
95
- ymin * y_scale,
96
- xmax * x_scale,
97
- ymax * y_scale
98
- ]
99
- rescaled_boxes.append(rescaled_box)
100
- return rescaled_boxes
101
-
102
- # Default system prompt for object detection
103
- default_system_prompt = (
104
- "You are a helpful assistant to detect objects in images. When asked to detect elements based on a description, "
105
- "you return bounding boxes for all elements in the form of [xmin, ymin, xmax, ymax] with the values being scaled "
106
- "to 512 by 512 pixels. When there are more than one result, answer with a list of bounding boxes in the form "
107
- "of [[xmin, ymin, xmax, ymax], [xmin, ymin, xmax, ymax], ...]."
108
- "Parse only the boxes; don't write unnecessary content."
109
- )
110
-
111
- # Function for object detection
112
- @spaces.GPU
113
- def run_example(image, text_input, system_prompt):
114
- """Detect objects in an image and return bounding box annotations."""
115
- model = model_x
116
- processor = processor_x
117
-
118
- messages = [
119
- {
120
- "role": "user",
121
- "content": [
122
- {"type": "image", "image": f"data:image;base64,{image_to_base64(image)}"},
123
- {"type": "text", "text": system_prompt},
124
- {"type": "text", "text": text_input},
125
- ],
126
- }
127
- ]
128
-
129
- text = processor.apply_chat_template(
130
- messages, tokenize=False, add_generation_prompt=True
131
- )
132
- image_inputs, video_inputs = process_vision_info(messages)
133
- inputs = processor(
134
- text=[text],
135
- images=image_inputs,
136
- videos=video_inputs,
137
- padding=True,
138
- return_tensors="pt",
139
- )
140
- inputs = inputs.to("cuda")
141
-
142
- generated_ids = model.generate(**inputs, max_new_tokens=256)
143
- generated_ids_trimmed = [
144
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
145
- ]
146
- output_text = processor.batch_decode(
147
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
148
- )
149
- pattern = r'\[\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\]'
150
- matches = re.findall(pattern, str(output_text))
151
- parsed_boxes = [[int(num) for num in match] for match in matches]
152
- scaled_boxes = rescale_bounding_boxes(parsed_boxes, image.width, image.height)
153
- annotated_image = draw_bounding_boxes(image.copy(), scaled_boxes)
154
- return output_text[0], str(parsed_boxes), annotated_image
155
-
156
- def downsample_video(video_path):
157
- """
158
- Downsample a video to evenly spaced frames, returning each as a PIL image with its timestamp.
159
- """
160
- vidcap = cv2.VideoCapture(video_path)
161
- total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
162
- fps = vidcap.get(cv2.CAP_PROP_FPS)
163
- frames = []
164
- frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
165
- for i in frame_indices:
166
- vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
167
- success, image = vidcap.read()
168
- if success:
169
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
170
- pil_image = Image.fromarray(image)
171
- timestamp = round(i / fps, 2)
172
- frames.append((pil_image, timestamp))
173
- vidcap.release()
174
- return frames
175
-
176
- @spaces.GPU
177
- def generate_image(model_name: str, text: str, image: Image.Image,
178
- max_new_tokens: int = 1024,
179
- temperature: float = 0.6,
180
- top_p: float = 0.9,
181
- top_k: int = 50,
182
- repetition_penalty: float = 1.2):
183
- """
184
- Generate responses using the selected model for image input.
185
- """
186
- if model_name == "Camel-Doc-OCR-062825":
187
- processor = processor_m
188
- model = model_m
189
- elif model_name == "ViLaSR-7B":
190
- processor = processor_x
191
- model = model_x
192
- elif model_name == "OCRFlux-3B":
193
- processor = processor_t
194
- model = model_t
195
- elif model_name == "ShotVL-7B":
196
- processor = processor_s
197
- model = model_s
198
- else:
199
- yield "Invalid model selected.", "Invalid model selected."
200
- return
201
-
202
- if image is None:
203
- yield "Please upload an image.", "Please upload an image."
204
- return
205
-
206
- messages = [{
207
- "role": "user",
208
- "content": [
209
- {"type": "image", "image": image},
210
- {"type": "text", "text": text},
211
- ]
212
- }]
213
- prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
214
- inputs = processor(
215
- text=[prompt_full],
216
- images=[image],
217
- return_tensors="pt",
218
- padding=True,
219
- truncation=False,
220
- max_length=MAX_INPUT_TOKEN_LENGTH
221
- ).to(device)
222
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
223
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
224
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
225
- thread.start()
226
- buffer = ""
227
- for new_text in streamer:
228
- buffer += new_text
229
- time.sleep(0.01)
230
- yield buffer, buffer
231
-
232
- @spaces.GPU
233
- def generate_video(model_name: str, text: str, video_path: str,
234
- max_new_tokens: int = 1024,
235
- temperature: float = 0.6,
236
- top_p: float = 0.9,
237
- top_k: int = 50,
238
- repetition_penalty: float = 1.2):
239
- """
240
- Generate responses using the selected model for video input.
241
- """
242
- if model_name == "Camel-Doc-OCR-062825":
243
- processor = processor_m
244
- model = model_m
245
- elif model_name == "ViLaSR-7B":
246
- processor = processor_x
247
- model = model_x
248
- elif model_name == "OCRFlux-3B":
249
- processor = processor_t
250
- model = model_t
251
- elif model_name == "ShotVL-7B":
252
- processor = processor_s
253
- model = model_s
254
- else:
255
- yield "Invalid model selected.", "Invalid model selected."
256
- return
257
-
258
- if video_path is None:
259
- yield "Please upload a video.", "Please upload a video."
260
- return
261
-
262
- frames = downsample_video(video_path)
263
- messages = [
264
- {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
265
- {"role": "user", "content": [{"type": "text", "text": text}]}
266
- ]
267
- for frame in frames:
268
- image, timestamp = frame
269
- messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
270
- messages[1]["content"].append({"type": "image", "image": image})
271
- inputs = processor.apply_chat_template(
272
- messages,
273
- tokenize=True,
274
- add_generation_prompt=True,
275
- return_dict=True,
276
- return_tensors="pt",
277
- truncation=False,
278
- max_length=MAX_INPUT_TOKEN_LENGTH
279
- ).to(device)
280
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
281
- generation_kwargs = {
282
- **inputs,
283
- "streamer": streamer,
284
- "max_new_tokens": max_new_tokens,
285
- "do_sample": True,
286
- "temperature": temperature,
287
- "top_p": top_p,
288
- "top_k": top_k,
289
- "repetition_penalty": repetition_penalty,
290
- }
291
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
292
- thread.start()
293
- buffer = ""
294
- for new_text in streamer:
295
- buffer += new_text
296
- buffer = buffer.replace("<|im_end|>", "")
297
- time.sleep(0.01)
298
- yield buffer, buffer
299
-
300
- # Define examples for image, video, and object detection inference
301
- image_examples = [
302
- ["convert this page to doc [text] precisely for markdown.", "images/1.png"],
303
- ["convert this page to doc [table] precisely for markdown.", "images/2.png"],
304
- ["explain the movie shot in detail.", "images/3.png"],
305
- ["fill the correct numbers.", "images/4.png"]
306
- ]
307
-
308
- video_examples = [
309
- ["explain the ad video in detail.", "videos/1.mp4"],
310
- ["explain the video in detail.", "videos/2.mp4"]
311
- ]
312
-
313
- object_detection_examples = [
314
- ["object/1.png", "detect red and yellow cars."],
315
- ["object/2.png", "detect the white cat."]
316
- ]
317
-
318
- # Added CSS to style the output area as a "Canvas"
319
- css = """
320
- .submit-btn {
321
- background-color: #2980b9 !important;
322
- color: white !important;
323
- }
324
- .submit-btn:hover {
325
- background-color: #3498db !important;
326
- }
327
- .canvas-output {
328
- border: 2px solid #4682B4;
329
- border-radius: 10px;
330
- padding: 20px;
331
- }
332
- """
333
-
334
- # Create the Gradio Interface
335
- with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
336
- gr.Markdown("# **[Doc VLMs v2 [Localization]](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**")
337
- with gr.Row():
338
- with gr.Column():
339
- with gr.Tabs():
340
- with gr.TabItem("Image Inference"):
341
- image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
342
- image_upload = gr.Image(type="pil", label="Image")
343
- image_submit = gr.Button("Submit", elem_classes="submit-btn")
344
- gr.Examples(
345
- examples=image_examples,
346
- inputs=[image_query, image_upload]
347
- )
348
- with gr.TabItem("Video Inference"):
349
- video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
350
- video_upload = gr.Video(label="Video")
351
- video_submit = gr.Button("Submit", elem_classes="submit-btn")
352
- gr.Examples(
353
- examples=video_examples,
354
- inputs=[video_query, video_upload]
355
- )
356
- with gr.TabItem("Object Detection / Localization"):
357
- with gr.Row():
358
- with gr.Column():
359
- input_img = gr.Image(label="Input Image", type="pil")
360
- system_prompt = gr.Textbox(label="System Prompt", value=default_system_prompt, visible=False)
361
- text_input = gr.Textbox(label="Query Input")
362
- submit_btn = gr.Button(value="Submit", elem_classes="submit-btn")
363
- with gr.Column():
364
- model_output_text = gr.Textbox(label="Model Output Text")
365
- parsed_boxes = gr.Textbox(label="Parsed Boxes")
366
- annotated_image = gr.Image(label="Annotated Image")
367
-
368
- gr.Examples(
369
- examples=object_detection_examples,
370
- inputs=[input_img, text_input],
371
- outputs=[model_output_text, parsed_boxes, annotated_image],
372
- fn=run_example,
373
- cache_examples=True,
374
- )
375
-
376
- submit_btn.click(
377
- fn=run_example,
378
- inputs=[input_img, text_input, system_prompt],
379
- outputs=[model_output_text, parsed_boxes, annotated_image]
380
- )
381
-
382
- with gr.Accordion("Advanced options", open=False):
383
- max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
384
- temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
385
- top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
386
- top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
387
- repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
388
-
389
- with gr.Column():
390
- with gr.Column(elem_classes="canvas-output"):
391
- gr.Markdown("## Result.Md")
392
- output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=2)
393
- markdown_output = gr.Markdown(label="Formatted Result (Result.Md)")
394
-
395
- model_choice = gr.Radio(
396
- choices=["Camel-Doc-OCR-062825", "OCRFlux-3B", "ShotVL-7B", "ViLaSR-7B"],
397
- label="Select Model",
398
- value="Camel-Doc-OCR-062825"
399
- )
400
-
401
- gr.Markdown("**Model Info 💻** | [Report Bug](https://huggingface.co/spaces/prithivMLmods/Doc-VLMs-v2-Localization/discussions)")
402
- gr.Markdown("> [Camel-Doc-OCR-062825](https://huggingface.co/prithivMLmods/Camel-Doc-OCR-062825) : camel-doc-ocr-062825 model is a fine-tuned version of qwen2.5-vl-7b-instruct, optimized for document retrieval, content extraction, and analysis recognition. built on top of the qwen2.5-vl architecture, this model enhances document comprehension capabilities.")
403
- gr.Markdown("> [OCRFlux-3B](https://huggingface.co/ChatDOC/OCRFlux-3B) : ocrflux-3b model that's fine-tuned from qwen2.5-vl-3b-instruct using our private document datasets and some data from olmocr-mix-0225 dataset. optimized for document retrieval, content extraction, and analysis recognition. the best way to use this model is via the ocrflux toolkit.")
404
- gr.Markdown("> [ViLaSR](https://huggingface.co/AntResearchNLP/ViLaSR) : vilasr-7b model as presented in reinforcing spatial reasoning in vision-language models with interwoven thinking and visual drawing. efficient reasoning capabilities.")
405
- gr.Markdown("> [ShotVL-7B](https://huggingface.co/Vchitect/ShotVL-7B) : shotvl-7b is a fine-tuned version of qwen2.5-vl-7b-instruct, trained by supervised fine-tuning on the largest and high-quality dataset for cinematic language understanding to date. it currently achieves state-of-the-art performance on shotbench.")
406
- gr.Markdown(">⚠️note: all the models in space are not guaranteed to perform well in video inference use cases.")
407
-
408
- image_submit.click(
409
- fn=generate_image,
410
- inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
411
- outputs=[output, markdown_output]
412
- )
413
- video_submit.click(
414
- fn=generate_video,
415
- inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
416
- outputs=[output, markdown_output]
417
- )
418
-
419
- if __name__ == "__main__":
420
  demo.queue(max_size=30).launch(share=True, mcp_server=True, ssr_mode=False, show_error=True)
 
1
+ import os
2
+ import random
3
+ import uuid
4
+ import json
5
+ import time
6
+ import asyncio
7
+ from threading import Thread
8
+ import base64
9
+ from io import BytesIO
10
+ import re
11
+
12
+ import gradio as gr
13
+ import spaces
14
+ import torch
15
+ import numpy as np
16
+ from PIL import Image, ImageDraw
17
+ import cv2
18
+
19
+ from transformers import (
20
+ Qwen2VLForConditionalGeneration,
21
+ Qwen2_5_VLForConditionalGeneration,
22
+ AutoProcessor,
23
+ TextIteratorStreamer,
24
+ )
25
+ from qwen_vl_utils import process_vision_info
26
+
27
+ # Constants for text generation
28
+ MAX_MAX_NEW_TOKENS = 2048
29
+ DEFAULT_MAX_NEW_TOKENS = 1024
30
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
31
+
32
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
33
+
34
+ # Load Camel-Doc-OCR-062825
35
+ MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-062825"
36
+ processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
37
+ model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
38
+ MODEL_ID_M,
39
+ trust_remote_code=True,
40
+ torch_dtype=torch.float16
41
+ ).to(device).eval()
42
+
43
+ # Load ViLaSR-7B
44
+ MODEL_ID_X = "AntResearchNLP/ViLaSR"
45
+ processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
46
+ model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
47
+ MODEL_ID_X,
48
+ trust_remote_code=True,
49
+ torch_dtype=torch.float16
50
+ ).to(device).eval()
51
+
52
+ # Load OCRFlux-3B
53
+ MODEL_ID_T = "ChatDOC/OCRFlux-3B"
54
+ processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True)
55
+ model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained(
56
+ MODEL_ID_T,
57
+ trust_remote_code=True,
58
+ torch_dtype=torch.float16
59
+ ).to(device).eval()
60
+
61
+ # Load ShotVL-7B
62
+ MODEL_ID_S = "Vchitect/ShotVL-7B"
63
+ processor_s = AutoProcessor.from_pretrained(MODEL_ID_S, trust_remote_code=True)
64
+ model_s = Qwen2_5_VLForConditionalGeneration.from_pretrained(
65
+ MODEL_ID_S,
66
+ trust_remote_code=True,
67
+ torch_dtype=torch.float16
68
+ ).to(device).eval()
69
+
70
+ # Helper functions for object detection
71
+ def image_to_base64(image):
72
+ """Convert a PIL image to a base64-encoded string."""
73
+ buffered = BytesIO()
74
+ image.save(buffered, format="PNG")
75
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
76
+ return img_str
77
+
78
+ def draw_bounding_boxes(image, bounding_boxes, outline_color="red", line_width=2):
79
+ """Draw bounding boxes on an image."""
80
+ draw = ImageDraw.Draw(image)
81
+ for box in bounding_boxes:
82
+ xmin, ymin, xmax, ymax = box
83
+ draw.rectangle([xmin, ymin, xmax, ymax], outline=outline_color, width=line_width)
84
+ return image
85
+
86
+ def rescale_bounding_boxes(bounding_boxes, original_width, original_height, scaled_width=1000, scaled_height=1000):
87
+ """Rescale bounding boxes from normalized (1000x1000) to original image dimensions."""
88
+ x_scale = original_width / scaled_width
89
+ y_scale = original_height / scaled_height
90
+ rescaled_boxes = []
91
+ for box in bounding_boxes:
92
+ xmin, ymin, xmax, ymax = box
93
+ rescaled_box = [
94
+ xmin * x_scale,
95
+ ymin * y_scale,
96
+ xmax * x_scale,
97
+ ymax * y_scale
98
+ ]
99
+ rescaled_boxes.append(rescaled_box)
100
+ return rescaled_boxes
101
+
102
+ # Default system prompt for object detection
103
+ default_system_prompt = (
104
+ "You are a helpful assistant to detect objects in images. When asked to detect elements based on a description, "
105
+ "you return bounding boxes for all elements in the form of [xmin, ymin, xmax, ymax] with the values being scaled "
106
+ "to 512 by 512 pixels. When there are more than one result, answer with a list of bounding boxes in the form "
107
+ "of [[xmin, ymin, xmax, ymax], [xmin, ymin, xmax, ymax], ...]."
108
+ "Parse only the boxes; don't write unnecessary content. Follow this command strictly at all times."
109
+ )
110
+
111
+ # Function for object detection
112
+ @spaces.GPU
113
+ def run_example(image, text_input, system_prompt):
114
+ """Detect objects in an image and return bounding box annotations."""
115
+ model = model_x
116
+ processor = processor_x
117
+
118
+ messages = [
119
+ {
120
+ "role": "user",
121
+ "content": [
122
+ {"type": "image", "image": f"data:image;base64,{image_to_base64(image)}"},
123
+ {"type": "text", "text": system_prompt},
124
+ {"type": "text", "text": text_input},
125
+ ],
126
+ }
127
+ ]
128
+
129
+ text = processor.apply_chat_template(
130
+ messages, tokenize=False, add_generation_prompt=True
131
+ )
132
+ image_inputs, video_inputs = process_vision_info(messages)
133
+ inputs = processor(
134
+ text=[text],
135
+ images=image_inputs,
136
+ videos=video_inputs,
137
+ padding=True,
138
+ return_tensors="pt",
139
+ )
140
+ inputs = inputs.to("cuda")
141
+
142
+ generated_ids = model.generate(**inputs, max_new_tokens=256)
143
+ generated_ids_trimmed = [
144
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
145
+ ]
146
+ output_text = processor.batch_decode(
147
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
148
+ )
149
+ pattern = r'\[\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\]'
150
+ matches = re.findall(pattern, str(output_text))
151
+ parsed_boxes = [[int(num) for num in match] for match in matches]
152
+ scaled_boxes = rescale_bounding_boxes(parsed_boxes, image.width, image.height)
153
+ annotated_image = draw_bounding_boxes(image.copy(), scaled_boxes)
154
+ return output_text[0], str(parsed_boxes), annotated_image
155
+
156
+ def downsample_video(video_path):
157
+ """
158
+ Downsample a video to evenly spaced frames, returning each as a PIL image with its timestamp.
159
+ """
160
+ vidcap = cv2.VideoCapture(video_path)
161
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
162
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
163
+ frames = []
164
+ frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
165
+ for i in frame_indices:
166
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
167
+ success, image = vidcap.read()
168
+ if success:
169
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
170
+ pil_image = Image.fromarray(image)
171
+ timestamp = round(i / fps, 2)
172
+ frames.append((pil_image, timestamp))
173
+ vidcap.release()
174
+ return frames
175
+
176
+ @spaces.GPU
177
+ def generate_image(model_name: str, text: str, image: Image.Image,
178
+ max_new_tokens: int = 1024,
179
+ temperature: float = 0.6,
180
+ top_p: float = 0.9,
181
+ top_k: int = 50,
182
+ repetition_penalty: float = 1.2):
183
+ """
184
+ Generate responses using the selected model for image input.
185
+ """
186
+ if model_name == "Camel-Doc-OCR-062825":
187
+ processor = processor_m
188
+ model = model_m
189
+ elif model_name == "ViLaSR-7B":
190
+ processor = processor_x
191
+ model = model_x
192
+ elif model_name == "OCRFlux-3B":
193
+ processor = processor_t
194
+ model = model_t
195
+ elif model_name == "ShotVL-7B":
196
+ processor = processor_s
197
+ model = model_s
198
+ else:
199
+ yield "Invalid model selected.", "Invalid model selected."
200
+ return
201
+
202
+ if image is None:
203
+ yield "Please upload an image.", "Please upload an image."
204
+ return
205
+
206
+ messages = [{
207
+ "role": "user",
208
+ "content": [
209
+ {"type": "image", "image": image},
210
+ {"type": "text", "text": text},
211
+ ]
212
+ }]
213
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
214
+ inputs = processor(
215
+ text=[prompt_full],
216
+ images=[image],
217
+ return_tensors="pt",
218
+ padding=True,
219
+ truncation=False,
220
+ max_length=MAX_INPUT_TOKEN_LENGTH
221
+ ).to(device)
222
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
223
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
224
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
225
+ thread.start()
226
+ buffer = ""
227
+ for new_text in streamer:
228
+ buffer += new_text
229
+ time.sleep(0.01)
230
+ yield buffer, buffer
231
+
232
+ @spaces.GPU
233
+ def generate_video(model_name: str, text: str, video_path: str,
234
+ max_new_tokens: int = 1024,
235
+ temperature: float = 0.6,
236
+ top_p: float = 0.9,
237
+ top_k: int = 50,
238
+ repetition_penalty: float = 1.2):
239
+ """
240
+ Generate responses using the selected model for video input.
241
+ """
242
+ if model_name == "Camel-Doc-OCR-062825":
243
+ processor = processor_m
244
+ model = model_m
245
+ elif model_name == "ViLaSR-7B":
246
+ processor = processor_x
247
+ model = model_x
248
+ elif model_name == "OCRFlux-3B":
249
+ processor = processor_t
250
+ model = model_t
251
+ elif model_name == "ShotVL-7B":
252
+ processor = processor_s
253
+ model = model_s
254
+ else:
255
+ yield "Invalid model selected.", "Invalid model selected."
256
+ return
257
+
258
+ if video_path is None:
259
+ yield "Please upload a video.", "Please upload a video."
260
+ return
261
+
262
+ frames = downsample_video(video_path)
263
+ messages = [
264
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
265
+ {"role": "user", "content": [{"type": "text", "text": text}]}
266
+ ]
267
+ for frame in frames:
268
+ image, timestamp = frame
269
+ messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
270
+ messages[1]["content"].append({"type": "image", "image": image})
271
+ inputs = processor.apply_chat_template(
272
+ messages,
273
+ tokenize=True,
274
+ add_generation_prompt=True,
275
+ return_dict=True,
276
+ return_tensors="pt",
277
+ truncation=False,
278
+ max_length=MAX_INPUT_TOKEN_LENGTH
279
+ ).to(device)
280
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
281
+ generation_kwargs = {
282
+ **inputs,
283
+ "streamer": streamer,
284
+ "max_new_tokens": max_new_tokens,
285
+ "do_sample": True,
286
+ "temperature": temperature,
287
+ "top_p": top_p,
288
+ "top_k": top_k,
289
+ "repetition_penalty": repetition_penalty,
290
+ }
291
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
292
+ thread.start()
293
+ buffer = ""
294
+ for new_text in streamer:
295
+ buffer += new_text
296
+ buffer = buffer.replace("<|im_end|>", "")
297
+ time.sleep(0.01)
298
+ yield buffer, buffer
299
+
300
+ # Define examples for image, video, and object detection inference
301
+ image_examples = [
302
+ ["convert this page to doc [text] precisely for markdown.", "images/1.png"],
303
+ ["convert this page to doc [table] precisely for markdown.", "images/2.png"],
304
+ ["explain the movie shot in detail.", "images/3.png"],
305
+ ["fill the correct numbers.", "images/4.png"]
306
+ ]
307
+
308
+ video_examples = [
309
+ ["explain the ad video in detail.", "videos/1.mp4"],
310
+ ["explain the video in detail.", "videos/2.mp4"]
311
+ ]
312
+
313
+ object_detection_examples = [
314
+ ["object/1.png", "detect red and yellow cars."],
315
+ ["object/2.png", "detect the white cat."]
316
+ ]
317
+
318
+ # Added CSS to style the output area as a "Canvas"
319
+ css = """
320
+ .submit-btn {
321
+ background-color: #2980b9 !important;
322
+ color: white !important;
323
+ }
324
+ .submit-btn:hover {
325
+ background-color: #3498db !important;
326
+ }
327
+ .canvas-output {
328
+ border: 2px solid #4682B4;
329
+ border-radius: 10px;
330
+ padding: 20px;
331
+ }
332
+ """
333
+
334
+ # Create the Gradio Interface
335
+ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
336
+ gr.Markdown("# **[Doc VLMs v2 [Localization]](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**")
337
+ with gr.Row():
338
+ with gr.Column():
339
+ with gr.Tabs():
340
+ with gr.TabItem("Image Inference"):
341
+ image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
342
+ image_upload = gr.Image(type="pil", label="Image")
343
+ image_submit = gr.Button("Submit", elem_classes="submit-btn")
344
+ gr.Examples(
345
+ examples=image_examples,
346
+ inputs=[image_query, image_upload]
347
+ )
348
+ with gr.TabItem("Video Inference"):
349
+ video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
350
+ video_upload = gr.Video(label="Video")
351
+ video_submit = gr.Button("Submit", elem_classes="submit-btn")
352
+ gr.Examples(
353
+ examples=video_examples,
354
+ inputs=[video_query, video_upload]
355
+ )
356
+ with gr.TabItem("Object Detection / Localization"):
357
+ with gr.Row():
358
+ with gr.Column():
359
+ input_img = gr.Image(label="Input Image", type="pil")
360
+ system_prompt = gr.Textbox(label="System Prompt", value=default_system_prompt, visible=False)
361
+ text_input = gr.Textbox(label="Query Input")
362
+ submit_btn = gr.Button(value="Submit", elem_classes="submit-btn")
363
+ with gr.Column():
364
+ model_output_text = gr.Textbox(label="Model Output Text")
365
+ parsed_boxes = gr.Textbox(label="Parsed Boxes")
366
+ annotated_image = gr.Image(label="Annotated Image")
367
+
368
+ gr.Examples(
369
+ examples=object_detection_examples,
370
+ inputs=[input_img, text_input],
371
+ outputs=[model_output_text, parsed_boxes, annotated_image],
372
+ fn=run_example,
373
+ cache_examples=True,
374
+ )
375
+
376
+ submit_btn.click(
377
+ fn=run_example,
378
+ inputs=[input_img, text_input, system_prompt],
379
+ outputs=[model_output_text, parsed_boxes, annotated_image]
380
+ )
381
+
382
+ with gr.Accordion("Advanced options", open=False):
383
+ max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
384
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
385
+ top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
386
+ top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
387
+ repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
388
+
389
+ with gr.Column():
390
+ with gr.Column(elem_classes="canvas-output"):
391
+ gr.Markdown("## Result.Md")
392
+ output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=2)
393
+ markdown_output = gr.Markdown(label="Formatted Result (Result.Md)")
394
+
395
+ model_choice = gr.Radio(
396
+ choices=["Camel-Doc-OCR-062825", "OCRFlux-3B", "ShotVL-7B", "ViLaSR-7B"],
397
+ label="Select Model",
398
+ value="Camel-Doc-OCR-062825"
399
+ )
400
+
401
+ gr.Markdown("**Model Info 💻** | [Report Bug](https://huggingface.co/spaces/prithivMLmods/Doc-VLMs-v2-Localization/discussions)")
402
+ gr.Markdown("> [Camel-Doc-OCR-062825](https://huggingface.co/prithivMLmods/Camel-Doc-OCR-062825) : camel-doc-ocr-062825 model is a fine-tuned version of qwen2.5-vl-7b-instruct, optimized for document retrieval, content extraction, and analysis recognition. built on top of the qwen2.5-vl architecture, this model enhances document comprehension capabilities.")
403
+ gr.Markdown("> [OCRFlux-3B](https://huggingface.co/ChatDOC/OCRFlux-3B) : ocrflux-3b model that's fine-tuned from qwen2.5-vl-3b-instruct using our private document datasets and some data from olmocr-mix-0225 dataset. optimized for document retrieval, content extraction, and analysis recognition. the best way to use this model is via the ocrflux toolkit.")
404
+ gr.Markdown("> [ViLaSR](https://huggingface.co/AntResearchNLP/ViLaSR) : vilasr-7b model as presented in reinforcing spatial reasoning in vision-language models with interwoven thinking and visual drawing. efficient reasoning capabilities.")
405
+ gr.Markdown("> [ShotVL-7B](https://huggingface.co/Vchitect/ShotVL-7B) : shotvl-7b is a fine-tuned version of qwen2.5-vl-7b-instruct, trained by supervised fine-tuning on the largest and high-quality dataset for cinematic language understanding to date. it currently achieves state-of-the-art performance on shotbench.")
406
+ gr.Markdown(">⚠️note: all the models in space are not guaranteed to perform well in video inference use cases.")
407
+
408
+ image_submit.click(
409
+ fn=generate_image,
410
+ inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
411
+ outputs=[output, markdown_output]
412
+ )
413
+ video_submit.click(
414
+ fn=generate_video,
415
+ inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
416
+ outputs=[output, markdown_output]
417
+ )
418
+
419
+ if __name__ == "__main__":
420
  demo.queue(max_size=30).launch(share=True, mcp_server=True, ssr_mode=False, show_error=True)