Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import uuid | |
import time | |
import asyncio | |
from threading import Thread | |
import gradio as gr | |
import spaces | |
import torch | |
import numpy as np | |
from PIL import Image | |
import cv2 | |
from transformers import ( | |
Qwen2_5_VLForConditionalGeneration, | |
AutoProcessor, | |
TextIteratorStreamer | |
) | |
# Constants | |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# Load multimodal processor and model (Callisto OCR3) | |
MODEL_ID = "nvidia/Cosmos-Reason1-7B" | |
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) | |
model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
MODEL_ID, | |
trust_remote_code=True, | |
torch_dtype=torch.float16 | |
).to(device).eval() | |
def downsample_video(video_path: str, num_frames: int = 10): | |
vidcap = cv2.VideoCapture(video_path) | |
total = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
fps = vidcap.get(cv2.CAP_PROP_FPS) | |
idxs = np.linspace(0, total - 1, num_frames, dtype=int) | |
frames = [] | |
for i in idxs: | |
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) | |
ok, img = vidcap.read() | |
if ok: | |
rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
pil = Image.fromarray(rgb) | |
timestamp = round(i / fps, 2) | |
frames.append((pil, timestamp)) | |
vidcap.release() | |
return frames | |
def progress_bar_html(label: str) -> str: | |
return f'''<div style="display:flex; align-items:center;"> | |
<span style="margin-right:10px; font-size:14px;">{label}</span> | |
<div style="width:110px; height:5px; background:#B0E0E6; border-radius:2px; overflow:hidden;"> | |
<div style="width:100%; height:100%; background:#00FFFF; animation:load 1.5s linear infinite;"></div> | |
</div> | |
</div> | |
<style>@keyframes load{{0%{{transform:translateX(-100%)}}100%{{transform:translateX(100%)}}}}</style>''' | |
def generate(prompt: str, files: list[str] = None): | |
files = files or [] | |
# Determine mode | |
is_video = any(f.lower().endswith(('.mp4', '.avi', '.mov')) for f in files) | |
is_image = any(f.lower().endswith(('.jpg', '.png', '.jpeg', '.bmp')) for f in files) | |
if is_video: | |
yield progress_bar_html("Processing video with cosmos-reason1") | |
video = files[0] | |
frames = downsample_video(video) | |
# Build messages | |
messages = [ | |
{"role": "system", "content": [{"type":"text","text":"You are a helpful assistant."}]}, | |
{"role": "user", "content": [{"type":"text","text": prompt}]} | |
] | |
for img, ts in frames: | |
path = f"frame_{uuid.uuid4().hex}.png" | |
img.save(path) | |
messages[1]["content"].extend([ | |
{"type":"text","text": f"Frame {ts}:"}, | |
{"type":"image","url": path} | |
]) | |
inputs = processor.apply_chat_template( | |
messages, tokenize=True, add_generation_prompt=True, | |
return_dict=True, return_tensors="pt", | |
truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH | |
).to(device) | |
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) | |
Thread(target=model.generate, kwargs={**inputs, "streamer": streamer}).start() | |
buffer = "" | |
for txt in streamer: | |
buffer += txt.replace("<|im_end|>", "") | |
time.sleep(0.01) | |
yield buffer | |
return | |
if is_image: | |
yield progress_bar_html("Processing image with cosmos-reason1") | |
imgs = [Image.open(f) for f in files] | |
messages = [ | |
{"role":"user","content":[*[{"type":"image","image":i} for i in imgs],{"type":"text","text":prompt}]}] | |
prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
inputs = processor( | |
text=[prompt_full], images=imgs, | |
return_tensors="pt", padding=True, | |
truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH | |
).to(device) | |
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) | |
Thread(target=model.generate, kwargs={**inputs, "streamer": streamer}).start() | |
out = "" | |
for txt in streamer: | |
out += txt.replace("<|im_end|>", "") | |
time.sleep(0.01) | |
yield out | |
return | |
# No valid media | |
yield "Please upload at least one image or a video for inference." | |
def main(): | |
demo = gr.ChatInterface( | |
fn=generate, | |
additional_inputs=[ | |
gr.File(label="Upload Images/Videos", file_types=["image", "video"], file_count="multiple") | |
], | |
description="# **cosmos-reason1 by nvidia**", | |
textbox=gr.Textbox(label="Prompt"), | |
cache_examples=False, | |
type="messages", | |
multimodal=True, | |
stop_btn="Stop Generation" | |
) | |
demo.queue(max_size=10).launch(share=True) | |
if __name__ == "__main__": | |
main() |