prithivMLmods commited on
Commit
231ec6e
·
verified ·
1 Parent(s): b851f94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +172 -0
app.py CHANGED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers.image_utils import load_image
3
+ from threading import Thread
4
+ import time
5
+ import torch
6
+ import spaces
7
+ import cv2
8
+ import numpy as np
9
+ from PIL import Image
10
+ from transformers import (
11
+ Qwen2VLForConditionalGeneration,
12
+ AutoProcessor,
13
+ TextIteratorStreamer,
14
+ )
15
+ from transformers import Qwen2_5_VLForConditionalGeneration
16
+
17
+ # Helper Functions
18
+ def progress_bar_html(label: str, primary_color: str = "#4B0082", secondary_color: str = "#9370DB") -> str:
19
+ """
20
+ Returns an HTML snippet for a thin animated progress bar with a label.
21
+ Colors can be customized; default colors are used for Qwen2VL/Aya-Vision.
22
+ """
23
+ return f'''
24
+ <div style="display: flex; align-items: center;">
25
+ <span style="margin-right: 10px; font-size: 14px;">{label}</span>
26
+ <div style="width: 110px; height: 5px; background-color: {secondary_color}; border-radius: 2px; overflow: hidden;">
27
+ <div style="width: 100%; height: 100%; background-color: {primary_color}; animation: loading 1.5s linear infinite;"></div>
28
+ </div>
29
+ </div>
30
+ <style>
31
+ @keyframes loading {{
32
+ 0% {{ transform: translateX(-100%); }}
33
+ 100% {{ transform: translateX(100%); }}
34
+ }}
35
+ </style>
36
+ '''
37
+
38
+ def downsample_video(video_path):
39
+ """
40
+ Downsamples a video file by extracting 10 evenly spaced frames.
41
+ Returns a list of tuples (PIL.Image, timestamp).
42
+ """
43
+ vidcap = cv2.VideoCapture(video_path)
44
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
45
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
46
+ frames = []
47
+ if total_frames <= 0 or fps <= 0:
48
+ vidcap.release()
49
+ return frames
50
+ frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
51
+ for i in frame_indices:
52
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
53
+ success, image = vidcap.read()
54
+ if success:
55
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
56
+ pil_image = Image.fromarray(image)
57
+ timestamp = round(i / fps, 2)
58
+ frames.append((pil_image, timestamp))
59
+ vidcap.release()
60
+ return frames
61
+
62
+ # Model and Processor Setup
63
+ QV_MODEL_ID = "Qwen/Qwen2.5-VL-32B-Instruct"
64
+ qwen_processor = AutoProcessor.from_pretrained(QV_MODEL_ID, trust_remote_code=True)
65
+ qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
66
+ QV_MODEL_ID,
67
+ trust_remote_code=True,
68
+ torch_dtype=torch.float16
69
+ ).to("cuda").eval()
70
+
71
+ COREOCR_MODEL_ID = "prithivMLmods/coreOCR-7B-050325-preview"
72
+ coreocr_processor = AutoProcessor.from_pretrained(COREOCR_MODEL_ID, trust_remote_code=True)
73
+ coreocr_model = Qwen2VLForConditionalGeneration.from_pretrained(
74
+ COREOCR_MODEL_ID,
75
+ trust_remote_code=True,
76
+ torch_dtype=torch.bfloat16
77
+ ).to("cuda").eval()
78
+
79
+ # Main Inference Function
80
+ @spaces.GPU
81
+ def model_inference(message, history, use_coreocr):
82
+ text = message["text"].strip()
83
+ files = message.get("files", [])
84
+
85
+ if not text and not files:
86
+ yield "Error: Please input a text query or provide image or video files."
87
+ return
88
+
89
+ # Process files: images and videos
90
+ image_list = []
91
+ for idx, file in enumerate(files):
92
+ if file.lower().endswith((".mp4", ".avi", ".mov")):
93
+ frames = downsample_video(file)
94
+ if not frames:
95
+ yield "Error: Could not extract frames from the video."
96
+ return
97
+ for frame, timestamp in frames:
98
+ label = f"Video {idx+1} Frame {timestamp}:"
99
+ image_list.append((label, frame))
100
+ else:
101
+ try:
102
+ img = load_image(file)
103
+ label = f"Image {idx+1}:"
104
+ image_list.append((label, img))
105
+ except Exception as e:
106
+ yield f"Error loading image: {str(e)}"
107
+ return
108
+
109
+ # Build content list
110
+ content = [{"type": "text", "text": text}]
111
+ for label, img in image_list:
112
+ content.append({"type": "text", "text": label})
113
+ content.append({"type": "image", "image": img})
114
+
115
+ messages = [{"role": "user", "content": content}]
116
+
117
+ # Select processor and model
118
+ if use_coreocr:
119
+ processor = coreocr_processor
120
+ model = coreocr_model
121
+ model_name = "CoreOCR"
122
+ else:
123
+ processor = qwen_processor
124
+ model = qwen_model
125
+ model_name = "Qwen2VL OCR"
126
+
127
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
128
+ all_images = [item["image"] for item in content if item["type"] == "image"]
129
+ inputs = processor(
130
+ text=[prompt_full],
131
+ images=all_images if all_images else None,
132
+ return_tensors="pt",
133
+ padding=True,
134
+ ).to("cuda")
135
+
136
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
137
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
138
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
139
+ thread.start()
140
+ buffer = ""
141
+ yield progress_bar_html(f"Processing with {model_name}")
142
+ for new_text in streamer:
143
+ buffer += new_text
144
+ buffer = buffer.replace("<|im_end|>", "")
145
+ time.sleep(0.01)
146
+ yield buffer
147
+
148
+ # Gradio Interface
149
+ examples = [
150
+ [{"text": "OCR the text in the image", "files": ["example/image1.jpg"]}],
151
+ [{"text": "Describe the content of the image", "files": ["example/image2.jpg"]}],
152
+ [{"text": "Extract the image content", "files": ["example/image3.jpg"]}],
153
+ ]
154
+
155
+ demo = gr.ChatInterface(
156
+ fn=model_inference,
157
+ description="# **CoreOCR `VL/OCR`**",
158
+ examples=examples,
159
+ textbox=gr.MultimodalTextbox(
160
+ label="Query Input",
161
+ file_types=["image", "video"],
162
+ file_count="multiple",
163
+ placeholder="Input your query and optionally upload image(s) or video(s). Select the model using the checkbox."
164
+ ),
165
+ stop_btn="Stop Generation",
166
+ multimodal=True,
167
+ cache_examples=False,
168
+ theme="bethecloud/storj_theme",
169
+ additional_inputs=[gr.Checkbox(label="Use CoreOCR", value=True, info="Check to use CoreOCR, uncheck to use Qwen2VL OCR")],
170
+ )
171
+
172
+ demo.launch(debug=True, ssr_mode=False)