Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import uuid | |
import time | |
import asyncio | |
from threading import Thread | |
import gradio as gr | |
import spaces | |
import torch | |
import numpy as np | |
from PIL import Image | |
import cv2 | |
from transformers import ( | |
Qwen2VLForConditionalGeneration, | |
Qwen2_5_VLForConditionalGeneration, | |
AutoModelForImageTextToText, | |
AutoProcessor, | |
TextIteratorStreamer, | |
) | |
from transformers.image_utils import load_image | |
# Constants | |
MAX_MAX_NEW_TOKENS = 2048 | |
DEFAULT_MAX_NEW_TOKENS = 1024 | |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# Load public OCR models | |
MODEL_ID_V = "nanonets/Nanonets-OCR-s" | |
processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True) | |
model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
MODEL_ID_V, | |
trust_remote_code=True, | |
torch_dtype=torch.bfloat16 | |
).to(device).eval() | |
MODEL_ID_X = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct" | |
processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True) | |
model_x = Qwen2VLForConditionalGeneration.from_pretrained( | |
MODEL_ID_X, trust_remote_code=True, torch_dtype=torch.bfloat16 | |
).to(device).eval() | |
MODEL_ID_M = "reducto/RolmOCR" | |
processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True) | |
model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
MODEL_ID_M, trust_remote_code=True, torch_dtype=torch.bfloat16 | |
).to(device).eval() | |
MODEL_ID_W = "prithivMLmods/Lh41-1042-Magellanic-7B-0711" | |
processor_w = AutoProcessor.from_pretrained(MODEL_ID_W, trust_remote_code=True) | |
model_w = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
MODEL_ID_W, trust_remote_code=True, torch_dtype=torch.bfloat16 | |
).to(device).eval() | |
def downsample_video(video_path): | |
vidcap = cv2.VideoCapture(video_path) | |
total = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
fps = vidcap.get(cv2.CAP_PROP_FPS) | |
frames = [] | |
for i in np.linspace(0, total - 1, 10, dtype=int): | |
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) | |
ok, img = vidcap.read() | |
if ok: | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
frames.append((Image.fromarray(img), round(i / fps, 2))) | |
vidcap.release() | |
return frames | |
def generate_image(model_name, text, image, max_new_tokens, temperature, top_p, top_k, repetition_penalty): | |
mapping = { | |
"Nanonets-OCR-s": (processor_v, model_v), | |
"Qwen2-VL-OCR-2B": (processor_x, model_x), | |
"RolmOCR-7B": (processor_m, model_m), | |
"Lh41-1042-Magellanic-7B-0711": (processor_w, model_w), | |
} | |
if model_name not in mapping: | |
yield "Invalid model selected.", "Invalid model." | |
return | |
processor, model = mapping[model_name] | |
if image is None: | |
yield "Please upload an image.", "" | |
return | |
msg = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": text}]}] | |
prompt = processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) | |
inputs = processor(text=[prompt], images=[image], return_tensors="pt", padding=True).to(device) | |
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) | |
thread = Thread(target=model.generate, kwargs={**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}) | |
thread.start() | |
out = "" | |
for token in streamer: | |
out += token.replace("<|im_end|>", "") | |
time.sleep(0.01) | |
yield out, out | |
def generate_video(model_name, text, video_path, max_new_tokens, temperature, top_p, top_k, repetition_penalty): | |
mapping = { | |
"Nanonets-OCR-s": (processor_v, model_v), | |
"Qwen2-VL-OCR-2B": (processor_x, model_x), | |
"RolmOCR-7B": (processor_m, model_m), | |
"Lh41-1042-Magellanic-7B-0711": (processor_w, model_w), | |
} | |
if model_name not in mapping: | |
yield "Invalid model selected.", "Invalid model." | |
return | |
processor, model = mapping[model_name] | |
if video_path is None: | |
yield "Please upload a video.", "" | |
return | |
frames = downsample_video(video_path) | |
messages = [{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, | |
{"role": "user", "content": [{"type": "text", "text": text}]}] | |
for img, ts in frames: | |
messages[1]["content"].append({"type": "text", "text": f"Frame {ts}:"}) | |
messages[1]["content"].append({"type": "image", "image": img}) | |
inputs = processor.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, | |
return_tensors="pt").to(device) | |
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) | |
thread = Thread(target=model.generate, kwargs={**inputs, | |
"streamer": streamer, | |
"max_new_tokens": max_new_tokens, | |
"temperature": temperature, | |
"top_p": top_p, | |
"top_k": top_k, | |
"repetition_penalty": repetition_penalty}) | |
thread.start() | |
out = "" | |
for token in streamer: | |
out += token.replace("<|im_end|>", "") | |
time.sleep(0.01) | |
yield out, out | |
# Examples | |
image_examples = [ | |
["Extract the content", "images/4.png"], | |
["Explain the scene", "images/3.jpg"], | |
["Perform OCR on the image", "images/1.jpg"], | |
] | |
video_examples = [ | |
["Explain the Ad in Detail", "videos/1.mp4"], | |
] | |
css = """ | |
.submit-btn { background-color: #2980b9 !important; color: white !important; } | |
.submit-btn:hover { background-color: #3498db !important; } | |
.canvas-output { border: 2px solid #4682B4; border-radius: 10px; padding: 20px; } | |
""" | |
with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo: | |
gr.Markdown("# **Multimodal OCR**") | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Tabs(): | |
with gr.TabItem("Image Inference"): | |
img_q = gr.Textbox(label="Query Input", placeholder="Enter prompt") | |
img_up = gr.Image(type="pil", label="Upload Image") | |
img_btn = gr.Button("Submit", elem_classes="submit-btn") | |
gr.Examples(examples=image_examples, inputs=[img_q, img_up]) | |
with gr.TabItem("Video Inference"): | |
vid_q = gr.Textbox(label="Query Input") | |
vid_up = gr.Video(label="Upload Video") | |
vid_btn = gr.Button("Submit", elem_classes="submit-btn") | |
gr.Examples(examples=video_examples, inputs=[vid_q, vid_up]) | |
with gr.Column(elem_classes="canvas-output"): | |
gr.Markdown("## Output") | |
out_raw = gr.Textbox(interactive=False, lines=2, show_copy_button=True) | |
with gr.Accordion("Formatted Output", open=False): | |
out_md = gr.Markdown() | |
model_choice = gr.Radio( | |
choices=["Nanonets-OCR-s", "Qwen2-VL-OCR-2B", "RolmOCR-7B", "Lh41-1042-Magellanic-7B-0711"], | |
label="Select Model", | |
value="Nanonets-OCR-s" | |
) | |
img_btn.click(generate_image, inputs=[model_choice, img_q, img_up, | |
gr.Slider(1, MAX_MAX_NEW_TOKENS, value=DEFAULT_MAX_NEW_TOKENS), | |
gr.Slider(0.1,4.0,value=0.6), | |
gr.Slider(0.05,1.0,value=0.9), | |
gr.Slider(1,1000,value=50), | |
gr.Slider(1.0,2.0,value=1.2)], | |
outputs=[out_raw, out_md]) | |
vid_btn.click(generate_video, inputs=[model_choice, vid_q, vid_up, | |
gr.Slider(1, MAX_MAX_NEW_TOKENS, value=DEFAULT_MAX_NEW_TOKENS), | |
gr.Slider(0.1,4.0,value=0.6), | |
gr.Slider(0.05,1.0,value=0.9), | |
gr.Slider(1,1000,value=50), | |
gr.Slider(1.0,2.0,value=1.2)], | |
outputs=[out_raw, out_md]) | |
if __name__ == "__main__": | |
demo.queue(max_size=30).launch(share=True, mcp_server=True, ssr_mode=False, show_error=True) | |