prithivMLmods commited on
Commit
b051d42
·
verified ·
1 Parent(s): 23ca3e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -148
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import spaces
2
  import json
3
- import math
4
  import os
5
  import traceback
6
  from io import BytesIO
@@ -15,22 +14,19 @@ import torch
15
  from PIL import Image
16
 
17
  from transformers import (
18
- Qwen2VLForConditionalGeneration,
19
  Qwen2_5_VLForConditionalGeneration,
20
- AutoModelForImageTextToText,
21
  AutoProcessor,
22
  TextIteratorStreamer,
23
- AutoModel,
24
- AutoTokenizer,
25
  )
26
-
27
- from transformers.image_utils import load_image
 
 
 
 
28
 
29
  # --- Constants and Model Setup ---
30
  MAX_INPUT_TOKEN_LENGTH = 4096
31
- # Note: The following line correctly falls back to CPU if CUDA is not available.
32
- # Let the environment (e.g., Hugging Face Spaces) determine the device.
33
- # This avoids conflicts with the CUDA environment setup by the platform.
34
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
 
36
  print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES"))
@@ -44,24 +40,6 @@ if torch.cuda.is_available():
44
 
45
  print("Using device:", device)
46
 
47
- # --- Model Loading ---
48
-
49
- # --- Prompts for Different Tasks ---
50
- layout_prompt = """Please output the layout information from the image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
51
-
52
- 1. Bbox format: [x1, y1, x2, y2]
53
- 2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
54
- 3. Text Extraction & Formatting Rules:
55
- - For tables, provide the content in a structured JSON format.
56
- - For all other elements, provide the plain text.
57
- 4. Constraints:
58
- - The output must be the original text from the image.
59
- - All layout elements must be sorted according to human reading order.
60
- 5. Final Output: The entire output must be a single JSON object wrapped in ```json ... ```.
61
- """
62
-
63
- ocr_prompt = "Perform precise OCR on the image. Extract all text content, maintaining the original structure, paragraphs, and tables as formatted markdown."
64
-
65
  # --- Model Loading ---
66
  MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-080125"
67
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
@@ -96,78 +74,61 @@ model_i = Qwen2_5_VLForConditionalGeneration.from_pretrained(
96
  MODEL_ID_I, trust_remote_code=True, torch_dtype=torch.float16
97
  ).to(device).eval()
98
 
99
- # --- Utility Functions ---
100
- def layoutjson2md(layout_data: Any) -> str:
101
- """
102
- FIXED: Converts the structured JSON from Layout Analysis into formatted Markdown.
103
- This version is robust against malformed JSON from the model.
104
- """
105
- markdown_lines = []
106
-
107
- # If the model wraps the list in a dictionary, find and extract the list.
108
- if isinstance(layout_data, dict):
109
- found_list = None
110
- for value in layout_data.values():
111
- if isinstance(value, list):
112
- found_list = value
113
- break
114
- if found_list is not None:
115
- layout_data = found_list
116
- else:
117
- return "### Error: Could not find a list of layout items in the JSON object."
118
-
119
- if not isinstance(layout_data, list):
120
- return f"### Error: Expected a list of layout items, but received type {type(layout_data).__name__}."
121
-
122
- try:
123
- # Filter out any non-dictionary items and sort by reading order.
124
- valid_items = [item for item in layout_data if isinstance(item, dict)]
125
- sorted_items = sorted(valid_items, key=lambda x: (x.get('bbox', [0, 0, 0, 0])[1], x.get('bbox', [0, 0, 0, 0])[0]))
126
-
127
- for item in sorted_items:
128
- category = item.get('category', 'Text') # Default to 'Text' if no category
129
- text = item.get('text', '')
130
- if not text:
131
- continue
132
-
133
- if category == 'Title':
134
- markdown_lines.append(f"# {text}\n")
135
- elif category == 'Section-header':
136
- markdown_lines.append(f"## {text}\n")
137
- elif category == 'Table':
138
- if isinstance(text, dict) and 'header' in text and 'rows' in text:
139
- header = '| ' + ' | '.join(map(str, text['header'])) + ' |'
140
- separator = '| ' + ' | '.join(['---'] * len(text['header'])) + ' |'
141
- rows = ['| ' + ' | '.join(map(str, row)) + ' |' for row in text['rows']]
142
- markdown_lines.extend([header, separator] + rows)
143
- markdown_lines.append("\n")
144
- else: # Fallback for simple text or malformed tables
145
- markdown_lines.append(f"{text}\n")
146
- else:
147
- markdown_lines.append(f"{text}\n")
148
-
149
- except Exception as e:
150
- print(f"Error converting to markdown: {e}")
151
- traceback.print_exc()
152
- return "### Error: An unexpected error occurred while converting JSON to Markdown."
153
-
154
- return "\n".join(markdown_lines)
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  # --- Core Application Logic ---
158
  @spaces.GPU
159
- def process_document_stream(model_name: str, task_choice: str, image: Image.Image, max_new_tokens: int):
160
  """
161
- Main generator function that handles both OCR and Layout Analysis tasks.
162
  """
163
  if image is None:
164
  yield "Please upload an image.", "Please upload an image.", None
165
  return
166
 
167
- # 1. Select prompt based on user's task choice
168
- text_prompt = ocr_prompt if task_choice == "Content Extraction" else layout_prompt
169
-
170
- # 2. Select model and processor
171
  if model_name == "Camel-Doc-OCR-080125": processor, model = processor_m, model_m
172
  elif model_name == "Megalodon-OCR-Sync-0713": processor, model = processor_t, model_t
173
  elif model_name == "Nanonets-OCR-s": processor, model = processor_c, model_c
@@ -177,7 +138,12 @@ def process_document_stream(model_name: str, task_choice: str, image: Image.Imag
177
  yield "Invalid model selected.", "Invalid model selected.", None
178
  return
179
 
180
- # 3. Prepare model inputs and streamer
 
 
 
 
 
181
  messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": text_prompt}]}]
182
  prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
183
  inputs = processor(text=[prompt_full], images=[image], return_tensors="pt", padding=True, truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH).to(device)
@@ -187,41 +153,23 @@ def process_document_stream(model_name: str, task_choice: str, image: Image.Imag
187
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
188
  thread.start()
189
 
190
- # 4. Stream raw output to the UI in real-time
191
  buffer = ""
192
  for new_text in streamer:
193
  buffer += new_text
194
  buffer = buffer.replace("<|im_end|>", "")
195
  time.sleep(0.01)
196
- yield buffer , "⏳ Processing...", {"status": "streaming"}
197
-
198
- # 5. Post-process the final buffer based on the selected task
199
- if task_choice == "Content Extraction":
200
- # For OCR, the buffer is the final result.
201
- yield buffer, buffer, None
202
- else: # Layout Analysis
203
- try:
204
- json_match = re.search(r'```json\s*([\s\S]+?)\s*```', buffer)
205
- if not json_match:
206
- # If no JSON block is found, try to parse the whole buffer as a fallback.
207
- try:
208
- layout_data = json.loads(buffer)
209
- markdown_content = layoutjson2md(layout_data)
210
- yield buffer, markdown_content, layout_data
211
- return
212
- except json.JSONDecodeError:
213
- raise ValueError("JSON object not found in the model's output.")
214
-
215
- json_str = json_match.group(1)
216
- layout_data = json.loads(json_str)
217
- markdown_content = layoutjson2md(layout_data)
218
 
219
- yield buffer, markdown_content, layout_data
220
- except Exception as e:
221
- error_md = f"❌ **Error:** Failed to parse Layout JSON.\n\n**Details:**\n`{str(e)}`\n\n**Raw Output:**\n```\n{buffer}\n```"
222
- error_json = {"error": "ProcessingError", "details": str(e), "raw_output": buffer}
223
- yield buffer, error_md, error_json
224
 
 
 
 
225
 
226
  # --- Gradio UI Definition ---
227
  def create_gradio_interface():
@@ -230,13 +178,15 @@ def create_gradio_interface():
230
  .main-container { max-width: 1400px; margin: 0 auto; }
231
  .process-button { border: none !important; color: white !important; font-weight: bold !important; background-color: blue !important;}
232
  .process-button:hover { background-color: darkblue !important; transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
 
 
233
  """
234
  with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
235
  gr.HTML("""
236
  <div class="title" style="text-align: center">
237
  <h1>Tiny VLMs Lab🧪</h1>
238
  <p style="font-size: 1.1em; color: #6b7280; margin-bottom: 0.6em;">
239
- Advanced Vision-Language Model for Image Content and Layout Extraction
240
  </p>
241
  </div>
242
  """)
@@ -245,23 +195,39 @@ def create_gradio_interface():
245
  # Left Column (Inputs)
246
  with gr.Column(scale=1):
247
  model_choice = gr.Dropdown(
248
- choices=["Camel-Doc-OCR-080125",
249
- "MonkeyOCR-Recognition",
250
- "olmOCR-7B-0725",
251
- "Nanonets-OCR-s",
252
- "Megalodon-OCR-Sync-0713"
253
- ],
254
- label="Select Model",
 
255
  value="Nanonets-OCR-s"
256
  )
257
- task_choice = gr.Dropdown(
258
- choices=["Content Extraction",
259
- "Layout Analysis(.json)"],
260
- label="Select Task", value="Content Extraction"
261
- )
262
  image_input = gr.Image(label="Upload Image", type="pil", sources=['upload'])
263
  with gr.Accordion("Advanced Settings", open=False):
264
  max_new_tokens = gr.Slider(minimum=512, maximum=8192, value=4096, step=256, label="Max New Tokens")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
 
266
  process_btn = gr.Button("🚀 Process Document", variant="primary", elem_classes=["process-button"], size="lg")
267
  clear_btn = gr.Button("🗑️ Clear All", variant="secondary")
@@ -276,14 +242,13 @@ def create_gradio_interface():
276
  examples=["examples/1.png", "examples/2.png", "examples/3.png", "examples/4.png", "examples/5.png"],
277
  inputs=image_input,
278
  label="Examples"
279
- )
280
  gr.Markdown("[Report-Bug💻](https://huggingface.co/spaces/prithivMLmods/OCR-Comparator/discussions)")
281
  with gr.Tab("📰 README.md"):
282
  with gr.Accordion("(Formatted Result)", open=True):
283
  markdown_output = gr.Markdown(label="Formatted Markdown")
284
-
285
- with gr.Tab("📋 Layout Analysis Results"):
286
- json_output = gr.JSON(label="Structured Layout Data (JSON)")
287
 
288
  # Event Handlers
289
  def clear_all_outputs():
@@ -291,20 +256,12 @@ def create_gradio_interface():
291
 
292
  process_btn.click(
293
  fn=process_document_stream,
294
- inputs=[model_choice,
295
- task_choice,
296
- image_input,
297
- max_new_tokens],
298
- outputs=[raw_output_stream,
299
- markdown_output,
300
- json_output]
301
  )
302
  clear_btn.click(
303
- clear_all_outputs,
304
- outputs=[image_input,
305
- raw_output_stream,
306
- markdown_output,
307
- json_output]
308
  )
309
  return demo
310
 
 
1
  import spaces
2
  import json
 
3
  import os
4
  import traceback
5
  from io import BytesIO
 
14
  from PIL import Image
15
 
16
  from transformers import (
 
17
  Qwen2_5_VLForConditionalGeneration,
 
18
  AutoProcessor,
19
  TextIteratorStreamer,
 
 
20
  )
21
+ from reportlab.lib.pagesizes import A4
22
+ from reportlab.lib.styles import getSampleStyleSheet
23
+ from reportlab.lib import colors
24
+ from reportlab.platypus import SimpleDocTemplate, Image as RLImage, Paragraph, Spacer
25
+ from reportlab.lib.units import inch
26
+ import uuid
27
 
28
  # --- Constants and Model Setup ---
29
  MAX_INPUT_TOKEN_LENGTH = 4096
 
 
 
30
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
 
32
  print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES"))
 
40
 
41
  print("Using device:", device)
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  # --- Model Loading ---
44
  MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-080125"
45
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
 
74
  MODEL_ID_I, trust_remote_code=True, torch_dtype=torch.float16
75
  ).to(device).eval()
76
 
77
+ # --- Prompts ---
78
+ ocr_prompt = "Perform precise OCR on the image. Extract all text content, maintaining the original structure, paragraphs, and tables as formatted markdown."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
+ # --- PDF Generation Functions ---
81
+ def generate_pdf(media_path, plain_text, font_size, line_spacing, alignment, image_size):
82
+ """Generates a PDF document."""
83
+ filename = f"output_{uuid.uuid4()}.pdf"
84
+ doc = SimpleDocTemplate(
85
+ filename,
86
+ pagesize=A4,
87
+ rightMargin=inch,
88
+ leftMargin=inch,
89
+ topMargin=inch,
90
+ bottomMargin=inch
91
+ )
92
+ styles = getSampleStyleSheet()
93
+ styles["Normal"].fontSize = int(font_size)
94
+ styles["Normal"].leading = int(font_size) * line_spacing
95
+ styles["Normal"].alignment = {
96
+ "Left": 0,
97
+ "Center": 1,
98
+ "Right": 2,
99
+ "Justified": 4
100
+ }[alignment]
101
+
102
+ story = []
103
+
104
+ # Add image with size adjustment
105
+ image_sizes = {
106
+ "Small": (200, 200),
107
+ "Medium": (400, 400),
108
+ "Large": (600, 600)
109
+ }
110
+ img = RLImage(media_path, width=image_sizes[image_size][0], height=image_sizes[image_size][1])
111
+ story.append(img)
112
+ story.append(Spacer(1, 12))
113
+
114
+ # Add plain text output
115
+ text = Paragraph(plain_text, styles["Normal"])
116
+ story.append(text)
117
+
118
+ doc.build(story)
119
+ return filename
120
 
121
  # --- Core Application Logic ---
122
  @spaces.GPU
123
+ def process_document_stream(model_name: str, image: Image.Image, max_new_tokens: int, font_size: str, line_spacing: float, alignment: str, image_size: str):
124
  """
125
+ Main generator function for OCR task, also generating PDF for preview.
126
  """
127
  if image is None:
128
  yield "Please upload an image.", "Please upload an image.", None
129
  return
130
 
131
+ # Select model and processor
 
 
 
132
  if model_name == "Camel-Doc-OCR-080125": processor, model = processor_m, model_m
133
  elif model_name == "Megalodon-OCR-Sync-0713": processor, model = processor_t, model_t
134
  elif model_name == "Nanonets-OCR-s": processor, model = processor_c, model_c
 
138
  yield "Invalid model selected.", "Invalid model selected.", None
139
  return
140
 
141
+ # Save image temporarily for PDF generation
142
+ temp_image_path = f"temp_{uuid.uuid4()}.png"
143
+ image.save(temp_image_path)
144
+
145
+ # Prepare model inputs and streamer
146
+ text_prompt = ocr_prompt
147
  messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": text_prompt}]}]
148
  prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
149
  inputs = processor(text=[prompt_full], images=[image], return_tensors="pt", padding=True, truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH).to(device)
 
153
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
154
  thread.start()
155
 
156
+ # Stream raw output to the UI in real-time
157
  buffer = ""
158
  for new_text in streamer:
159
  buffer += new_text
160
  buffer = buffer.replace("<|im_end|>", "")
161
  time.sleep(0.01)
162
+ # Generate PDF with current buffer
163
+ pdf_file = generate_pdf(temp_image_path, buffer, font_size, line_spacing, alignment, image_size)
164
+ yield buffer, buffer, pdf_file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
+ # Final PDF with complete output
167
+ pdf_file = generate_pdf(temp_image_path, buffer, font_size, line_spacing, alignment, image_size)
168
+ yield buffer, buffer, pdf_file
 
 
169
 
170
+ # Clean up temporary image file
171
+ if os.path.exists(temp_image_path):
172
+ os.remove(temp_image_path)
173
 
174
  # --- Gradio UI Definition ---
175
  def create_gradio_interface():
 
178
  .main-container { max-width: 1400px; margin: 0 auto; }
179
  .process-button { border: none !important; color: white !important; font-weight: bold !important; background-color: blue !important;}
180
  .process-button:hover { background-color: darkblue !important; transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
181
+ .download-btn { background-color: #35a6d6 !important; color: white !important; }
182
+ .download-btn:hover { background-color: #22bcff !important; }
183
  """
184
  with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
185
  gr.HTML("""
186
  <div class="title" style="text-align: center">
187
  <h1>Tiny VLMs Lab🧪</h1>
188
  <p style="font-size: 1.1em; color: #6b7280; margin-bottom: 0.6em;">
189
+ Advanced Vision-Language Model for Image Content Extraction and PDF Generation
190
  </p>
191
  </div>
192
  """)
 
195
  # Left Column (Inputs)
196
  with gr.Column(scale=1):
197
  model_choice = gr.Dropdown(
198
+ choices=[
199
+ "Camel-Doc-OCR-080125",
200
+ "MonkeyOCR-Recognition",
201
+ "olmOCR-7B-0725",
202
+ "Nanonets-OCR-s",
203
+ "Megalodon-OCR-Sync-0713"
204
+ ],
205
+ label="Select Model",
206
  value="Nanonets-OCR-s"
207
  )
 
 
 
 
 
208
  image_input = gr.Image(label="Upload Image", type="pil", sources=['upload'])
209
  with gr.Accordion("Advanced Settings", open=False):
210
  max_new_tokens = gr.Slider(minimum=512, maximum=8192, value=4096, step=256, label="Max New Tokens")
211
+ font_size = gr.Dropdown(
212
+ choices=["8", "10", "12", "14", "16", "18", "20", "22", "24"],
213
+ value="16",
214
+ label="Font Size"
215
+ )
216
+ line_spacing = gr.Dropdown(
217
+ choices=[0.5, 1.0, 1.15, 1.5, 2.0, 2.5, 3.0],
218
+ value=1.5,
219
+ label="Line Spacing"
220
+ )
221
+ alignment = gr.Dropdown(
222
+ choices=["Left", "Center", "Right", "Justified"],
223
+ value="Justified",
224
+ label="Text Alignment"
225
+ )
226
+ image_size = gr.Dropdown(
227
+ choices=["Small", "Medium", "Large"],
228
+ value="Medium",
229
+ label="Image Size"
230
+ )
231
 
232
  process_btn = gr.Button("🚀 Process Document", variant="primary", elem_classes=["process-button"], size="lg")
233
  clear_btn = gr.Button("🗑️ Clear All", variant="secondary")
 
242
  examples=["examples/1.png", "examples/2.png", "examples/3.png", "examples/4.png", "examples/5.png"],
243
  inputs=image_input,
244
  label="Examples"
245
+ )
246
  gr.Markdown("[Report-Bug💻](https://huggingface.co/spaces/prithivMLmods/OCR-Comparator/discussions)")
247
  with gr.Tab("📰 README.md"):
248
  with gr.Accordion("(Formatted Result)", open=True):
249
  markdown_output = gr.Markdown(label="Formatted Markdown")
250
+ with gr.Tab("📋 PDF Preview"):
251
+ pdf_output = gr.File(label="Download PDF", interactive=True)
 
252
 
253
  # Event Handlers
254
  def clear_all_outputs():
 
256
 
257
  process_btn.click(
258
  fn=process_document_stream,
259
+ inputs=[model_choice, image_input, max_new_tokens, font_size, line_spacing, alignment, image_size],
260
+ outputs=[raw_output_stream, markdown_output, pdf_output]
 
 
 
 
 
261
  )
262
  clear_btn.click(
263
+ fn=clear_all_outputs,
264
+ outputs=[image_input, raw_output_stream, markdown_output, pdf_output]
 
 
 
265
  )
266
  return demo
267