File size: 4,938 Bytes
e14e6d1
 
ec8d7fa
 
 
 
a85c4cf
ec8d7fa
 
221d2b6
b3a3e40
ec8d7fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8110123
ec94f98
ec8d7fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8110123
ec8d7fa
 
 
 
 
 
8110123
ec8d7fa
8110123
 
ec8d7fa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
import uuid
import time
import asyncio
from threading import Thread

import gradio as gr
import spaces
import torch
import numpy as np
from PIL import Image
import cv2

from transformers import (
    Qwen2_5_VLForConditionalGeneration,
    AutoProcessor,
    TextIteratorStreamer
)

# Constants
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load multimodal processor and model (Callisto OCR3)
MODEL_ID = "nvidia/Cosmos-Reason1-7B"
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
    torch_dtype=torch.float16
).to(device).eval()


def downsample_video(video_path: str, num_frames: int = 10):
    vidcap = cv2.VideoCapture(video_path)
    total = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = vidcap.get(cv2.CAP_PROP_FPS)
    idxs = np.linspace(0, total - 1, num_frames, dtype=int)
    frames = []
    for i in idxs:
        vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
        ok, img = vidcap.read()
        if ok:
            rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            pil = Image.fromarray(rgb)
            timestamp = round(i / fps, 2)
            frames.append((pil, timestamp))
    vidcap.release()
    return frames


def progress_bar_html(label: str) -> str:
    return f'''<div style="display:flex; align-items:center;">
  <span style="margin-right:10px; font-size:14px;">{label}</span>
  <div style="width:110px; height:5px; background:#B0E0E6; border-radius:2px; overflow:hidden;">
    <div style="width:100%; height:100%; background:#00FFFF; animation:load 1.5s linear infinite;"></div>
  </div>
</div>
<style>@keyframes load{{0%{{transform:translateX(-100%)}}100%{{transform:translateX(100%)}}}}</style>'''

@spaces.GPU
def generate(prompt: str, files: list[str] = None):
    files = files or []
    # Determine mode
    is_video = any(f.lower().endswith(('.mp4', '.avi', '.mov')) for f in files)
    is_image = any(f.lower().endswith(('.jpg', '.png', '.jpeg', '.bmp')) for f in files)

    if is_video:
        yield progress_bar_html("Processing video with cosmos-reason1")
        video = files[0]
        frames = downsample_video(video)
        # Build messages
        messages = [
            {"role": "system", "content": [{"type":"text","text":"You are a helpful assistant."}]},
            {"role": "user", "content": [{"type":"text","text": prompt}]}
        ]
        for img, ts in frames:
            path = f"frame_{uuid.uuid4().hex}.png"
            img.save(path)
            messages[1]["content"].extend([
                {"type":"text","text": f"Frame {ts}:"},
                {"type":"image","url": path}
            ])
        inputs = processor.apply_chat_template(
            messages, tokenize=True, add_generation_prompt=True,
            return_dict=True, return_tensors="pt",
            truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH
        ).to(device)
        streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
        Thread(target=model.generate, kwargs={**inputs, "streamer": streamer}).start()
        buffer = ""
        for txt in streamer:
            buffer += txt.replace("<|im_end|>", "")
            time.sleep(0.01)
            yield buffer
        return

    if is_image:
        yield progress_bar_html("Processing image with cosmos-reason1")
        imgs = [Image.open(f) for f in files]
        messages = [
            {"role":"user","content":[*[{"type":"image","image":i} for i in imgs],{"type":"text","text":prompt}]}]
        prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = processor(
            text=[prompt_full], images=imgs,
            return_tensors="pt", padding=True,
            truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH
        ).to(device)
        streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
        Thread(target=model.generate, kwargs={**inputs, "streamer": streamer}).start()
        out = ""
        for txt in streamer:
            out += txt.replace("<|im_end|>", "")
            time.sleep(0.01)
            yield out
        return

    # No valid media
    yield "Please upload at least one image or a video for inference."


def main():
    demo = gr.ChatInterface(
        fn=generate,
        additional_inputs=[
            gr.File(label="Upload Images/Videos", file_types=["image", "video"], file_count="multiple")
        ],
        description="# **cosmos-reason1 by nvidia**",
        textbox=gr.Textbox(label="Prompt"),
        cache_examples=False,
        type="messages",
        multimodal=True,
        stop_btn="Stop Generation"
    )
    demo.queue(max_size=10).launch(share=True)

if __name__ == "__main__":
    main()