prithivMLmods commited on
Commit
3d442e6
·
verified ·
1 Parent(s): ff82d30

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +305 -0
app.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import json
3
+ import math
4
+ import os
5
+ import traceback
6
+ from io import BytesIO
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+ import re
9
+ import time
10
+ from threading import Thread
11
+ from io import BytesIO
12
+ import uuid
13
+ import tempfile
14
+
15
+ import gradio as gr
16
+ import requests
17
+ import torch
18
+ from PIL import Image
19
+ import fitz
20
+
21
+ from transformers import (
22
+ Qwen2_5_VLForConditionalGeneration,
23
+ AutoModelForVision2Seq,
24
+ AutoModelForImageTextToText,
25
+ AutoModel,
26
+ AutoProcessor,
27
+ TextIteratorStreamer,
28
+ )
29
+
30
+ from transformers.image_utils import load_image
31
+
32
+ from reportlab.lib.pagesizes import A4
33
+ from reportlab.lib.styles import getSampleStyleSheet
34
+ from reportlab.platypus import SimpleDocTemplate, Image as RLImage, Paragraph, Spacer
35
+ from reportlab.lib.units import inch
36
+
37
+ # --- Constants and Model Setup ---
38
+ MAX_INPUT_TOKEN_LENGTH = 4096
39
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+
41
+ print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES"))
42
+ print("torch.__version__ =", torch.__version__)
43
+ print("torch.version.cuda =", torch.version.cuda)
44
+ print("cuda available:", torch.cuda.is_available())
45
+ print("cuda device count:", torch.cuda.device_count())
46
+ if torch.cuda.is_available():
47
+ print("current device:", torch.cuda.current_device())
48
+ print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
49
+
50
+ print("Using device:", device)
51
+
52
+ # --- Model Loading ---
53
+ MODEL_ID_M = "LiquidAI/LFM2-VL-450M"
54
+ processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
55
+ model_m = AutoModelForImageTextToText.from_pretrained(
56
+ MODEL_ID_M, trust_remote_code=True, torch_dtype=torch.float16
57
+ ).to(device).eval()
58
+
59
+ MODEL_ID_T = "LiquidAI/LFM2-VL-1.6B"
60
+ processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True)
61
+ model_t = AutoModelForImageTextToText.from_pretrained(
62
+ MODEL_ID_T, trust_remote_code=True, torch_dtype=torch.float16
63
+ ).to(device).eval()
64
+
65
+ MODEL_ID_C = "HuggingFaceTB/SmolVLM-Instruct-250M"
66
+ processor_c = AutoProcessor.from_pretrained(MODEL_ID_C, trust_remote_code=True)
67
+ model_c = AutoModelForVision2Seq.from_pretrained(
68
+ MODEL_ID_C, trust_remote_code=True, torch_dtype=torch.float16, _attn_implementation="flash_attention_2"
69
+ ).to(device).eval()
70
+
71
+ MODEL_ID_G = "echo840/MonkeyOCR-pro-1.2B"
72
+ SUBFOLDER = "Recognition"
73
+ processor_g = AutoProcessor.from_pretrained(
74
+ MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER
75
+ )
76
+ model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained(
77
+ MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER, torch_dtype=torch.float16
78
+ ).to(device).eval()
79
+
80
+ MODEL_ID_I = "HuggingFaceTB/SmolVLM2-2.2B-Instruct"
81
+ processor_i = AutoProcessor.from_pretrained(MODEL_ID_I, trust_remote_code=True)
82
+ model_i = AutoModelForImageTextToText.from_pretrained(
83
+ MODEL_ID_I, trust_remote_code=True, torch_dtype=torch.float16, _attn_implementation="flash_attention_2"
84
+ ).to(device).eval()
85
+
86
+
87
+ # --- PDF Generation and Preview Utility Function ---
88
+ def generate_and_preview_pdf(image: Image.Image, text_content: str, font_size: int, line_spacing: float, alignment: str, image_size: str):
89
+ """
90
+ Generates a PDF, saves it, and then creates image previews of its pages.
91
+ Returns the path to the PDF and a list of paths to the preview images.
92
+ """
93
+ if image is None or not text_content or not text_content.strip():
94
+ raise gr.Error("Cannot generate PDF. Image or text content is missing.")
95
+
96
+ # --- 1. Generate the PDF ---
97
+ temp_dir = tempfile.gettempdir()
98
+ pdf_filename = os.path.join(temp_dir, f"output_{uuid.uuid4()}.pdf")
99
+ doc = SimpleDocTemplate(
100
+ pdf_filename,
101
+ pagesize=A4,
102
+ rightMargin=inch, leftMargin=inch,
103
+ topMargin=inch, bottomMargin=inch
104
+ )
105
+ styles = getSampleStyleSheet()
106
+ style_normal = styles["Normal"]
107
+ style_normal.fontSize = int(font_size)
108
+ style_normal.leading = int(font_size) * line_spacing
109
+ style_normal.alignment = {"Left": 0, "Center": 1, "Right": 2, "Justified": 4}[alignment]
110
+
111
+ story = []
112
+
113
+ img_buffer = BytesIO()
114
+ image.save(img_buffer, format='PNG')
115
+ img_buffer.seek(0)
116
+
117
+ page_width, _ = A4
118
+ available_width = page_width - 2 * inch
119
+ image_widths = {
120
+ "Small": available_width * 0.3,
121
+ "Medium": available_width * 0.6,
122
+ "Large": available_width * 0.9,
123
+ }
124
+ img_width = image_widths[image_size]
125
+ img = RLImage(img_buffer, width=img_width, height=image.height * (img_width / image.width))
126
+ story.append(img)
127
+ story.append(Spacer(1, 12))
128
+
129
+ cleaned_text = re.sub(r'#+\s*', '', text_content).replace("*", "")
130
+ text_paragraphs = cleaned_text.split('\n')
131
+
132
+ for para in text_paragraphs:
133
+ if para.strip():
134
+ story.append(Paragraph(para, style_normal))
135
+
136
+ doc.build(story)
137
+
138
+ # --- 2. Render PDF pages as images for preview ---
139
+ preview_images = []
140
+ try:
141
+ pdf_doc = fitz.open(pdf_filename)
142
+ for page_num in range(len(pdf_doc)):
143
+ page = pdf_doc.load_page(page_num)
144
+ pix = page.get_pixmap(dpi=150)
145
+ preview_img_path = os.path.join(temp_dir, f"preview_{uuid.uuid4()}_p{page_num}.png")
146
+ pix.save(preview_img_path)
147
+ preview_images.append(preview_img_path)
148
+ pdf_doc.close()
149
+ except Exception as e:
150
+ print(f"Error generating PDF preview: {e}")
151
+
152
+ return pdf_filename, preview_images
153
+
154
+
155
+ # --- Core Application Logic ---
156
+ @spaces.GPU
157
+ def process_document_stream(
158
+ model_name: str,
159
+ image: Image.Image,
160
+ prompt_input: str,
161
+ max_new_tokens: int,
162
+ temperature: float,
163
+ top_p: float,
164
+ top_k: int,
165
+ repetition_penalty: float
166
+ ):
167
+ """
168
+ Main generator function that handles model inference tasks with advanced generation parameters.
169
+ """
170
+ if image is None:
171
+ yield "Please upload an image.", ""
172
+ return
173
+ if not prompt_input or not prompt_input.strip():
174
+ yield "Please enter a prompt.", ""
175
+ return
176
+
177
+ if model_name == "LFM2-VL-450M": processor, model = processor_m, model_m
178
+ elif model_name == "LFM2-VL-1.6B": processor, model = processor_t, model_t
179
+ elif model_name == "SmolVLM-Instruct-250M": processor, model = processor_c, model_c
180
+ elif model_name == "MonkeyOCR-pro-1.2B": processor, model = processor_g, model_g
181
+ elif model_name == "SmolVLM2-2.2B-Instruct": processor, model = processor_i, model_i
182
+ else:
183
+ yield "Invalid model selected.", ""
184
+ return
185
+
186
+ messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt_input}]}]
187
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
188
+ inputs = processor(text=[prompt_full], images=[image], return_tensors="pt", padding=True, truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH).to(device)
189
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
190
+
191
+ generation_kwargs = {
192
+ **inputs,
193
+ "streamer": streamer,
194
+ "max_new_tokens": max_new_tokens,
195
+ "temperature": temperature,
196
+ "top_p": top_p,
197
+ "top_k": top_k,
198
+ "repetition_penalty": repetition_penalty,
199
+ "do_sample": True
200
+ }
201
+
202
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
203
+ thread.start()
204
+
205
+ buffer = ""
206
+ for new_text in streamer:
207
+ buffer += new_text
208
+ buffer = buffer.replace("<|im_end|>", "")
209
+ time.sleep(0.01)
210
+ yield buffer , buffer
211
+
212
+ yield buffer, buffer
213
+
214
+
215
+ # --- Gradio UI Definition ---
216
+ def create_gradio_interface():
217
+ """Builds and returns the Gradio web interface."""
218
+ css = """
219
+ .main-container { max-width: 1400px; margin: 0 auto; }
220
+ .process-button { border: none !important; color: white !important; font-weight: bold !important; background-color: blue !important;}
221
+ .process-button:hover { background-color: darkblue !important; transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
222
+ #gallery { min-height: 400px; }
223
+ """
224
+ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
225
+ gr.HTML("""
226
+ <div class="title" style="text-align: center">
227
+ <h1>Tiny VLMs Lab🧪</h1>
228
+ <p style="font-size: 1.1em; color: #6b7280; margin-bottom: 0.6em;">
229
+ Advanced Vision-Language Model for Image Content and Layout Extraction
230
+ </p>
231
+ </div>
232
+ """)
233
+
234
+ with gr.Row():
235
+ # Left Column (Inputs)
236
+ with gr.Column(scale=1):
237
+ model_choice = gr.Dropdown(
238
+ choices=["LFM2-VL-1.6B", "LFM2-VL-450M", "SmolVLM-Instruct-250M", "SmolVLM2-2.2B-Instruct", "MonkeyOCR-pro-1.2B"],
239
+ label="Select Model", value="LFM2-VL-1.6B"
240
+ )
241
+ prompt_input = gr.Textbox(label="Query Input", placeholder="✦︎ Enter your query")
242
+ image_input = gr.Image(label="Upload Image", type="pil", sources=['upload'])
243
+
244
+ with gr.Accordion("Advanced Settings", open=False):
245
+ max_new_tokens = gr.Slider(minimum=512, maximum=8192, value=4096, step=256, label="Max New Tokens")
246
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
247
+ top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
248
+ top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
249
+ repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
250
+
251
+ gr.Markdown("### PDF Export Settings")
252
+ font_size = gr.Dropdown(choices=["8", "10", "12", "14", "16", "18"], value="12", label="Font Size")
253
+ line_spacing = gr.Dropdown(choices=[1.0, 1.15, 1.5, 2.0], value=1.15, label="Line Spacing")
254
+ alignment = gr.Dropdown(choices=["Left", "Center", "Right", "Justified"], value="Justified", label="Text Alignment")
255
+ image_size = gr.Dropdown(choices=["Small", "Medium", "Large"], value="Medium", label="Image Size in PDF")
256
+
257
+ process_btn = gr.Button("🚀 Process Image", variant="primary", elem_classes=["process-button"], size="lg")
258
+ clear_btn = gr.Button("🗑️ Clear All", variant="secondary")
259
+
260
+ # Right Column (Outputs)
261
+ with gr.Column(scale=2):
262
+ with gr.Tabs() as tabs:
263
+ with gr.Tab("📝 Extracted Content"):
264
+ raw_output_stream = gr.Textbox(label="Raw Model Output Stream", interactive=False, lines=15, show_copy_button=True)
265
+ with gr.Row():
266
+ examples = gr.Examples(
267
+ examples=["examples/1.png", "examples/2.png", "examples/3.png", "examples/4.png", "examples/5.png"],
268
+ inputs=image_input, label="Examples"
269
+ )
270
+ gr.Markdown("[Report-Bug💻](https://huggingface.co/spaces/prithivMLmods/OCR-Comparator/discussions)")
271
+
272
+ with gr.Tab("📰 README.md"):
273
+ with gr.Accordion("(Result.md)", open=True):
274
+ markdown_output = gr.Markdown()
275
+
276
+ with gr.Tab("📋 PDF Preview"):
277
+ generate_pdf_btn = gr.Button("📄 Generate PDF & Render", variant="primary")
278
+ pdf_output_file = gr.File(label="Download Generated PDF", interactive=False)
279
+ pdf_preview_gallery = gr.Gallery(label="PDF Page Preview", show_label=True, elem_id="gallery", columns=2, object_fit="contain", height="auto")
280
+
281
+ # Event Handlers
282
+ def clear_all_outputs():
283
+ return None, "", "Raw output will appear here.", "", None, None
284
+
285
+ process_btn.click(
286
+ fn=process_document_stream,
287
+ inputs=[model_choice, image_input, prompt_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
288
+ outputs=[raw_output_stream, markdown_output]
289
+ )
290
+
291
+ generate_pdf_btn.click(
292
+ fn=generate_and_preview_pdf,
293
+ inputs=[image_input, raw_output_stream, font_size, line_spacing, alignment, image_size],
294
+ outputs=[pdf_output_file, pdf_preview_gallery]
295
+ )
296
+
297
+ clear_btn.click(
298
+ clear_all_outputs,
299
+ outputs=[image_input, prompt_input, raw_output_stream, markdown_output, pdf_output_file, pdf_preview_gallery]
300
+ )
301
+ return demo
302
+
303
+ if __name__ == "__main__":
304
+ demo = create_gradio_interface()
305
+ demo.queue(max_size=50).launch(share=True, ssr_mode=False, show_error=True)