|
import gradio as gr |
|
from transformers.image_utils import load_image |
|
from threading import Thread |
|
import time |
|
import torch |
|
import spaces |
|
import cv2 |
|
import numpy as np |
|
from PIL import Image |
|
from transformers import ( |
|
Qwen2VLForConditionalGeneration, |
|
Qwen2_5_VLForConditionalGeneration, |
|
AutoProcessor, |
|
TextIteratorStreamer, |
|
) |
|
|
|
def progress_bar_html(label: str, primary_color: str = "#FF4500", secondary_color: str = "#FFA07A") -> str: |
|
""" |
|
Returns an HTML snippet for a thin animated progress bar with a label. |
|
Colors can be customized; default colors are used for Qwen2VL/Aya-Vision. |
|
""" |
|
return f''' |
|
<div style="display: flex; align-items: center;"> |
|
<span style="margin-right: 10px; font-size: 14px;">{label}</span> |
|
<div style="width: 110px; height: 5px; background-color: {secondary_color}; border-radius: 2px; overflow: hidden;"> |
|
<div style="width: 100%; height: 100%; background-color: {primary_color}; animation: loading 1.5s linear infinite;"></div> |
|
</div> |
|
</div> |
|
<style> |
|
@keyframes loading {{ |
|
0% {{ transform: translateX(-100%); }} |
|
100% {{ transform: translateX(100%); }} |
|
}} |
|
</style> |
|
''' |
|
|
|
def downsample_video(video_path): |
|
""" |
|
Downsamples a video file by extracting 10 evenly spaced frames. |
|
Returns a list of tuples (PIL.Image, timestamp). |
|
""" |
|
vidcap = cv2.VideoCapture(video_path) |
|
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
fps = vidcap.get(cv2.CAP_PROP_FPS) |
|
frames = [] |
|
if total_frames <= 0 or fps <= 0: |
|
vidcap.release() |
|
return frames |
|
frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int) |
|
for i in frame_indices: |
|
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) |
|
success, image = vidcap.read() |
|
if success: |
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
pil_image = Image.fromarray(image) |
|
timestamp = round(i / fps, 2) |
|
frames.append((pil_image, timestamp)) |
|
vidcap.release() |
|
return frames |
|
|
|
|
|
QV_MODEL_ID = "prithivMLmods/coreOCR-7B-050325-preview" |
|
qwen_processor = AutoProcessor.from_pretrained(QV_MODEL_ID, trust_remote_code=True) |
|
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained( |
|
QV_MODEL_ID, |
|
trust_remote_code=True, |
|
torch_dtype=torch.float16 |
|
).to("cuda").eval() |
|
|
|
COREOCR_MODEL_ID = "prithivMLmods/docscopeOCR-7B-050425-exp" |
|
coreocr_processor = AutoProcessor.from_pretrained(COREOCR_MODEL_ID, trust_remote_code=True) |
|
coreocr_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
|
COREOCR_MODEL_ID, |
|
trust_remote_code=True, |
|
torch_dtype=torch.bfloat16 |
|
).to("cuda").eval() |
|
|
|
|
|
@spaces.GPU |
|
@torch.no_grad() |
|
def model_inference(message, history, use_coreocr): |
|
text = message["text"].strip() |
|
files = message.get("files", []) |
|
|
|
if not text and not files: |
|
yield "Error: Please input a text query or provide image or video files." |
|
return |
|
|
|
|
|
image_list = [] |
|
for idx, file in enumerate(files): |
|
if file.lower().endswith((".mp4", ".avi", ".mov")): |
|
frames = downsample_video(file) |
|
if not frames: |
|
yield "Error: Could not extract frames from the video." |
|
return |
|
for frame, timestamp in frames: |
|
label = f"Video {idx+1} Frame {timestamp}:" |
|
image_list.append((label, frame)) |
|
else: |
|
try: |
|
img = load_image(file) |
|
label = f"Image {idx+1}:" |
|
image_list.append((label, img)) |
|
except Exception as e: |
|
yield f"Error loading image: {str(e)}" |
|
return |
|
|
|
|
|
content = [{"type": "text", "text": text}] |
|
for label, img in image_list: |
|
content.append({"type": "text", "text": label}) |
|
content.append({"type": "image", "image": img}) |
|
|
|
messages = [{"role": "user", "content": content}] |
|
|
|
|
|
if use_coreocr: |
|
processor = coreocr_processor |
|
model = coreocr_model |
|
model_name = "DocScopeOCR" |
|
else: |
|
processor = qwen_processor |
|
model = qwen_model |
|
model_name = "CoreOCR" |
|
|
|
prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
all_images = [item["image"] for item in content if item["type"] == "image"] |
|
inputs = processor( |
|
text=[prompt_full], |
|
images=all_images if all_images else None, |
|
return_tensors="pt", |
|
padding=True, |
|
).to("cuda") |
|
|
|
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) |
|
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024) |
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
buffer = "" |
|
yield progress_bar_html(f"Processing with {model_name}") |
|
for new_text in streamer: |
|
buffer += new_text |
|
buffer = buffer.replace("<|im_end|>", "") |
|
time.sleep(0.01) |
|
yield buffer |
|
|
|
|
|
examples = [ |
|
[{"text": "Validate the worksheet answers", "files": ["example/image1.png"]}], |
|
[{"text": "Explain the scene", "files": ["example/image2.jpg"]}], |
|
[{"text": "Fill the correct numbers", "files": ["example/image3.png"]}], |
|
] |
|
|
|
demo = gr.ChatInterface( |
|
fn=model_inference, |
|
description="# **CoreOCR `VL/OCR`**", |
|
examples=examples, |
|
textbox=gr.MultimodalTextbox( |
|
label="Query Input", |
|
file_types=["image", "video"], |
|
file_count="multiple", |
|
placeholder="Input your query and optionally upload image(s) or video(s). Select the model using the checkbox." |
|
), |
|
stop_btn="Stop Generation", |
|
multimodal=True, |
|
cache_examples=False, |
|
theme="bethecloud/storj_theme", |
|
additional_inputs=[gr.Checkbox(label="Use CoreOCR", value=True, info="Check to use CoreOCR, uncheck to use DocScopeOCR")], |
|
) |
|
|
|
demo.launch(debug=True, ssr_mode=False) |