prithivMLmods commited on
Commit
0a7677c
·
verified ·
1 Parent(s): e1ef66b

removed video inference settings

Browse files
Files changed (1) hide show
  1. app.py +17 -142
app.py CHANGED
@@ -3,7 +3,6 @@ import random
3
  import uuid
4
  import json
5
  import time
6
- import asyncio
7
  from threading import Thread
8
  from typing import Iterable
9
 
@@ -12,7 +11,6 @@ import spaces
12
  import torch
13
  import numpy as np
14
  from PIL import Image
15
- import cv2
16
 
17
  from transformers import (
18
  Qwen2VLForConditionalGeneration,
@@ -158,7 +156,6 @@ div.no-padding { padding: 0 !important; }
158
  """
159
 
160
  # Constants for text generation
161
- MAX_MAX_NEW_TOKENS = 2048
162
  DEFAULT_MAX_NEW_TOKENS = 1024
163
  # Increased max_length to accommodate more complex inputs, especially with multiple images
164
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
@@ -208,7 +205,7 @@ model_a = AutoModelForImageTextToText.from_pretrained(
208
  MODEL_ID_W = "allenai/olmOCR-7B-0725"
209
  processor_w = AutoProcessor.from_pretrained(MODEL_ID_W, trust_remote_code=True)
210
  model_w = Qwen2_5_VLForConditionalGeneration.from_pretrained(
211
- MODEL_ID_W,
212
  trust_remote_code=True,
213
  torch_dtype=torch.float16
214
  ).to(device).eval()
@@ -222,35 +219,9 @@ model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
222
  torch_dtype=torch.float16
223
  ).to(device).eval()
224
 
225
- def downsample_video(video_path):
226
- """
227
- Downsamples the video to evenly spaced frames.
228
- Each frame is returned as a PIL image along with its timestamp.
229
- """
230
- vidcap = cv2.VideoCapture(video_path)
231
- total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
232
- fps = vidcap.get(cv2.CAP_PROP_FPS)
233
- frames = []
234
- # Use a maximum of 10 frames to avoid excessive memory usage
235
- frame_indices = np.linspace(0, total_frames - 1, min(total_frames, 10), dtype=int)
236
- for i in frame_indices:
237
- vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
238
- success, image = vidcap.read()
239
- if success:
240
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
241
- pil_image = Image.fromarray(image)
242
- timestamp = round(i / fps, 2)
243
- frames.append((pil_image, timestamp))
244
- vidcap.release()
245
- return frames
246
 
247
  @spaces.GPU
248
- def generate_image(model_name: str, text: str, image: Image.Image,
249
- max_new_tokens: int = 1024,
250
- temperature: float = 0.6,
251
- top_p: float = 0.9,
252
- top_k: int = 50,
253
- repetition_penalty: float = 1.2):
254
  """
255
  Generates responses using the selected model for image input.
256
  Yields raw text and Markdown-formatted text.
@@ -286,8 +257,7 @@ def generate_image(model_name: str, text: str, image: Image.Image,
286
  ]
287
  }]
288
  prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
289
-
290
- # FIX: Set truncation to False to avoid the ValueError
291
  inputs = processor(
292
  text=[prompt_full],
293
  images=[image],
@@ -296,7 +266,7 @@ def generate_image(model_name: str, text: str, image: Image.Image,
296
  ).to(device)
297
 
298
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
299
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
300
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
301
  thread.start()
302
  buffer = ""
@@ -306,138 +276,43 @@ def generate_image(model_name: str, text: str, image: Image.Image,
306
  time.sleep(0.01)
307
  yield buffer, buffer
308
 
309
- @spaces.GPU
310
- def generate_video(model_name: str, text: str, video_path: str,
311
- max_new_tokens: int = 1024,
312
- temperature: float = 0.6,
313
- top_p: float = 0.9,
314
- top_k: int = 50,
315
- repetition_penalty: float = 1.2):
316
- """
317
- Generates responses using the selected model for video input.
318
- Yields raw text and Markdown-formatted text.
319
- """
320
- if model_name == "RolmOCR-7B":
321
- processor = processor_m
322
- model = model_m
323
- elif model_name == "Qwen2-VL-OCR-2B":
324
- processor = processor_x
325
- model = model_x
326
- elif model_name == "Nanonets-OCR2-3B":
327
- processor = processor_v
328
- model = model_v
329
- elif model_name == "Aya-Vision-8B":
330
- processor = processor_a
331
- model = model_a
332
- elif model_name == "olmOCR-7B-0725":
333
- processor = processor_w
334
- model = model_w
335
- else:
336
- yield "Invalid model selected.", "Invalid model selected."
337
- return
338
-
339
- if video_path is None:
340
- yield "Please upload a video.", "Please upload a video."
341
- return
342
-
343
- frames_with_ts = downsample_video(video_path)
344
- images_for_processor = [frame for frame, ts in frames_with_ts]
345
 
346
- messages = [{"role": "user", "content": [{"type": "text", "text": text}]}]
347
- for frame in images_for_processor:
348
- messages[0]["content"].insert(0, {"type": "image"})
349
-
350
- prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
351
-
352
- inputs = processor(
353
- text=[prompt_full],
354
- images=images_for_processor,
355
- return_tensors="pt",
356
- padding=True
357
- ).to(device)
358
-
359
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
360
- generation_kwargs = {
361
- **inputs,
362
- "streamer": streamer,
363
- "max_new_tokens": max_new_tokens,
364
- "do_sample": True,
365
- "temperature": temperature,
366
- "top_p": top_p,
367
- "top_k": top_k,
368
- "repetition_penalty": repetition_penalty,
369
- }
370
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
371
- thread.start()
372
- buffer = ""
373
- for new_text in streamer:
374
- buffer += new_text
375
- buffer = buffer.replace("<|im_end|>", "")
376
- time.sleep(0.01)
377
- yield buffer, buffer
378
-
379
- # Define examples for image and video inference
380
  image_examples = [
381
- ["Extract the full page.", "images/ocr.png"],
382
- ["Extract the content.", "images/4.png"],
383
  ["Convert this page to doc [table] precisely for markdown.", "images/0.png"]
384
  ]
385
 
386
- video_examples = [
387
- ["Explain the Ad in Detail.", "videos/1.mp4"],
388
- ]
389
 
390
  # Create the Gradio Interface
391
  with gr.Blocks(css=css, theme=thistle_theme) as demo:
392
  gr.Markdown("# **Multimodal OCR**", elem_id="main-title")
393
  with gr.Row():
394
  with gr.Column(scale=2):
395
- with gr.Tabs():
396
- with gr.TabItem("Image Inference"):
397
- image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
398
- image_upload = gr.Image(type="pil", label="Upload Image", height=290)
399
- image_submit = gr.Button("Submit", variant="primary")
400
- gr.Examples(
401
- examples=image_examples,
402
- inputs=[image_query, image_upload]
403
- )
404
- with gr.TabItem("Video Inference"):
405
- video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
406
- video_upload = gr.Video(label="Upload Video", height=290)
407
- video_submit = gr.Button("Submit", variant="primary")
408
- gr.Examples(
409
- examples=video_examples,
410
- inputs=[video_query, video_upload]
411
- )
412
- gr.Markdown("> Only the olmOCR and RolmOCR models currently support video inference (max video length: 30 secs).")
413
- with gr.Accordion("Advanced options", open=False):
414
- max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
415
- temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
416
- top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
417
- top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
418
- repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
419
-
420
  with gr.Column(scale=3):
421
  gr.Markdown("## Output", elem_id="output-title")
422
  output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=11, show_copy_button=True)
423
  with gr.Accordion("(Result.md)", open=False):
424
  markdown_output = gr.Markdown(label="(Result.Md)")
425
-
426
  model_choice = gr.Radio(
427
  choices=["Nanonets-OCR2-3B", "olmOCR-7B-0725", "RolmOCR-7B",
428
  "Aya-Vision-8B", "Qwen2-VL-OCR-2B"],
429
  label="Select Model",
430
  value="Nanonets-OCR2-3B"
431
  )
432
-
433
  image_submit.click(
434
  fn=generate_image,
435
- inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
436
- outputs=[output, markdown_output]
437
- )
438
- video_submit.click(
439
- fn=generate_video,
440
- inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
441
  outputs=[output, markdown_output]
442
  )
443
 
 
3
  import uuid
4
  import json
5
  import time
 
6
  from threading import Thread
7
  from typing import Iterable
8
 
 
11
  import torch
12
  import numpy as np
13
  from PIL import Image
 
14
 
15
  from transformers import (
16
  Qwen2VLForConditionalGeneration,
 
156
  """
157
 
158
  # Constants for text generation
 
159
  DEFAULT_MAX_NEW_TOKENS = 1024
160
  # Increased max_length to accommodate more complex inputs, especially with multiple images
161
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
 
205
  MODEL_ID_W = "allenai/olmOCR-7B-0725"
206
  processor_w = AutoProcessor.from_pretrained(MODEL_ID_W, trust_remote_code=True)
207
  model_w = Qwen2_5_VLForConditionalGeneration.from_pretrained(
208
+ MODEL_ID_W,
209
  trust_remote_code=True,
210
  torch_dtype=torch.float16
211
  ).to(device).eval()
 
219
  torch_dtype=torch.float16
220
  ).to(device).eval()
221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
  @spaces.GPU
224
+ def generate_image(model_name: str, text: str, image: Image.Image):
 
 
 
 
 
225
  """
226
  Generates responses using the selected model for image input.
227
  Yields raw text and Markdown-formatted text.
 
257
  ]
258
  }]
259
  prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
260
+
 
261
  inputs = processor(
262
  text=[prompt_full],
263
  images=[image],
 
266
  ).to(device)
267
 
268
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
269
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": DEFAULT_MAX_NEW_TOKENS}
270
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
271
  thread.start()
272
  buffer = ""
 
276
  time.sleep(0.01)
277
  yield buffer, buffer
278
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
 
280
+ # Define examples for image inference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  image_examples = [
282
+ ["Extract the full page.", "images/ocr.png"],
283
+ ["Extract the content.", "images/4.png"],
284
  ["Convert this page to doc [table] precisely for markdown.", "images/0.png"]
285
  ]
286
 
 
 
 
287
 
288
  # Create the Gradio Interface
289
  with gr.Blocks(css=css, theme=thistle_theme) as demo:
290
  gr.Markdown("# **Multimodal OCR**", elem_id="main-title")
291
  with gr.Row():
292
  with gr.Column(scale=2):
293
+ image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
294
+ image_upload = gr.Image(type="pil", label="Upload Image", height=290)
295
+ image_submit = gr.Button("Submit", variant="primary")
296
+ gr.Examples(
297
+ examples=image_examples,
298
+ inputs=[image_query, image_upload]
299
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  with gr.Column(scale=3):
301
  gr.Markdown("## Output", elem_id="output-title")
302
  output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=11, show_copy_button=True)
303
  with gr.Accordion("(Result.md)", open=False):
304
  markdown_output = gr.Markdown(label="(Result.Md)")
305
+
306
  model_choice = gr.Radio(
307
  choices=["Nanonets-OCR2-3B", "olmOCR-7B-0725", "RolmOCR-7B",
308
  "Aya-Vision-8B", "Qwen2-VL-OCR-2B"],
309
  label="Select Model",
310
  value="Nanonets-OCR2-3B"
311
  )
312
+
313
  image_submit.click(
314
  fn=generate_image,
315
+ inputs=[model_choice, image_query, image_upload],
 
 
 
 
 
316
  outputs=[output, markdown_output]
317
  )
318