prithivMLmods commited on
Commit
96f7759
·
verified ·
1 Parent(s): fe83899

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +277 -141
app.py CHANGED
@@ -6,28 +6,29 @@ import traceback
6
  from io import BytesIO
7
  from typing import Any, Dict, List, Optional, Tuple
8
  import re
9
- from threading import Thread
10
  import time
 
11
 
12
  import gradio as gr
13
  import requests
14
  import torch
15
- from PIL import Image, ImageDraw, ImageFont
16
  from transformers import (
17
  Qwen2_5_VLForConditionalGeneration,
18
  AutoProcessor,
19
  TextIteratorStreamer,
20
  )
 
21
 
22
  # Constants
23
  MIN_PIXELS = 3136
24
  MAX_PIXELS = 11289600
25
  IMAGE_FACTOR = 28
26
- MAX_INPUT_TOKEN_LENGTH = 4096
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
 
29
- # Prompt for Layout Analysis
30
- prompt = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
31
 
32
  1. Bbox format: [x1, y1, x2, y2]
33
 
@@ -46,74 +47,145 @@ prompt = """Please output the layout information from the PDF image, including e
46
  5. Final Output: The entire output must be a single JSON object.
47
  """
48
 
49
- # Load Models
50
  MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-062825"
51
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
52
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
53
- MODEL_ID_M, trust_remote_code=True, torch_dtype=torch.float16
 
 
54
  ).to(device).eval()
55
 
56
  MODEL_ID_T = "prithivMLmods/Megalodon-OCR-Sync-0713"
57
  processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True)
58
  model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained(
59
- MODEL_ID_T, trust_remote_code=True, torch_dtype=torch.float16
 
 
60
  ).to(device).eval()
61
 
62
  MODEL_ID_C = "nanonets/Nanonets-OCR-s"
63
  processor_c = AutoProcessor.from_pretrained(MODEL_ID_C, trust_remote_code=True)
64
  model_c = Qwen2_5_VLForConditionalGeneration.from_pretrained(
65
- MODEL_ID_C, trust_remote_code=True, torch_dtype=torch.float16
 
 
66
  ).to(device).eval()
67
 
68
  MODEL_ID_G = "echo840/MonkeyOCR"
69
  SUBFOLDER = "Recognition"
70
  processor_g = AutoProcessor.from_pretrained(
71
- MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER
 
 
72
  )
73
  model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained(
74
- MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER, torch_dtype=torch.float16
 
 
 
75
  ).to(device).eval()
76
 
77
-
78
  # Utility functions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def is_arabic_text(text: str) -> bool:
80
- """Check if text contains mostly Arabic characters."""
81
  if not text:
82
  return False
83
- # Simplified check for Arabic characters in the given text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  arabic_chars = 0
85
  total_chars = 0
86
- for char in text:
87
  if char.isalpha():
88
  total_chars += 1
89
- if '\u0600' <= char <= '\u06FF':
90
  arabic_chars += 1
91
  return total_chars > 0 and (arabic_chars / total_chars) > 0.5
92
 
93
  def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = 'text') -> str:
94
- """Convert layout JSON to markdown format."""
95
  import base64
96
  from io import BytesIO
97
  markdown_lines = []
98
  try:
99
- # Sort items by reading order (top to bottom, left to right)
100
  sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox', [0, 0, 0, 0])[1], x.get('bbox', [0, 0, 0, 0])[0]))
101
  for item in sorted_items:
102
  category = item.get('category', '')
103
  text = item.get(text_key, '')
104
  bbox = item.get('bbox', [])
105
-
106
  if category == 'Picture':
107
  if bbox and len(bbox) == 4:
108
  try:
109
- x1, y1, x2, y2 = [int(coord) for coord in bbox]
110
- cropped_img = image.crop((x1, y1, x2, y2))
111
- buffer = BytesIO()
112
- cropped_img.save(buffer, format='PNG')
113
- img_data = base64.b64encode(buffer.getvalue()).decode()
114
- markdown_lines.append(f"![Image](data:image/png;base64,{img_data})\n")
 
 
 
 
 
115
  except Exception as e:
116
- markdown_lines.append("![Image](Image region detected)\n")
 
 
 
117
  elif not text:
118
  continue
119
  elif category == 'Title':
@@ -124,105 +196,161 @@ def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = '
124
  markdown_lines.append(f"{text}\n")
125
  elif category == 'List-item':
126
  markdown_lines.append(f"- {text}\n")
127
- elif category == 'Table' and text.strip().startswith('<'):
128
- markdown_lines.append(f"{text}\n")
129
- elif category == 'Formula' and (text.strip().startswith('$') or '\\' in text):
130
- markdown_lines.append(f"$$\n{text}\n$$\n")
 
 
 
 
 
 
131
  elif category == 'Caption':
132
  markdown_lines.append(f"*{text}*\n")
133
  elif category == 'Footnote':
134
- markdown_lines.append(f"^{text}^\n")
135
- elif category not in ['Page-header', 'Page-footer']:
 
 
136
  markdown_lines.append(f"{text}\n")
 
137
  except Exception as e:
138
  print(f"Error converting to markdown: {e}")
139
- return f"### Error converting to Markdown\n\n```\n{str(layout_data)}\n```"
140
  return "\n".join(markdown_lines)
141
 
142
-
143
  @spaces.GPU
144
- def generate_and_process(model_name: str, image: Image.Image, max_new_tokens: int):
145
- """
146
- Generates a response using streaming, then processes the final output.
147
- Yields updates for the raw stream, final markdown, and JSON output.
148
- """
149
- if image is None:
150
- yield "Please upload an image.", "Please upload an image.", None
151
- return
152
-
153
- # 1. Select Model and Processor
154
- if model_name == "Camel-Doc-OCR-062825":
155
- processor, model = processor_m, model_m
156
- elif model_name == "Megalodon-OCR-Sync-0713":
157
- processor, model = processor_t, model_t
158
- elif model_name == "Nanonets-OCR-s":
159
- processor, model = processor_c, model_c
160
- elif model_name == "MonkeyOCR-Recognition":
161
- processor, model = processor_g, model_g
162
- else:
163
- yield "Invalid model selected.", "Invalid model selected.", None
164
- return
165
-
166
- # 2. Prepare inputs for the model
167
- messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}]}]
168
- prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
169
- inputs = processor(
170
- text=[prompt_full],
171
- images=[image],
172
- return_tensors="pt",
173
- padding=True,
174
- truncation=True,
175
- max_length=MAX_INPUT_TOKEN_LENGTH
176
- ).to(device)
177
-
178
- # 3. Stream the generation
179
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
180
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
181
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
182
- thread.start()
183
-
184
- buffer = ""
185
- # Initial placeholder yield
186
- yield buffer, "⏳ Generating response...", None
187
-
188
- for new_text in streamer:
189
- buffer += new_text
190
- buffer = buffer.replace("<|im_end|>", "")
191
- time.sleep(0.01) # Small delay for smoother streaming
192
- yield buffer, "⏳ Generating response...", None
193
-
194
- # 4. Process the final buffer content
195
  try:
196
- json_match = re.search(r'```json\s*([\s\S]+?)\s*```', buffer)
197
- json_str = json_match.group(1) if json_match else buffer
198
- layout_data = json.loads(json_str)
199
-
200
- markdown_content = layoutjson2md(image, layout_data)
201
-
202
- # Final yield with all processed content
203
- yield buffer, markdown_content, layout_data
204
-
205
- except json.JSONDecodeError:
206
- error_msg = "❌ Failed to parse JSON from model output."
207
- yield buffer, error_msg, {"error": "JSONDecodeError", "raw_output": buffer}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  except Exception as e:
209
- error_msg = f" An error occurred during post-processing: {e}"
210
- yield buffer, error_msg, {"error": str(e), "raw_output": buffer}
211
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
  def create_gradio_interface():
214
- """Create the Gradio interface."""
215
  css = """
216
  .main-container { max-width: 1400px; margin: 0 auto; }
217
  .header-text { text-align: center; color: #2c3e50; margin-bottom: 20px; }
218
  .process-button {
219
- border: none !important; color: white !important; font-weight: bold !important;
220
- background-color: blue !important;
221
- }
 
222
  .process-button:hover {
223
- background-color: darkblue !important; transform: translateY(-2px) !important;
224
- box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important;
225
- }
 
 
 
 
226
  """
227
  with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
228
  gr.HTML("""
@@ -233,64 +361,72 @@ def create_gradio_interface():
233
  </p>
234
  </div>
235
  """)
236
-
237
- # Keep track of the uploaded image
238
- image_state = gr.State(None)
239
-
240
  with gr.Row():
241
- # Left column - Input and controls
242
  with gr.Column(scale=1):
243
  model_choice = gr.Radio(
244
  choices=["Camel-Doc-OCR-062825", "MonkeyOCR-Recognition", "Nanonets-OCR-s", "Megalodon-OCR-Sync-0713"],
245
  label="Select Model",
246
  value="Camel-Doc-OCR-062825"
247
  )
248
- file_input = gr.Image(
249
  label="Upload Image",
250
- type="pil",
251
- sources=['upload']
252
  )
 
253
  with gr.Accordion("Advanced Settings", open=False):
254
  max_new_tokens = gr.Slider(minimum=1000, maximum=32000, value=24000, step=1000, label="Max New Tokens")
255
-
 
256
  process_btn = gr.Button("🚀 Process Document", variant="primary", elem_classes=["process-button"], size="lg")
257
  clear_btn = gr.Button("🗑️ Clear All", variant="secondary")
258
-
259
- # Right column - Results
260
  with gr.Column(scale=2):
261
  with gr.Tabs():
262
  with gr.Tab("📝 Extracted Content"):
263
- output_stream = gr.Textbox(label="Raw Output Stream", interactive=False, lines=10, show_copy_button=True)
264
- with gr.Accordion("(Formatted Result)", open=True):
265
- markdown_output = gr.Markdown(label="Formatted Result (Result.md)")
266
-
267
  with gr.Tab("📋 Layout JSON"):
268
- json_output = gr.JSON(label="Layout Analysis Results (JSON)", value=None)
269
-
270
- # Event Handlers
271
- def handle_file_upload(image):
272
- """Store the uploaded image in the state."""
273
- return image
274
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  def clear_all():
276
- """Clear all data and reset the interface."""
277
- return None, None, "Click 'Process Document' to see extracted content...", None, None
278
-
279
- file_input.upload(handle_file_upload, inputs=[file_input], outputs=[image_state])
280
-
281
  process_btn.click(
282
- generate_and_process,
283
- inputs=[model_choice, image_state, max_new_tokens],
284
- outputs=[output_stream, markdown_output, json_output]
285
  )
286
-
287
  clear_btn.click(
288
  clear_all,
289
- outputs=[file_input, image_state, markdown_output, json_output, output_stream]
290
  )
291
-
292
  return demo
293
 
294
  if __name__ == "__main__":
295
  demo = create_gradio_interface()
296
- demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=True, show_error=True)
 
6
  from io import BytesIO
7
  from typing import Any, Dict, List, Optional, Tuple
8
  import re
 
9
  import time
10
+ from threading import Thread
11
 
12
  import gradio as gr
13
  import requests
14
  import torch
15
+ from PIL import Image
16
  from transformers import (
17
  Qwen2_5_VLForConditionalGeneration,
18
  AutoProcessor,
19
  TextIteratorStreamer,
20
  )
21
+ from qwen_vl_utils import process_vision_info
22
 
23
  # Constants
24
  MIN_PIXELS = 3136
25
  MAX_PIXELS = 11289600
26
  IMAGE_FACTOR = 28
27
+ MAX_INPUT_TOKEN_LENGTH = 2048
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
 
30
+ # Prompts
31
+ prompt = """Please output the layout information from the image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
32
 
33
  1. Bbox format: [x1, y1, x2, y2]
34
 
 
47
  5. Final Output: The entire output must be a single JSON object.
48
  """
49
 
50
+ # Load models
51
  MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-062825"
52
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
53
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
54
+ MODEL_ID_M,
55
+ trust_remote_code=True,
56
+ torch_dtype=torch.float16
57
  ).to(device).eval()
58
 
59
  MODEL_ID_T = "prithivMLmods/Megalodon-OCR-Sync-0713"
60
  processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True)
61
  model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained(
62
+ MODEL_ID_T,
63
+ trust_remote_code=True,
64
+ torch_dtype=torch.float16
65
  ).to(device).eval()
66
 
67
  MODEL_ID_C = "nanonets/Nanonets-OCR-s"
68
  processor_c = AutoProcessor.from_pretrained(MODEL_ID_C, trust_remote_code=True)
69
  model_c = Qwen2_5_VLForConditionalGeneration.from_pretrained(
70
+ MODEL_ID_C,
71
+ trust_remote_code=True,
72
+ torch_dtype=torch.float16
73
  ).to(device).eval()
74
 
75
  MODEL_ID_G = "echo840/MonkeyOCR"
76
  SUBFOLDER = "Recognition"
77
  processor_g = AutoProcessor.from_pretrained(
78
+ MODEL_ID_G,
79
+ trust_remote_code=True,
80
+ subfolder=SUBFOLDER
81
  )
82
  model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained(
83
+ MODEL_ID_G,
84
+ trust_remote_code=True,
85
+ subfolder=SUBFOLDER,
86
+ torch_dtype=torch.float16
87
  ).to(device).eval()
88
 
 
89
  # Utility functions
90
+ def round_by_factor(number: int, factor: int) -> int:
91
+ return round(number / factor) * factor
92
+
93
+ def smart_resize(
94
+ height: int,
95
+ width: int,
96
+ factor: int = 28,
97
+ min_pixels: int = 3136,
98
+ max_pixels: int = 11289600,
99
+ ):
100
+ if max(height, width) / min(height, width) > 200:
101
+ raise ValueError(f"Aspect ratio too extreme: {max(height, width) / min(height, width)}")
102
+ h_bar = max(factor, round_by_factor(height, factor))
103
+ w_bar = max(factor, round_by_factor(width, factor))
104
+ if h_bar * w_bar > max_pixels:
105
+ beta = math.sqrt((height * width) / max_pixels)
106
+ h_bar = round_by_factor(height / beta, factor)
107
+ w_bar = round_by_factor(width / beta, factor)
108
+ elif h_bar * w_bar < min_pixels:
109
+ beta = math.sqrt(min_pixels / (height * width))
110
+ h_bar = round_by_factor(height * beta, factor)
111
+ w_bar = round_by_factor(width * beta, factor)
112
+ return h_bar, w_bar
113
+
114
+ def fetch_image(image_input, min_pixels: int = None, max_pixels: int = None):
115
+ if isinstance(image_input, str):
116
+ if image_input.startswith(("http://", "https://")):
117
+ response = requests.get(image_input)
118
+ image = Image.open(BytesIO(response.content)).convert('RGB')
119
+ else:
120
+ image = Image.open(image_input).convert('RGB')
121
+ elif isinstance(image_input, Image.Image):
122
+ image = image_input.convert('RGB')
123
+ else:
124
+ raise ValueError(f"Invalid image input type: {type(image_input)}")
125
+ if min_pixels or max_pixels:
126
+ min_pixels = min_pixels or MIN_PIXELS
127
+ max_pixels = max_pixels or MAX_PIXELS
128
+ height, width = smart_resize(image.height, image.width, factor=IMAGE_FACTOR, min_pixels=min_pixels, max_pixels=max_pixels)
129
+ image = image.resize((width, height), Image.LANCZOS)
130
+ return image
131
+
132
  def is_arabic_text(text: str) -> bool:
 
133
  if not text:
134
  return False
135
+ header_pattern = r'^#{1,6}\s+(.+)$'
136
+ paragraph_pattern = r'^(?!#{1,6}\s|!\[|```|\||\s*[-*+]\s|\s*\d+\.\s)(.+)$'
137
+ content_text = []
138
+ for line in text.split('\n'):
139
+ line = line.strip()
140
+ if not line:
141
+ continue
142
+ header_match = re.match(header_pattern, line, re.MULTILINE)
143
+ if header_match:
144
+ content_text.append(header_match.group(1))
145
+ continue
146
+ if re.match(paragraph_pattern, line, re.MULTILINE):
147
+ content_text.append(line)
148
+ if not content_text:
149
+ return False
150
+ combined_text = ' '.join(content_text)
151
  arabic_chars = 0
152
  total_chars = 0
153
+ for char in combined_text:
154
  if char.isalpha():
155
  total_chars += 1
156
+ if ('\u0600' <= char <= '\u06FF') or ('\u0750' <= char <= '\u077F') or ('\u08A0' <= char <= '\u08FF'):
157
  arabic_chars += 1
158
  return total_chars > 0 and (arabic_chars / total_chars) > 0.5
159
 
160
  def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = 'text') -> str:
 
161
  import base64
162
  from io import BytesIO
163
  markdown_lines = []
164
  try:
 
165
  sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox', [0, 0, 0, 0])[1], x.get('bbox', [0, 0, 0, 0])[0]))
166
  for item in sorted_items:
167
  category = item.get('category', '')
168
  text = item.get(text_key, '')
169
  bbox = item.get('bbox', [])
 
170
  if category == 'Picture':
171
  if bbox and len(bbox) == 4:
172
  try:
173
+ x1, y1, x2, y2 = bbox
174
+ x1, y1 = max(0, int(x1)), max(0, int(y1))
175
+ x2, y2 = min(image.width, int(x2)), min(image.height, int(y2))
176
+ if x2 > x1 and y2 > y1:
177
+ cropped_img = image.crop((x1, y1, x2, y2))
178
+ buffer = BytesIO()
179
+ cropped_img.save(buffer, format='PNG')
180
+ img_data = base64.b64encode(buffer.getvalue()).decode()
181
+ markdown_lines.append(f"<image-card alt="Image" src="data:image/png;base64,{img_data}" ></image-card>\n")
182
+ else:
183
+ markdown_lines.append("<image-card alt="Image" src="Image region detected" ></image-card>\n")
184
  except Exception as e:
185
+ print(f"Error processing image region: {e}")
186
+ markdown_lines.append("<image-card alt="Image" src="Image detected" ></image-card>\n")
187
+ else:
188
+ markdown_lines.append("<image-card alt="Image" src="Image detected" ></image-card>\n")
189
  elif not text:
190
  continue
191
  elif category == 'Title':
 
196
  markdown_lines.append(f"{text}\n")
197
  elif category == 'List-item':
198
  markdown_lines.append(f"- {text}\n")
199
+ elif category == 'Table':
200
+ if text.strip().startswith('<'):
201
+ markdown_lines.append(f"{text}\n")
202
+ else:
203
+ markdown_lines.append(f"**Table:** {text}\n")
204
+ elif category == 'Formula':
205
+ if text.strip().startswith('$') or '\\' in text:
206
+ markdown_lines.append(f"$$ \n{text}\n $$\n")
207
+ else:
208
+ markdown_lines.append(f"**Formula:** {text}\n")
209
  elif category == 'Caption':
210
  markdown_lines.append(f"*{text}*\n")
211
  elif category == 'Footnote':
212
+ markdown_lines.append(f"^{text}^\n")
213
+ elif category in ['Page-header', 'Page-footer']:
214
+ continue
215
+ else:
216
  markdown_lines.append(f"{text}\n")
217
+ markdown_lines.append("")
218
  except Exception as e:
219
  print(f"Error converting to markdown: {e}")
220
+ return str(layout_data)
221
  return "\n".join(markdown_lines)
222
 
 
223
  @spaces.GPU
224
+ def inference(model_name: str, image: Image.Image, text: str, max_new_tokens: int = 1024) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  try:
226
+ if model_name == "Camel-Doc-OCR-062825":
227
+ processor = processor_m
228
+ model = model_m
229
+ elif model_name == "Megalodon-OCR-Sync-0713":
230
+ processor = processor_t
231
+ model = model_t
232
+ elif model_name == "Nanonets-OCR-s":
233
+ processor = processor_c
234
+ model = model_c
235
+ elif model_name == "MonkeyOCR-Recognition":
236
+ processor = processor_g
237
+ model = model_g
238
+ else:
239
+ raise ValueError(f"Invalid model selected: {model_name}")
240
+
241
+ if image is None:
242
+ yield "Please upload an image.", "Please upload an image."
243
+ return
244
+
245
+ messages = [{
246
+ "role": "user",
247
+ "content": [
248
+ {"type": "image", "image": image},
249
+ {"type": "text", "text": text},
250
+ ]
251
+ }]
252
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
253
+ inputs = processor(
254
+ text=[prompt_full],
255
+ images=[image],
256
+ return_tensors="pt",
257
+ padding=True,
258
+ truncation=False,
259
+ max_length=MAX_INPUT_TOKEN_LENGTH
260
+ ).to(device)
261
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
262
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
263
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
264
+ thread.start()
265
+ buffer = ""
266
+ for new_text in streamer:
267
+ buffer += new_text
268
+ buffer = buffer.replace("<|im_end|>", "")
269
+ time.sleep(0.01)
270
+ yield buffer, buffer
271
  except Exception as e:
272
+ print(f"Error during inference: {e}")
273
+ traceback.print_exc()
274
+ yield f"Error during inference: {str(e)}", f"Error during inference: {str(e)}"
275
+
276
+ def process_image(
277
+ model_name: str,
278
+ image: Image.Image,
279
+ min_pixels: Optional[int] = None,
280
+ max_pixels: Optional[int] = None,
281
+ max_new_tokens: int = 1024
282
+ ) -> Dict[str, Any]:
283
+ try:
284
+ if min_pixels or max_pixels:
285
+ image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
286
+ result = {
287
+ 'original_image': image,
288
+ 'raw_output': "",
289
+ 'layout_result': None,
290
+ 'markdown_content': None
291
+ }
292
+ buffer = ""
293
+ for raw_output, _ in inference(model_name, image, prompt, max_new_tokens):
294
+ buffer = raw_output
295
+ result['raw_output'] = buffer
296
+ yield result
297
+ try:
298
+ json_match = re.search(r'```json
299
+ json_str = json_match.group(1) if json_match else buffer
300
+ layout_data = json.loads(json_str)
301
+ result['layout_result'] = layout_data
302
+ try:
303
+ markdown_content = layoutjson2md(image, layout_data, text_key='text')
304
+ result['markdown_content'] = markdown_content
305
+ except Exception as e:
306
+ print(f"Error generating markdown: {e}")
307
+ result['markdown_content'] = buffer
308
+ except json.JSONDecodeError:
309
+ print("Failed to parse JSON output, using raw output")
310
+ result['markdown_content'] = buffer
311
+ yield result
312
+ except Exception as e:
313
+ print(f"Error processing image: {e}")
314
+ traceback.print_exc()
315
+ result = {
316
+ 'original_image': image,
317
+ 'raw_output': f"Error processing image: {str(e)}",
318
+ 'layout_result': None,
319
+ 'markdown_content': f"Error processing image: {str(e)}"
320
+ }
321
+ yield result
322
+
323
+ def load_file_for_preview(file_path: str) -> Tuple[Optional[Image.Image], str]:
324
+ if not file_path or not os.path.exists(file_path):
325
+ return None, "No file selected"
326
+ file_ext = os.path.splitext(file_path)[1].lower()
327
+ try:
328
+ if file_ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
329
+ image = Image.open(file_path).convert('RGB')
330
+ return image, "Image loaded"
331
+ else:
332
+ return None, f"Unsupported file format: {file_ext}"
333
+ except Exception as e:
334
+ print(f"Error loading file: {e}")
335
+ return None, f"Error loading file: {str(e)}"
336
 
337
  def create_gradio_interface():
 
338
  css = """
339
  .main-container { max-width: 1400px; margin: 0 auto; }
340
  .header-text { text-align: center; color: #2c3e50; margin-bottom: 20px; }
341
  .process-button {
342
+ border: none !important;
343
+ color: white !important;
344
+ font-weight: bold !important;
345
+ background-color: blue !important;}
346
  .process-button:hover {
347
+ background-color: darkblue !important;
348
+ transform: translateY(-2px) !important;
349
+ box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
350
+ .info-box { border: 1px solid #dee2e6; border-radius: 8px; padding: 15px; margin: 10px 0; }
351
+ .page-info { text-align: center; padding: 8px 16px; border-radius: 20px; font-weight: bold; margin: 10px 0; }
352
+ .model-status { padding: 10px; border-radius: 8px; margin: 10px 0; text-align: center; font-weight: bold; }
353
+ .status-ready { background: #d1edff; color: #0c5460; border: 1px solid #b8daff; }
354
  """
355
  with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
356
  gr.HTML("""
 
361
  </p>
362
  </div>
363
  """)
 
 
 
 
364
  with gr.Row():
 
365
  with gr.Column(scale=1):
366
  model_choice = gr.Radio(
367
  choices=["Camel-Doc-OCR-062825", "MonkeyOCR-Recognition", "Nanonets-OCR-s", "Megalodon-OCR-Sync-0713"],
368
  label="Select Model",
369
  value="Camel-Doc-OCR-062825"
370
  )
371
+ file_input = gr.File(
372
  label="Upload Image",
373
+ file_types=[".jpg", ".jpeg", ".png", ".bmp", ".tiff"],
374
+ type="filepath"
375
  )
376
+ image_preview = gr.Image(label="Preview", type="pil", interactive=False, height=300)
377
  with gr.Accordion("Advanced Settings", open=False):
378
  max_new_tokens = gr.Slider(minimum=1000, maximum=32000, value=24000, step=1000, label="Max New Tokens")
379
+ min_pixels = gr.Number(value=MIN_PIXELS, label="Min Pixels")
380
+ max_pixels = gr.Number(value=MAX_PIXELS, label="Max Pixels")
381
  process_btn = gr.Button("🚀 Process Document", variant="primary", elem_classes=["process-button"], size="lg")
382
  clear_btn = gr.Button("🗑️ Clear All", variant="secondary")
 
 
383
  with gr.Column(scale=2):
384
  with gr.Tabs():
385
  with gr.Tab("📝 Extracted Content"):
386
+ output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=2, show_copy_button=True)
387
+ with gr.Accordion("(Result.md)", open=False):
388
+ markdown_output = gr.Markdown(label="Formatted Result (Result.Md)")
 
389
  with gr.Tab("📋 Layout JSON"):
390
+ json_output = gr.JSON(label="Layout Analysis Results", value=None)
391
+ def process_document(model_name, file_path, max_tokens, min_pix, max_pix):
392
+ try:
393
+ if not file_path:
394
+ return "Please upload an image.", "Please upload an image.", None
395
+ image, status = load_file_for_preview(file_path)
396
+ if image is None:
397
+ return status, status, None
398
+ for result in process_image(model_name, image, min_pixels=int(min_pix) if min_pix else None, max_pixels=int(max_pix) if max_pix else None, max_new_tokens=max_tokens):
399
+ raw_output = result['raw_output']
400
+ markdown_content = result['markdown_content'] or raw_output
401
+ if is_arabic_text(markdown_content):
402
+ markdown_update = gr.update(value=markdown_content, rtl=True)
403
+ else:
404
+ markdown_update = markdown_content
405
+ yield raw_output, markdown_update, result['layout_result']
406
+ except Exception as e:
407
+ error_msg = f"Error processing document: {str(e)}"
408
+ print(error_msg)
409
+ traceback.print_exc()
410
+ yield error_msg, error_msg, None
411
+ def handle_file_upload(file_path):
412
+ if not file_path:
413
+ return None, "No file loaded"
414
+ image, page_info = load_file_for_preview(file_path)
415
+ return image, page_info
416
  def clear_all():
417
+ return None, None, "No file loaded", "", "Click 'Process Document' to see extracted content...", None
418
+ file_input.change(handle_file_upload, inputs=[file_input], outputs=[image_preview, output])
 
 
 
419
  process_btn.click(
420
+ process_document,
421
+ inputs=[model_choice, file_input, max_new_tokens, min_pixels, max_pixels],
422
+ outputs=[output, markdown_output, json_output]
423
  )
 
424
  clear_btn.click(
425
  clear_all,
426
+ outputs=[file_input, image_preview, output, markdown_output, json_output]
427
  )
 
428
  return demo
429
 
430
  if __name__ == "__main__":
431
  demo = create_gradio_interface()
432
+ demo.queue(max_size=10).launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True, show_error=True)