prithivMLmods commited on
Commit
129f25d
Β·
verified Β·
1 Parent(s): 566263b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +287 -127
app.py CHANGED
@@ -18,227 +18,387 @@ from transformers import (
18
  AutoProcessor,
19
  TextIteratorStreamer,
20
  )
 
21
 
22
- # --- Constants and Model Setup ---
23
- MAX_INPUT_TOKEN_LENGTH = 4096
 
 
 
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
25
 
26
- # The detailed prompt to instruct the model to generate structured JSON
27
  prompt = """Please output the layout information from the image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
28
 
29
  1. Bbox format: [x1, y1, x2, y2]
 
30
  2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
 
31
  3. Text Extraction & Formatting Rules:
32
  - Picture: For the 'Picture' category, the text field should be omitted.
33
  - Formula: Format its text as LaTeX.
34
- - Table: For tables, provide the content in a structured format within the JSON.
35
  - All Others (Text, Title, etc.): Format their text as Markdown.
 
36
  4. Constraints:
37
  - The output text must be the original text from the image, with no translation.
38
  - All layout elements must be sorted according to human reading order.
39
- 5. Final Output: The entire output must be a single JSON object wrapped in ```json ... ```.
 
40
  """
41
 
42
  # Load models
43
  MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-062825"
44
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
45
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
46
- MODEL_ID_M, trust_remote_code=True, torch_dtype=torch.float16
 
 
47
  ).to(device).eval()
48
 
49
  MODEL_ID_T = "prithivMLmods/Megalodon-OCR-Sync-0713"
50
  processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True)
51
  model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained(
52
- MODEL_ID_T, trust_remote_code=True, torch_dtype=torch.float16
 
 
53
  ).to(device).eval()
54
 
55
  MODEL_ID_C = "nanonets/Nanonets-OCR-s"
56
  processor_c = AutoProcessor.from_pretrained(MODEL_ID_C, trust_remote_code=True)
57
  model_c = Qwen2_5_VLForConditionalGeneration.from_pretrained(
58
- MODEL_ID_C, trust_remote_code=True, torch_dtype=torch.float16
 
 
59
  ).to(device).eval()
60
 
61
  MODEL_ID_G = "echo840/MonkeyOCR"
62
  SUBFOLDER = "Recognition"
63
  processor_g = AutoProcessor.from_pretrained(
64
- MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER
 
 
65
  )
66
  model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained(
67
- MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER, torch_dtype=torch.float16
 
 
 
68
  ).to(device).eval()
69
 
70
- # --- Utility Functions ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- def layoutjson2md(layout_data: List[Dict]) -> str:
73
- """Converts the structured JSON layout data into formatted Markdown."""
 
74
  markdown_lines = []
75
  try:
76
  sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox', [0, 0, 0, 0])[1], x.get('bbox', [0, 0, 0, 0])[0]))
77
  for item in sorted_items:
78
  category = item.get('category', '')
79
- text = item.get('text', '')
80
-
81
- if not text:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  continue
83
-
84
- if category == 'Title':
85
  markdown_lines.append(f"# {text}\n")
86
  elif category == 'Section-header':
87
  markdown_lines.append(f"## {text}\n")
 
 
 
 
88
  elif category == 'Table':
89
- # Check if the text is a dictionary representing a structured table
90
- if isinstance(text, dict) and 'header' in text and 'rows' in text:
91
- header = '| ' + ' | '.join(map(str, text['header'])) + ' |'
92
- separator = '| ' + ' | '.join(['---'] * len(text['header'])) + ' |'
93
- rows = ['| ' + ' | '.join(map(str, row)) + ' |' for row in text['rows']]
94
- markdown_lines.append(header)
95
- markdown_lines.append(separator)
96
- markdown_lines.extend(rows)
97
- markdown_lines.append("\n")
98
- else:
99
- # Fallback for unstructured table text
100
  markdown_lines.append(f"{text}\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  else:
102
  markdown_lines.append(f"{text}\n")
103
-
104
  except Exception as e:
105
  print(f"Error converting to markdown: {e}")
106
- return "### Error converting JSON to Markdown."
107
  return "\n".join(markdown_lines)
108
 
109
- # --- Core Application Logic ---
110
-
111
  @spaces.GPU
112
- def process_document_stream(model_name: str, image: Image.Image, text_prompt: str, max_new_tokens: int):
113
- """
114
- Main generator function that streams raw model output and then processes it into
115
- formatted Markdown and structured JSON for the UI.
116
- """
117
- if image is None:
118
- yield "Please upload an image.", "Please upload an image.", None
119
- return
120
-
121
- # Select the model and processor
122
- if model_name == "Camel-Doc-OCR-062825": processor, model = processor_m, model_m
123
- elif model_name == "Megalodon-OCR-Sync-0713": processor, model = processor_t, model_t
124
- elif model_name == "Nanonets-OCR-s": processor, model = processor_c, model_c
125
- elif model_name == "MonkeyOCR-Recognition": processor, model = processor_g, model_g
126
- else:
127
- yield "Invalid model selected.", "Invalid model selected.", None
128
- return
129
-
130
- # Prepare model inputs
131
- messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": text_prompt}]}]
132
- prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
133
- inputs = processor(text=[prompt_full], images=[image], return_tensors="pt", padding=True, truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH).to(device)
134
-
135
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
136
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
137
-
138
- # Start generation in a separate thread
139
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
140
- thread.start()
141
-
142
- # Stream raw output to the UI
143
- buffer = ""
144
- for new_text in streamer:
145
- buffer += new_text
146
- buffer = buffer.replace("<|im_end|>", "")
147
- time.sleep(0.01)
148
- # Yield the raw stream and placeholders for the final results
149
- yield buffer, "⏳ Formatting Markdown...", {"status": "processing"}
150
-
151
- # After streaming is complete, process the final buffer
152
  try:
153
- # Extract the JSON object from the buffer
154
- json_match = re.search(r'```json\s*([\s\S]+?)\s*```', buffer)
155
- if not json_match:
156
- raise json.JSONDecodeError("JSON object not found in the model's output.", buffer, 0)
157
-
158
- json_str = json_match.group(1)
159
- layout_data = json.loads(json_str)
160
-
161
- # Convert the parsed JSON to formatted markdown
162
- markdown_content = layoutjson2md(layout_data)
163
-
164
- # Yield the final, complete results
165
- yield buffer, markdown_content, layout_data
166
-
167
- except json.JSONDecodeError as e:
168
- print(f"JSON parsing failed: {e}")
169
- error_md = f"❌ **Error:** Failed to parse JSON from the model's output.\n\nSee the raw output stream for details."
170
- error_json = {"error": "JSONDecodeError", "details": str(e), "raw_output": buffer}
171
- yield buffer, error_md, error_json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  except Exception as e:
173
- print(f"An unexpected error occurred: {e}")
174
- yield buffer, f"❌ An unexpected error occurred: {e}", None
 
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
- # --- Gradio UI Definition ---
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
  def create_gradio_interface():
180
- """Builds and returns the Gradio web interface."""
181
  css = """
182
  .main-container { max-width: 1400px; margin: 0 auto; }
183
- .process-button { border: none !important; color: white !important; font-weight: bold !important; background-color: blue !important;}
184
- .process-button:hover { background-color: darkblue !important; transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
 
 
 
 
 
 
 
 
 
 
 
 
185
  """
186
  with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
187
  gr.HTML("""
188
  <div class="title" style="text-align: center">
189
  <h1>Dot<span style="color: red;">●</span><strong></strong>OCR Comparator</h1>
190
  <p style="font-size: 1.1em; color: #6b7280; margin-bottom: 0.6em;">
191
- Advanced Vision-Language Model for Image Layout Analysis
192
  </p>
193
  </div>
194
  """)
195
-
196
  with gr.Row():
197
- # --- Left Column (Inputs) ---
198
  with gr.Column(scale=1):
199
  model_choice = gr.Radio(
200
  choices=["Camel-Doc-OCR-062825", "MonkeyOCR-Recognition", "Nanonets-OCR-s", "Megalodon-OCR-Sync-0713"],
201
  label="Select Model",
202
  value="Camel-Doc-OCR-062825"
203
  )
204
- image_input = gr.Image(label="Upload Image", type="pil", sources=['upload'])
 
 
 
 
 
205
  with gr.Accordion("Advanced Settings", open=False):
206
- max_new_tokens = gr.Slider(minimum=1000, maximum=8192, value=4096, step=256, label="Max New Tokens")
207
-
 
208
  process_btn = gr.Button("πŸš€ Process Document", variant="primary", elem_classes=["process-button"], size="lg")
209
  clear_btn = gr.Button("πŸ—‘οΈ Clear All", variant="secondary")
210
-
211
- # --- Right Column (Outputs) ---
212
  with gr.Column(scale=2):
213
  with gr.Tabs():
214
  with gr.Tab("πŸ“ Extracted Content"):
215
- raw_output_stream = gr.Textbox(label="Raw Model Output Stream", interactive=False, lines=15, show_copy_button=True)
216
- with gr.Accordion("(Formatted Result)", open=True):
217
- markdown_output = gr.Markdown(label="Formatted Markdown (from JSON)")
218
-
219
  with gr.Tab("πŸ“‹ Layout Analysis Results"):
220
- json_output = gr.JSON(label="Structured Layout Data (JSON)", value=None)
221
-
222
- # --- Event Handlers ---
223
- def clear_all_outputs():
224
- """Resets all input and output fields to their default state."""
225
- return None, "Raw output will appear here.", "Formatted results will appear here.", None
226
-
227
- # Connect the process button to the main generator function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  process_btn.click(
229
- fn=process_document_stream,
230
- inputs=[model_choice, image_input, gr.Textbox(value=prompt, visible=False), max_new_tokens],
231
- outputs=[raw_output_stream, markdown_output, json_output]
232
  )
233
-
234
- # Connect the clear button
235
  clear_btn.click(
236
- clear_all_outputs,
237
- outputs=[image_input, raw_output_stream, markdown_output, json_output]
238
  )
239
-
240
  return demo
241
 
242
  if __name__ == "__main__":
243
  demo = create_gradio_interface()
244
- demo.queue().launch(server_name="0.0.0.0", server_port=7860, show_error=True)
 
18
  AutoProcessor,
19
  TextIteratorStreamer,
20
  )
21
+ from qwen_vl_utils import process_vision_info
22
 
23
+ # Constants
24
+ MIN_PIXELS = 3136
25
+ MAX_PIXELS = 11289600
26
+ IMAGE_FACTOR = 28
27
+ MAX_INPUT_TOKEN_LENGTH = 2048
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
 
30
+ # Prompts
31
  prompt = """Please output the layout information from the image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
32
 
33
  1. Bbox format: [x1, y1, x2, y2]
34
+
35
  2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
36
+
37
  3. Text Extraction & Formatting Rules:
38
  - Picture: For the 'Picture' category, the text field should be omitted.
39
  - Formula: Format its text as LaTeX.
40
+ - Table: Format its text as HTML.
41
  - All Others (Text, Title, etc.): Format their text as Markdown.
42
+
43
  4. Constraints:
44
  - The output text must be the original text from the image, with no translation.
45
  - All layout elements must be sorted according to human reading order.
46
+
47
+ 5. Final Output: The entire output must be a single JSON object.
48
  """
49
 
50
  # Load models
51
  MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-062825"
52
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
53
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
54
+ MODEL_ID_M,
55
+ trust_remote_code=True,
56
+ torch_dtype=torch.float16
57
  ).to(device).eval()
58
 
59
  MODEL_ID_T = "prithivMLmods/Megalodon-OCR-Sync-0713"
60
  processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True)
61
  model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained(
62
+ MODEL_ID_T,
63
+ trust_remote_code=True,
64
+ torch_dtype=torch.float16
65
  ).to(device).eval()
66
 
67
  MODEL_ID_C = "nanonets/Nanonets-OCR-s"
68
  processor_c = AutoProcessor.from_pretrained(MODEL_ID_C, trust_remote_code=True)
69
  model_c = Qwen2_5_VLForConditionalGeneration.from_pretrained(
70
+ MODEL_ID_C,
71
+ trust_remote_code=True,
72
+ torch_dtype=torch.float16
73
  ).to(device).eval()
74
 
75
  MODEL_ID_G = "echo840/MonkeyOCR"
76
  SUBFOLDER = "Recognition"
77
  processor_g = AutoProcessor.from_pretrained(
78
+ MODEL_ID_G,
79
+ trust_remote_code=True,
80
+ subfolder=SUBFOLDER
81
  )
82
  model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained(
83
+ MODEL_ID_G,
84
+ trust_remote_code=True,
85
+ subfolder=SUBFOLDER,
86
+ torch_dtype=torch.float16
87
  ).to(device).eval()
88
 
89
+ # Utility functions
90
+ def round_by_factor(number: int, factor: int) -> int:
91
+ return round(number / factor) * factor
92
+
93
+ def smart_resize(
94
+ height: int,
95
+ width: int,
96
+ factor: int = 28,
97
+ min_pixels: int = 3136,
98
+ max_pixels: int = 11289600,
99
+ ):
100
+ if max(height, width) / min(height, width) > 200:
101
+ raise ValueError(f"Aspect ratio too extreme: {max(height, width) / min(height, width)}")
102
+ h_bar = max(factor, round_by_factor(height, factor))
103
+ w_bar = max(factor, round_by_factor(width, factor))
104
+ if h_bar * w_bar > max_pixels:
105
+ beta = math.sqrt((height * width) / max_pixels)
106
+ h_bar = round_by_factor(height / beta, factor)
107
+ w_bar = round_by_factor(width / beta, factor)
108
+ elif h_bar * w_bar < min_pixels:
109
+ beta = math.sqrt(min_pixels / (height * width))
110
+ h_bar = round_by_factor(height * beta, factor)
111
+ w_bar = round_by_factor(width * beta, factor)
112
+ return h_bar, w_bar
113
+
114
+ def fetch_image(image_input, min_pixels: int = None, max_pixels: int = None):
115
+ if isinstance(image_input, str):
116
+ if image_input.startswith(("http://", "https://")):
117
+ response = requests.get(image_input)
118
+ image = Image.open(BytesIO(response.content)).convert('RGB')
119
+ else:
120
+ image = Image.open(image_input).convert('RGB')
121
+ elif isinstance(image_input, Image.Image):
122
+ image = image_input.convert('RGB')
123
+ else:
124
+ raise ValueError(f"Invalid image input type: {type.image_input)}")
125
+ if min_pixels or max_pixels:
126
+ min_pixels = min_pixels or MIN_PIXELS
127
+ max_pixels = max_pixels or MAX_PIXELS
128
+ height, width = smart_resize(image.height, image.width, factor=IMAGE_FACTOR, min_pixels=min_pixels, max_pixels=max_pixels)
129
+ image = image.resize((width, height), Image.LANCZOS)
130
+ return image
131
+
132
+ def is_arabic_text(text: str) -> bool:
133
+ if not text:
134
+ return False
135
+ header_pattern = r'^#{1,6}\s+(.+)$'
136
+ paragraph_pattern = r'^(?!#{1,6}\s|!\[|```|\||\s*[-*+]\s|\s*\d+\.\s)(.+)$'
137
+ content_text = []
138
+ for line in text.split('\n'):
139
+ line = line.strip()
140
+ if not line:
141
+ continue
142
+ header_match = re.match(header_pattern, line, re.MULTILINE)
143
+ if header_match:
144
+ content_text.append(header_match.group(1))
145
+ continue
146
+ if re.match(paragraph_pattern, line, re.MULTILINE):
147
+ content_text.append(line)
148
+ if not content_text:
149
+ return False
150
+ combined_text = ' '.join(content_text)
151
+ arabic_chars = 0
152
+ total_chars = 0
153
+ for char in combined_text:
154
+ if char.isalpha():
155
+ total_chars += 1
156
+ if ('\u0600' <= char <= '\u06FF') or ('\u0750' <= char <= '\u077F') or ('\u08A0' <= char <= '\u08FF'):
157
+ arabic_chars += 1
158
+ return total_chars > 0 and (arabic_chars / total_chars) > 0.5
159
 
160
+ def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = 'text') -> str:
161
+ import base64
162
+ from io import BytesIO
163
  markdown_lines = []
164
  try:
165
  sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox', [0, 0, 0, 0])[1], x.get('bbox', [0, 0, 0, 0])[0]))
166
  for item in sorted_items:
167
  category = item.get('category', '')
168
+ text = item.get(text_key, '')
169
+ bbox = item.get('bbox', [])
170
+ if category == 'Picture':
171
+ if bbox and len(bbox) == 4:
172
+ try:
173
+ x1, y1, x2, y2 = bbox
174
+ x1, y1 = max(0, int(x1)), max(0, int(y1))
175
+ x2, y2 = min(image.width, int(x2)), min(image.height, int(y2))
176
+ if x2 > x1 and y2 > y1:
177
+ cropped_img = image.crop((x1, y1, x2, y2))
178
+ buffer = BytesIO()
179
+ cropped_img.save(buffer, format='PNG')
180
+ img_data = base64.b64encode(buffer.getvalue()).decode()
181
+ markdown_lines.append(f"<image-card alt="Image" src="data:image/png;base64,{img_data}" ></image-card>\n")
182
+ else:
183
+ markdown_lines.append("<image-card alt="Image" src="Image region detected" ></image-card>\n")
184
+ except Exception as e:
185
+ print(f"Error processing image region: {e}")
186
+ markdown_lines.append("<image-card alt="Image" src="Image detected" ></image-card>\n")
187
+ else:
188
+ markdown_lines.append("<image-card alt="Image" src="Image detected" ></image-card>\n")
189
+ elif not text:
190
  continue
191
+ elif category == 'Title':
 
192
  markdown_lines.append(f"# {text}\n")
193
  elif category == 'Section-header':
194
  markdown_lines.append(f"## {text}\n")
195
+ elif category == 'Text':
196
+ markdown_lines.append(f"{text}\n")
197
+ elif category == 'List-item':
198
+ markdown_lines.append(f"- {text}\n")
199
  elif category == 'Table':
200
+ if text.strip().startswith('<'):
 
 
 
 
 
 
 
 
 
 
201
  markdown_lines.append(f"{text}\n")
202
+ else:
203
+ markdown_lines.append(f"**Table:** {text}\n")
204
+ elif category == 'Formula':
205
+ if text.strip().startswith('$') or '\\' in text:
206
+ markdown_lines.append(f"$$ \n{text}\n $$\n")
207
+ else:
208
+ markdown_lines.append(f"**Formula:** {text}\n")
209
+ elif category == 'Caption':
210
+ markdown_lines.append(f"*{text}*\n")
211
+ elif category == 'Footnote':
212
+ markdown_lines.append(f"^{text}^\n")
213
+ elif category in ['Page-header', 'Page-footer']:
214
+ continue
215
  else:
216
  markdown_lines.append(f"{text}\n")
217
+ markdown_lines.append("")
218
  except Exception as e:
219
  print(f"Error converting to markdown: {e}")
220
+ return str(layout_data)
221
  return "\n".join(markdown_lines)
222
 
 
 
223
  @spaces.GPU
224
+ def inference(model_name: str, image: Image.Image, text: str, max_new_tokens: int = 1024) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  try:
226
+ if model_name == "Camel-Doc-OCR-062825":
227
+ processor = processor_m
228
+ model = model_m
229
+ elif model_name == "Megalodon-OCR-Sync-0713":
230
+ processor = processor_t
231
+ model = model_t
232
+ elif model_name == "Nanonets-OCR-s":
233
+ processor = processor_c
234
+ model = model_c
235
+ elif model_name == "MonkeyOCR-Recognition":
236
+ processor = processor_g
237
+ model = model_g
238
+ else:
239
+ raise ValueError(f"Invalid model selected: {model_name}")
240
+
241
+ if image is None:
242
+ yield "Please upload an image.", "Please upload an image."
243
+ return
244
+
245
+ messages = [{
246
+ "role": "user",
247
+ "content": [
248
+ {"type": "image", "image": image},
249
+ {"type": "text", "text": text},
250
+ ]
251
+ }]
252
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
253
+ inputs = processor(
254
+ text=[prompt_full],
255
+ images=[image],
256
+ return_tensors="pt",
257
+ padding=True,
258
+ truncation=False,
259
+ max_length=MAX_INPUT_TOKEN_LENGTH
260
+ ).to(device)
261
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
262
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
263
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
264
+ thread.start()
265
+ buffer = ""
266
+ for new_text in streamer:
267
+ buffer += new_text
268
+ buffer = buffer.replace("<|im_end|>", "")
269
+ time.sleep(0.01)
270
+ yield buffer, buffer
271
  except Exception as e:
272
+ print(f"Error during inference: {e}")
273
+ traceback.print_exc()
274
+ yield f"Error during inference: {str(e)}", f"Error during inference: {str(e)}"
275
 
276
+ def process_image(
277
+ model_name: str,
278
+ image: Image.Image,
279
+ min_pixels: Optional[int] = None,
280
+ max_pixels: Optional[int] = None,
281
+ max_new_tokens: int = 1024
282
+ ):
283
+ try:
284
+ if min_pixels or max_pixels:
285
+ image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
286
+ buffer = ""
287
+ for raw_output, _ in inference(model_name, image, prompt, max_new_tokens):
288
+ buffer = raw_output
289
+ yield buffer, None # Yield raw OCR stream and None for JSON during processing
290
+ try:
291
+ json_match = re.search(r'```json
292
+ json_str = json_match.group(1) if json_match else buffer
293
+ layout_data = json.loads(json_str)
294
+ yield buffer, layout_data # Final yield with raw OCR and parsed JSON
295
+ except json.JSONDecodeError:
296
+ print("Failed to parse JSON output, using raw output")
297
+ yield buffer, None # If JSON parsing fails, yield raw OCR with no JSON
298
+ except Exception as e:
299
+ print(f"Error processing image: {e}")
300
+ traceback.print_exc()
301
+ yield f"Error processing image: {str(e)}", None
302
 
303
+ def load_file_for_preview(file_path: str) -> Tuple[Optional[Image.Image], str]:
304
+ if not file_path or not os.path.exists(file_path):
305
+ return None, "No file selected"
306
+ file_ext = os.path.splitext(file_path)[1].lower()
307
+ try:
308
+ if file_ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
309
+ image = Image.open(file_path).convert('RGB')
310
+ return image, "Image loaded"
311
+ else:
312
+ return None, f"Unsupported file format: {file_ext}"
313
+ except Exception as e:
314
+ print(f"Error loading file: {e}")
315
+ return None, f"Error loading file: {str(e)}"
316
 
317
  def create_gradio_interface():
 
318
  css = """
319
  .main-container { max-width: 1400px; margin: 0 auto; }
320
+ .header-text { text-align: center; color: #2c3e50; margin-bottom: 20px; }
321
+ .process-button {
322
+ border: none !important;
323
+ color: white !important;
324
+ font-weight: bold !important;
325
+ background-color: blue !important;}
326
+ .process-button:hover {
327
+ background-color: darkblue !important;
328
+ transform: translateY(-2px) !important;
329
+ box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
330
+ .info-box { border: 1px solid #dee2e6; border-radius: 8px; padding: 15px; margin: 10px 0; }
331
+ .page-info { text-align: center; padding: 8px 16px; border-radius: 20px; font-weight: bold; margin: 10px 0; }
332
+ .model-status { padding: 10px; border-radius: 8px; margin: 10px 0; text-align: center; font-weight: bold; }
333
+ .status-ready { background: #d1edff; color: #0c5460; border: 1px solid #b8daff; }
334
  """
335
  with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
336
  gr.HTML("""
337
  <div class="title" style="text-align: center">
338
  <h1>Dot<span style="color: red;">●</span><strong></strong>OCR Comparator</h1>
339
  <p style="font-size: 1.1em; color: #6b7280; margin-bottom: 0.6em;">
340
+ Advanced vision-language model for image to markdown document processing
341
  </p>
342
  </div>
343
  """)
 
344
  with gr.Row():
 
345
  with gr.Column(scale=1):
346
  model_choice = gr.Radio(
347
  choices=["Camel-Doc-OCR-062825", "MonkeyOCR-Recognition", "Nanonets-OCR-s", "Megalodon-OCR-Sync-0713"],
348
  label="Select Model",
349
  value="Camel-Doc-OCR-062825"
350
  )
351
+ file_input = gr.File(
352
+ label="Upload Image",
353
+ file_types =[".jpg", ".jpeg", ".png", ".bmp", ".tiff"],
354
+ type="filepath"
355
+ )
356
+ image_preview = gr.Image(label="Preview", type="pil", interactive=False, height=300)
357
  with gr.Accordion("Advanced Settings", open=False):
358
+ max_new_tokens = gr.Slider(minimum=1000, maximum=32000, value=24000, step=1000, label="Max New Tokens")
359
+ min_pixels = gr.Number(value=MIN_PIXELS, label="Min Pixels")
360
+ max_pixels = gr.Number(value=MAX_PIXELS, label="Max Pixels")
361
  process_btn = gr.Button("πŸš€ Process Document", variant="primary", elem_classes=["process-button"], size="lg")
362
  clear_btn = gr.Button("πŸ—‘οΈ Clear All", variant="secondary")
 
 
363
  with gr.Column(scale=2):
364
  with gr.Tabs():
365
  with gr.Tab("πŸ“ Extracted Content"):
366
+ output = gr.Textbox(label="Raw OCR Stream", interactive=False, lines=10, show_copy_button=True)
 
 
 
367
  with gr.Tab("πŸ“‹ Layout Analysis Results"):
368
+ json_output = gr.JSON(label="Layout Analysis Results", value=None)
369
+ def process_document(model_name, file_path, max_tokens, min_pix, max_pix):
370
+ try:
371
+ if not file_path:
372
+ return "Please upload an image.", None
373
+ image, status = load_file_for_preview(file_path)
374
+ if image is None:
375
+ return status, None
376
+ for raw_output, layout_result in process_image(model_name, image, min_pixels=int(min_pix) if min_pix else None, max_pixels=int(max_pix) if max_pix else None, max_new_tokens=max_tokens):
377
+ yield raw_output, layout_result
378
+ except Exception as e:
379
+ error_msg = f"Error processing document: {str(e)}"
380
+ print(error_msg)
381
+ traceback.print_exc()
382
+ yield error_msg, None
383
+ def handle_file_upload(file_path):
384
+ if not file_path:
385
+ return None, "No file loaded"
386
+ image, page_info = load_file_for_preview(file_path)
387
+ return image, page_info
388
+ def clear_all():
389
+ return None, None, "No file loaded", None
390
+ file_input.change(handle_file_upload, inputs=[file_input], outputs=[image_preview, output])
391
  process_btn.click(
392
+ process_document,
393
+ inputs=[model_choice, file_input, max_new_tokens, min_pixels, max_pixels],
394
+ outputs=[output, json_output]
395
  )
 
 
396
  clear_btn.click(
397
+ clear_all,
398
+ outputs=[file_input, image_preview, output, json_output]
399
  )
 
400
  return demo
401
 
402
  if __name__ == "__main__":
403
  demo = create_gradio_interface()
404
+ demo.queue(max_size=10).launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True, show_error=True)