File size: 7,478 Bytes
a85c4cf
b3a3e40
 
 
a85c4cf
b3a3e40
 
a85c4cf
b3a3e40
 
 
 
 
e01e01c
b3a3e40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a0114d
b3a3e40
 
 
 
 
 
 
758151c
 
 
 
b3a3e40
 
 
 
 
 
1c5b159
b3a3e40
 
 
 
 
 
1c5b159
 
b3a3e40
1c5b159
b3a3e40
 
 
 
 
 
1c5b159
 
b3a3e40
 
 
 
 
 
 
 
1c5b159
b3a3e40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
758151c
 
 
 
04cce22
b3a3e40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c5b159
 
 
 
 
b3a3e40
1c5b159
 
 
 
b3a3e40
 
1c5b159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70cf16f
758151c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import gradio as gr
from transformers.image_utils import load_image
from threading import Thread
import time
import torch
import spaces
import cv2
import numpy as np
from PIL import Image
from transformers import (
    Qwen2VLForConditionalGeneration,
    AutoProcessor,
    TextIteratorStreamer,
)
from transformers import Qwen2_5_VLForConditionalGeneration
from pdf2image import convert_from_path

# Helper Functions
def progress_bar_html(label: str, primary_color: str = "#4B0082", secondary_color: str = "#9370DB") -> str:
    """
    Returns an HTML snippet for a thin animated progress bar with a label.
    Colors can be customized; default colors are used for Qwen2VL/Aya‑Vision.
    """
    return f'''
<div style="display: flex; align-items: center;">
    <span style="margin-right: 10px; font-size: 14px;">{label}</span>
    <div style="width: 110px; height: 5px; background-color: {secondary_color}; border-radius: 2px; overflow: hidden;">
        <div style="width: 100%; height: 100%; background-color: {primary_color}; animation: loading 1.5s linear infinite;"></div>
    </div>
</div>
<style>
@keyframes loading {{
    0% {{ transform: translateX(-100%); }}
    100% {{ transform: translateX(100%); }}
}}
</style>
    '''

def downsample_video(video_path):
    """
    Downsamples a video file by extracting 10 evenly spaced frames.
    Returns a list of tuples (PIL.Image, timestamp).
    """
    vidcap = cv2.VideoCapture(video_path)
    total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = vidcap.get(cv2.CAP_PROP_FPS)
    frames = []
    if total_frames <= 0 or fps <= 0:
        vidcap.release()
        return frames
    frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
    for i in frame_indices:
        vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
        success, image = vidcap.read()
        if success:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            pil_image = Image.fromarray(image)
            timestamp = round(i / fps, 2)
            frames.append((pil_image, timestamp))
    vidcap.release()
    return frames

# Model and Processor Setup
QV_MODEL_ID = "prithivMLmods/Qwen2-VL-Ocrtest-2B-Instruct"
qwen_processor = AutoProcessor.from_pretrained(QV_MODEL_ID, trust_remote_code=True)
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
    QV_MODEL_ID,
    trust_remote_code=True,
    torch_dtype=torch.float16
).to("cuda").eval()

DOCSCOPEOCR_MODEL_ID = "prithivMLmods/docscopeOCR-7B-050425-exp"
docscopeocr_processor = AutoProcessor.from_pretrained(DOCSCOPEOCR_MODEL_ID, trust_remote_code=True)
docscopeocr_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    DOCSCOPEOCR_MODEL_ID,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16
).to("cuda").eval()

# Main Inference Function
@spaces.GPU
def model_inference(text, files, history, use_docscopeocr):
    if not text and not files:
        yield "Error: Please input a text query or provide files (images, videos, PDFs)."
        return

    # Process files: images, videos, PDFs
    image_list = []
    for idx, file in enumerate(files or []):
        if file.name.lower().endswith(".pdf"):
            try:
                pdf_images = convert_from_path(file.name)
                for page_num, img in enumerate(pdf_images, start=1):
                    label = f"PDF {idx+1} Page {page_num}:"
                    image_list.append((label, img))
            except Exception as e:
                yield f"Error converting PDF: {str(e)}"
                return
        elif file.name.lower().endswith((".mp4", ".avi", ".mov")):
            frames = downsample_video(file.name)
            if not frames:
                yield "Error: Could not extract frames from the video."
                return
            for frame, timestamp in frames:
                label = f"Video {idx+1} Frame {timestamp}:"
                image_list.append((label, frame))
        else:
            try:
                img = load_image(file.name)
                label = f"Image {idx+1}:"
                image_list.append((label, img))
            except Exception as e:
                yield f"Error loading image: {str(e)}"
                return

    # Build content list
    content = [{"type": "text", "text": text}]
    for label, img in image_list:
        content.append({"type": "text", "text": label})
        content.append({"type": "image", "image": img})

    messages = [{"role": "user", "content": content}]

    # Select processor and model
    if use_docscopeocr:
        processor = docscopeocr_processor
        model = docscopeocr_model
        model_name = "DocScopeOCR"
    else:
        processor = qwen_processor
        model = qwen_model
        model_name = "Qwen2VL OCR"

    prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    all_images = [item["image"] for item in content if item["type"] == "image"]
    inputs = processor(
        text=[prompt_full],
        images=all_images if all_images else None,
        return_tensors="pt",
        padding=True,
    ).to("cuda")

    streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
    generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    buffer = ""
    yield progress_bar_html(f"Processing with {model_name}")
    for new_text in streamer:
        buffer += new_text
        buffer = buffer.replace("<|im_end|>", "")
        time.sleep(0.01)
        yield buffer

# Gradio Interface
def chat_interface(text, files, use_docscopeocr, history):
    if text is None and files is None:
        return "Error: Please input a text query or provide files."
    return model_inference(text, files, history, use_docscopeocr)

examples = [
    {"text": "OCR the Text in the Image", "files": ["rolm/1.jpeg"]},
    {"text": "Explain the Ad in Detail", "files": ["examples/videoplayback.mp4"]},
    {"text": "OCR the Image", "files": ["rolm/3.jpeg"]},
    {"text": "Extract as JSON table from the table", "files": ["examples/4.jpg"]},
]

with gr.Blocks(theme="bethecloud/storj_theme") as demo:
    gr.Markdown("# **DocScope OCR `VL/OCR`**")
    with gr.Row():
        text_input = gr.Textbox(label="Query Input", placeholder="Input your query here.")
        file_input = gr.File(label="Upload Files", file_count="multiple", file_types=["image", "video", "pdf"])
        use_docscopeocr = gr.Checkbox(label="Use DocScopeOCR", value=True, info="Check to use DocScopeOCR, uncheck to use Qwen2VL OCR")
    chat = gr.Chatbot()
    submit_btn = gr.Button("Submit")
    stop_btn = gr.Button("Stop Generation")

    def submit(text, files, use_docscopeocr, history):
        if not history:
            history = []
        history.append({"role": "user", "content": text})
        return history, gr.update(interactive=False), gr.update(interactive=True)

    def generate(history, text, files, use_docscopeocr):
        if not history:
            history = []
        for response in model_inference(text, files, history, use_docscopeocr):
            history.append({"role": "assistant", "content": response})
            yield history

    submit_btn.click(submit, [text_input, file_input, use_docscopeocr, chat], [chat, submit_btn, stop_btn])
    submit_btn.click(generate, [chat, text_input, file_input, use_docscopeocr], chat)

demo.launch(debug=True, ssr_mode=False)