DocScope-R1 / app.py
prithivMLmods's picture
Update app.py
53f1230 verified
raw
history blame
4.94 kB
import os
import uuid
import time
import asyncio
from threading import Thread
import gradio as gr
import spaces
import torch
import numpy as np
from PIL import Image
import cv2
from transformers import (
Qwen2_5_VLForConditionalGeneration,
AutoProcessor,
TextIteratorStreamer
)
# Constants
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load multimodal processor and model (Callisto OCR3)
MODEL_ID = "nvidia/Cosmos-Reason1-7B"
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID,
trust_remote_code=True,
torch_dtype=torch.float16
).to(device).eval()
def downsample_video(video_path: str, num_frames: int = 10):
vidcap = cv2.VideoCapture(video_path)
total = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = vidcap.get(cv2.CAP_PROP_FPS)
idxs = np.linspace(0, total - 1, num_frames, dtype=int)
frames = []
for i in idxs:
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
ok, img = vidcap.read()
if ok:
rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
pil = Image.fromarray(rgb)
timestamp = round(i / fps, 2)
frames.append((pil, timestamp))
vidcap.release()
return frames
def progress_bar_html(label: str) -> str:
return f'''<div style="display:flex; align-items:center;">
<span style="margin-right:10px; font-size:14px;">{label}</span>
<div style="width:110px; height:5px; background:#B0E0E6; border-radius:2px; overflow:hidden;">
<div style="width:100%; height:100%; background:#00FFFF; animation:load 1.5s linear infinite;"></div>
</div>
</div>
<style>@keyframes load{{0%{{transform:translateX(-100%)}}100%{{transform:translateX(100%)}}}}</style>'''
@spaces.GPU
def generate(prompt: str, files: list[str] = None):
files = files or []
# Determine mode
is_video = any(f.lower().endswith(('.mp4', '.avi', '.mov')) for f in files)
is_image = any(f.lower().endswith(('.jpg', '.png', '.jpeg', '.bmp')) for f in files)
if is_video:
yield progress_bar_html("Processing video with cosmos-reason1")
video = files[0]
frames = downsample_video(video)
# Build messages
messages = [
{"role": "system", "content": [{"type":"text","text":"You are a helpful assistant."}]},
{"role": "user", "content": [{"type":"text","text": prompt}]}
]
for img, ts in frames:
path = f"frame_{uuid.uuid4().hex}.png"
img.save(path)
messages[1]["content"].extend([
{"type":"text","text": f"Frame {ts}:"},
{"type":"image","url": path}
])
inputs = processor.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True,
return_dict=True, return_tensors="pt",
truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH
).to(device)
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
Thread(target=model.generate, kwargs={**inputs, "streamer": streamer}).start()
buffer = ""
for txt in streamer:
buffer += txt.replace("<|im_end|>", "")
time.sleep(0.01)
yield buffer
return
if is_image:
yield progress_bar_html("Processing image with cosmos-reason1")
imgs = [Image.open(f) for f in files]
messages = [
{"role":"user","content":[*[{"type":"image","image":i} for i in imgs],{"type":"text","text":prompt}]}]
prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(
text=[prompt_full], images=imgs,
return_tensors="pt", padding=True,
truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH
).to(device)
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
Thread(target=model.generate, kwargs={**inputs, "streamer": streamer}).start()
out = ""
for txt in streamer:
out += txt.replace("<|im_end|>", "")
time.sleep(0.01)
yield out
return
# No valid media
yield "Please upload at least one image or a video for inference."
def main():
demo = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.File(label="Upload Images/Videos", file_types=["image", "video"], file_count="multiple")
],
description="# **cosmos-reason1 by nvidia**",
textbox=gr.Textbox(label="Prompt"),
cache_examples=False,
type="messages",
multimodal=True,
stop_btn="Stop Generation"
)
demo.queue(max_size=10).launch(share=True)
if __name__ == "__main__":
main()