prithivMLmods commited on
Commit
ad2fb93
·
verified ·
1 Parent(s): 4048cd0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -57
app.py CHANGED
@@ -26,6 +26,9 @@ from transformers import (
26
 
27
  # --- Constants and Model Setup ---
28
  MAX_INPUT_TOKEN_LENGTH = 4096
 
 
 
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
30
 
31
  # --- Prompts for Different Tasks ---
@@ -78,57 +81,64 @@ model_i = Qwen2_5_VLForConditionalGeneration.from_pretrained(
78
  MODEL_ID_I, trust_remote_code=True, torch_dtype=torch.float16
79
  ).to(device).eval()
80
 
81
- def progress_bar_html(label: str) -> str:
 
82
  """
83
- Returns an HTML snippet for a thin progress bar with a label.
84
- The progress bar is styled as a dark red animated bar.
85
  """
86
- return f'''
87
- <div style="display: flex; align-items: center;">
88
- <span style="margin-right: 10px; font-size: 14px;">{label}</span>
89
- <div style="width: 110px; height: 5px; background-color: #AFEEEE; border-radius: 2px; overflow: hidden;">
90
- <div style="width: 100%; height: 100%; background-color: #00FFFF; animation: loading 1.5s linear infinite;"></div>
91
- </div>
92
- </div>
93
- <style>
94
- @keyframes loading {{
95
- 0% {{ transform: translateX(-100%); }}
96
- 100% {{ transform: translateX(100%); }}
97
- }}
98
- </style>
99
- '''
100
-
101
- # --- Utility Functions ---
102
- def layoutjson2md(layout_data: List[Dict]) -> str:
103
- """Converts the structured JSON from Layout Analysis into formatted Markdown."""
104
  markdown_lines = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  try:
106
- # Sort items by reading order (top-to-bottom, left-to-right)
107
- sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox', [0,0,0,0])[1], x.get('bbox', [0,0,0,0])[0]))
 
 
108
  for item in sorted_items:
109
- category = item.get('category', '')
110
  text = item.get('text', '')
111
- if not text: continue
 
112
 
113
- if category == 'Title': markdown_lines.append(f"# {text}\n")
114
- elif category == 'Section-header': markdown_lines.append(f"## {text}\n")
 
 
115
  elif category == 'Table':
116
- # Handle structured table JSON
117
  if isinstance(text, dict) and 'header' in text and 'rows' in text:
118
  header = '| ' + ' | '.join(map(str, text['header'])) + ' |'
119
  separator = '| ' + ' | '.join(['---'] * len(text['header'])) + ' |'
120
  rows = ['| ' + ' | '.join(map(str, row)) + ' |' for row in text['rows']]
121
  markdown_lines.extend([header, separator] + rows)
122
  markdown_lines.append("\n")
123
- else: # Fallback for simple text
124
  markdown_lines.append(f"{text}\n")
125
  else:
126
- markdown_lines.append(f"{text}\n")
 
127
  except Exception as e:
128
  print(f"Error converting to markdown: {e}")
129
- return "### Error converting JSON to Markdown."
 
 
130
  return "\n".join(markdown_lines)
131
 
 
132
  # --- Core Application Logic ---
133
  @spaces.GPU
134
  def process_document_stream(model_name: str, task_choice: str, image: Image.Image, max_new_tokens: int):
@@ -158,11 +168,10 @@ def process_document_stream(model_name: str, task_choice: str, image: Image.Imag
158
  inputs = processor(text=[prompt_full], images=[image], return_tensors="pt", padding=True, truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH).to(device)
159
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
160
  generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
161
-
162
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
163
  thread.start()
164
 
165
-
166
  # 4. Stream raw output to the UI in real-time
167
  buffer = ""
168
  for new_text in streamer:
@@ -179,24 +188,32 @@ def process_document_stream(model_name: str, task_choice: str, image: Image.Imag
179
  try:
180
  json_match = re.search(r'```json\s*([\s\S]+?)\s*```', buffer)
181
  if not json_match:
182
- raise json.JSONDecodeError("JSON object not found in output.", buffer, 0)
183
-
 
 
 
 
 
 
 
184
  json_str = json_match.group(1)
185
  layout_data = json.loads(json_str)
186
  markdown_content = layoutjson2md(layout_data)
187
-
188
  yield buffer, markdown_content, layout_data
189
  except Exception as e:
190
- error_md = f"❌ **Error:** Failed to parse Layout JSON.\n\n**Details:**\n`{str(e)}`"
191
  error_json = {"error": "ProcessingError", "details": str(e), "raw_output": buffer}
192
  yield buffer, error_md, error_json
193
 
 
194
  # --- Gradio UI Definition ---
195
  def create_gradio_interface():
196
  """Builds and returns the Gradio web interface."""
197
  css = """
198
  .main-container { max-width: 1400px; margin: 0 auto; }
199
- .process-button { border: none !important; color: white !important; font-weight: bold !important; background-color: blue !important;}
200
  .process-button:hover { background-color: darkblue !important; transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
201
  """
202
  with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
@@ -208,15 +225,15 @@ def create_gradio_interface():
208
  </p>
209
  </div>
210
  """)
211
-
212
  with gr.Row():
213
  # Left Column (Inputs)
214
  with gr.Column(scale=1):
215
  model_choice = gr.Dropdown(
216
- choices=["Camel-Doc-OCR-080125",
217
- "MonkeyOCR-Recognition",
218
  "olmOCR-7B-0725",
219
- "Nanonets-OCR-s",
220
  "Megalodon-OCR-Sync-0713"
221
  ],
222
  label="Select Model", value="Nanonets-OCR-s"
@@ -228,7 +245,7 @@ def create_gradio_interface():
228
  image_input = gr.Image(label="Upload Image", type="pil", sources=['upload'])
229
  with gr.Accordion("Advanced Settings", open=False):
230
  max_new_tokens = gr.Slider(minimum=512, maximum=8192, value=4096, step=256, label="Max New Tokens")
231
-
232
  process_btn = gr.Button("🚀 Process Document", variant="primary", elem_classes=["process-button"], size="lg")
233
  clear_btn = gr.Button("🗑️ Clear All", variant="secondary")
234
 
@@ -242,33 +259,32 @@ def create_gradio_interface():
242
  examples=["examples/1.png", "examples/2.png", "examples/3.png", "examples/4.png", "examples/5.png"],
243
  inputs=image_input,
244
  label="Examples"
245
- )
246
- with gr.Tab("📰 README.md"):
247
- with gr.Accordion("(Formatted Result)", open=True):
248
- markdown_output = gr.Markdown(label="Formatted Markdown")
249
-
250
  with gr.Tab("📋 Layout Analysis Results"):
251
  json_output = gr.JSON(label="Structured Layout Data (JSON)")
252
-
253
  # Event Handlers
254
  def clear_all_outputs():
255
  return None, "Raw output will appear here.", "Formatted results will appear here.", None
256
 
257
  process_btn.click(
258
  fn=process_document_stream,
259
- inputs=[model_choice,
260
- task_choice,
261
- image_input,
262
  max_new_tokens],
263
- outputs=[raw_output_stream,
264
- markdown_output,
265
  json_output]
266
  )
267
  clear_btn.click(
268
  clear_all_outputs,
269
- outputs=[image_input,
270
- raw_output_stream,
271
- markdown_output,
272
  json_output]
273
  )
274
  return demo
 
26
 
27
  # --- Constants and Model Setup ---
28
  MAX_INPUT_TOKEN_LENGTH = 4096
29
+ # Note: The following line correctly falls back to CPU if CUDA is not available.
30
+ # The "RuntimeError: CUDA driver initialization failed" is an environment issue,
31
+ # meaning the code is being run where a GPU is expected but not found/configured.
32
  device = "cuda" if torch.cuda.is_available() else "cpu"
33
 
34
  # --- Prompts for Different Tasks ---
 
81
  MODEL_ID_I, trust_remote_code=True, torch_dtype=torch.float16
82
  ).to(device).eval()
83
 
84
+ # --- Utility Functions ---
85
+ def layoutjson2md(layout_data: Any) -> str:
86
  """
87
+ FIXED: Converts the structured JSON from Layout Analysis into formatted Markdown.
88
+ This version is robust against malformed JSON from the model.
89
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  markdown_lines = []
91
+
92
+ # If the model wraps the list in a dictionary, find and extract the list.
93
+ if isinstance(layout_data, dict):
94
+ found_list = None
95
+ for value in layout_data.values():
96
+ if isinstance(value, list):
97
+ found_list = value
98
+ break
99
+ if found_list is not None:
100
+ layout_data = found_list
101
+ else:
102
+ return "### Error: Could not find a list of layout items in the JSON object."
103
+
104
+ if not isinstance(layout_data, list):
105
+ return f"### Error: Expected a list of layout items, but received type {type(layout_data).__name__}."
106
+
107
  try:
108
+ # Filter out any non-dictionary items and sort by reading order.
109
+ valid_items = [item for item in layout_data if isinstance(item, dict)]
110
+ sorted_items = sorted(valid_items, key=lambda x: (x.get('bbox', [0, 0, 0, 0])[1], x.get('bbox', [0, 0, 0, 0])[0]))
111
+
112
  for item in sorted_items:
113
+ category = item.get('category', 'Text') # Default to 'Text' if no category
114
  text = item.get('text', '')
115
+ if not text:
116
+ continue
117
 
118
+ if category == 'Title':
119
+ markdown_lines.append(f"# {text}\n")
120
+ elif category == 'Section-header':
121
+ markdown_lines.append(f"## {text}\n")
122
  elif category == 'Table':
 
123
  if isinstance(text, dict) and 'header' in text and 'rows' in text:
124
  header = '| ' + ' | '.join(map(str, text['header'])) + ' |'
125
  separator = '| ' + ' | '.join(['---'] * len(text['header'])) + ' |'
126
  rows = ['| ' + ' | '.join(map(str, row)) + ' |' for row in text['rows']]
127
  markdown_lines.extend([header, separator] + rows)
128
  markdown_lines.append("\n")
129
+ else: # Fallback for simple text or malformed tables
130
  markdown_lines.append(f"{text}\n")
131
  else:
132
+ markdown_lines.append(f"{text}\n")
133
+
134
  except Exception as e:
135
  print(f"Error converting to markdown: {e}")
136
+ traceback.print_exc()
137
+ return "### Error: An unexpected error occurred while converting JSON to Markdown."
138
+
139
  return "\n".join(markdown_lines)
140
 
141
+
142
  # --- Core Application Logic ---
143
  @spaces.GPU
144
  def process_document_stream(model_name: str, task_choice: str, image: Image.Image, max_new_tokens: int):
 
168
  inputs = processor(text=[prompt_full], images=[image], return_tensors="pt", padding=True, truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH).to(device)
169
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
170
  generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
171
+
172
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
173
  thread.start()
174
 
 
175
  # 4. Stream raw output to the UI in real-time
176
  buffer = ""
177
  for new_text in streamer:
 
188
  try:
189
  json_match = re.search(r'```json\s*([\s\S]+?)\s*```', buffer)
190
  if not json_match:
191
+ # If no JSON block is found, try to parse the whole buffer as a fallback.
192
+ try:
193
+ layout_data = json.loads(buffer)
194
+ markdown_content = layoutjson2md(layout_data)
195
+ yield buffer, markdown_content, layout_data
196
+ return
197
+ except json.JSONDecodeError:
198
+ raise ValueError("JSON object not found in the model's output.")
199
+
200
  json_str = json_match.group(1)
201
  layout_data = json.loads(json_str)
202
  markdown_content = layoutjson2md(layout_data)
203
+
204
  yield buffer, markdown_content, layout_data
205
  except Exception as e:
206
+ error_md = f"❌ **Error:** Failed to parse Layout JSON.\n\n**Details:**\n`{str(e)}`\n\n**Raw Output:**\n```\n{buffer}\n```"
207
  error_json = {"error": "ProcessingError", "details": str(e), "raw_output": buffer}
208
  yield buffer, error_md, error_json
209
 
210
+
211
  # --- Gradio UI Definition ---
212
  def create_gradio_interface():
213
  """Builds and returns the Gradio web interface."""
214
  css = """
215
  .main-container { max-width: 1400px; margin: 0 auto; }
216
+ .process-button { border: none !important; color: white !important; font-weight: bold !important; background-color: blue !important;}
217
  .process-button:hover { background-color: darkblue !important; transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
218
  """
219
  with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
 
225
  </p>
226
  </div>
227
  """)
228
+
229
  with gr.Row():
230
  # Left Column (Inputs)
231
  with gr.Column(scale=1):
232
  model_choice = gr.Dropdown(
233
+ choices=["Camel-Doc-OCR-080125",
234
+ "MonkeyOCR-Recognition",
235
  "olmOCR-7B-0725",
236
+ "Nanonets-OCR-s",
237
  "Megalodon-OCR-Sync-0713"
238
  ],
239
  label="Select Model", value="Nanonets-OCR-s"
 
245
  image_input = gr.Image(label="Upload Image", type="pil", sources=['upload'])
246
  with gr.Accordion("Advanced Settings", open=False):
247
  max_new_tokens = gr.Slider(minimum=512, maximum=8192, value=4096, step=256, label="Max New Tokens")
248
+
249
  process_btn = gr.Button("🚀 Process Document", variant="primary", elem_classes=["process-button"], size="lg")
250
  clear_btn = gr.Button("🗑️ Clear All", variant="secondary")
251
 
 
259
  examples=["examples/1.png", "examples/2.png", "examples/3.png", "examples/4.png", "examples/5.png"],
260
  inputs=image_input,
261
  label="Examples"
262
+ )
263
+ with gr.Tab("📰 Formatted Result"):
264
+ markdown_output = gr.Markdown(label="Formatted Markdown")
265
+
 
266
  with gr.Tab("📋 Layout Analysis Results"):
267
  json_output = gr.JSON(label="Structured Layout Data (JSON)")
268
+
269
  # Event Handlers
270
  def clear_all_outputs():
271
  return None, "Raw output will appear here.", "Formatted results will appear here.", None
272
 
273
  process_btn.click(
274
  fn=process_document_stream,
275
+ inputs=[model_choice,
276
+ task_choice,
277
+ image_input,
278
  max_new_tokens],
279
+ outputs=[raw_output_stream,
280
+ markdown_output,
281
  json_output]
282
  )
283
  clear_btn.click(
284
  clear_all_outputs,
285
+ outputs=[image_input,
286
+ raw_output_stream,
287
+ markdown_output,
288
  json_output]
289
  )
290
  return demo