MohamedRashad commited on
Commit
008db80
Β·
1 Parent(s): c4e5804

Refactor inference and process_image functions: streamline parameters and enhance layout processing

Browse files
Files changed (1) hide show
  1. app.py +26 -45
app.py CHANGED
@@ -284,6 +284,7 @@ processing_results = {
284
  'markdown_content': None,
285
  'raw_output': None,
286
  }
 
287
  def inference(image: Image.Image, prompt: str, max_new_tokens: int = 24000) -> str:
288
  """Run inference on an image with the given prompt"""
289
  try:
@@ -356,7 +357,6 @@ def inference(image: Image.Image, prompt: str, max_new_tokens: int = 24000) -> s
356
 
357
  def process_image(
358
  image: Image.Image,
359
- prompt_mode: str,
360
  min_pixels: Optional[int] = None,
361
  max_pixels: Optional[int] = None
362
  ) -> Dict[str, Any]:
@@ -366,48 +366,42 @@ def process_image(
366
  if min_pixels is not None or max_pixels is not None:
367
  image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
368
 
369
- # Run inference
370
  raw_output = inference(image, prompt)
371
 
372
  # Process results based on prompt mode
373
  result = {
374
  'original_image': image,
375
  'raw_output': raw_output,
376
- 'prompt_mode': prompt_mode,
377
  'processed_image': image,
378
  'layout_result': None,
379
  'markdown_content': None
380
  }
381
 
382
- # For layout analysis prompts, try to parse JSON and create visualizations
383
- if prompt_mode in ['prompt_layout_all_en', 'prompt_layout_only_en']:
 
 
 
 
 
384
  try:
385
- # Try to parse JSON output
386
- layout_data = json.loads(raw_output)
387
- result['layout_result'] = layout_data
388
-
389
- # Create visualization with bounding boxes
390
- try:
391
- processed_image = draw_layout_on_image(image, layout_data)
392
- result['processed_image'] = processed_image
393
- except Exception as e:
394
- print(f"Error drawing layout: {e}")
395
- result['processed_image'] = image
396
-
397
- # Generate markdown if text is available
398
- if prompt_mode == 'prompt_layout_all_en':
399
- try:
400
- markdown_content = layoutjson2md(image, layout_data, text_key='text')
401
- result['markdown_content'] = markdown_content
402
- except Exception as e:
403
- print(f"Error generating markdown: {e}")
404
- result['markdown_content'] = raw_output
405
-
406
- except json.JSONDecodeError:
407
- print("Failed to parse JSON output, using raw output")
408
  result['markdown_content'] = raw_output
409
- else:
410
- # For OCR prompts, use raw output as markdown
 
411
  result['markdown_content'] = raw_output
412
 
413
  return result
@@ -418,7 +412,6 @@ def process_image(
418
  return {
419
  'original_image': image,
420
  'raw_output': f"Error processing image: {str(e)}",
421
- 'prompt_mode': prompt_mode,
422
  'processed_image': image,
423
  'layout_result': None,
424
  'markdown_content': f"Error processing image: {str(e)}"
@@ -707,7 +700,7 @@ def create_gradio_interface():
707
  except Exception as e:
708
  return f'<div class="model-status status-error">❌ Error: {str(e)}</div>'
709
 
710
- def process_document(file_path, prompt_mode_val, max_tokens, min_pix, max_pix):
711
  """Process the uploaded document"""
712
  global pdf_cache
713
 
@@ -750,7 +743,6 @@ def create_gradio_interface():
750
  for i, img in enumerate(pdf_cache["images"]):
751
  result = process_image(
752
  img,
753
- prompt_mode_val,
754
  min_pixels=int(min_pix) if min_pix else None,
755
  max_pixels=int(max_pix) if max_pix else None
756
  )
@@ -776,7 +768,6 @@ def create_gradio_interface():
776
  # Process single image
777
  result = process_image(
778
  image,
779
- prompt_mode_val,
780
  min_pixels=int(min_pix) if min_pix else None,
781
  max_pixels=int(max_pix) if max_pix else None
782
  )
@@ -804,10 +795,6 @@ def create_gradio_interface():
804
  f'<div class="model-status status-error">❌ {error_msg}</div>'
805
  )
806
 
807
- def update_prompt_display(mode):
808
- """Update the prompt display when mode changes"""
809
- return prompt
810
-
811
  def handle_file_upload(file_path):
812
  """Handle file upload and show preview"""
813
  if not file_path:
@@ -871,15 +858,9 @@ def create_gradio_interface():
871
  outputs=[image_preview, page_info, markdown_output]
872
  )
873
 
874
- prompt_mode.change(
875
- update_prompt_display,
876
- inputs=[prompt_mode],
877
- outputs=[prompt_display]
878
- )
879
-
880
  process_btn.click(
881
  process_document,
882
- inputs=[file_input, prompt_mode, max_new_tokens, min_pixels, max_pixels],
883
  outputs=[processed_image, markdown_output, raw_output, json_output, model_status]
884
  )
885
 
 
284
  'markdown_content': None,
285
  'raw_output': None,
286
  }
287
+ @spaces.gpu
288
  def inference(image: Image.Image, prompt: str, max_new_tokens: int = 24000) -> str:
289
  """Run inference on an image with the given prompt"""
290
  try:
 
357
 
358
  def process_image(
359
  image: Image.Image,
 
360
  min_pixels: Optional[int] = None,
361
  max_pixels: Optional[int] = None
362
  ) -> Dict[str, Any]:
 
366
  if min_pixels is not None or max_pixels is not None:
367
  image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
368
 
369
+ # Run inference with the default prompt
370
  raw_output = inference(image, prompt)
371
 
372
  # Process results based on prompt mode
373
  result = {
374
  'original_image': image,
375
  'raw_output': raw_output,
 
376
  'processed_image': image,
377
  'layout_result': None,
378
  'markdown_content': None
379
  }
380
 
381
+ # Try to parse JSON and create visualizations (since we're doing layout analysis)
382
+ try:
383
+ # Try to parse JSON output
384
+ layout_data = json.loads(raw_output)
385
+ result['layout_result'] = layout_data
386
+
387
+ # Create visualization with bounding boxes
388
  try:
389
+ processed_image = draw_layout_on_image(image, layout_data)
390
+ result['processed_image'] = processed_image
391
+ except Exception as e:
392
+ print(f"Error drawing layout: {e}")
393
+ result['processed_image'] = image
394
+
395
+ # Generate markdown from layout data
396
+ try:
397
+ markdown_content = layoutjson2md(image, layout_data, text_key='text')
398
+ result['markdown_content'] = markdown_content
399
+ except Exception as e:
400
+ print(f"Error generating markdown: {e}")
 
 
 
 
 
 
 
 
 
 
 
401
  result['markdown_content'] = raw_output
402
+
403
+ except json.JSONDecodeError:
404
+ print("Failed to parse JSON output, using raw output")
405
  result['markdown_content'] = raw_output
406
 
407
  return result
 
412
  return {
413
  'original_image': image,
414
  'raw_output': f"Error processing image: {str(e)}",
 
415
  'processed_image': image,
416
  'layout_result': None,
417
  'markdown_content': f"Error processing image: {str(e)}"
 
700
  except Exception as e:
701
  return f'<div class="model-status status-error">❌ Error: {str(e)}</div>'
702
 
703
+ def process_document(file_path, max_tokens, min_pix, max_pix):
704
  """Process the uploaded document"""
705
  global pdf_cache
706
 
 
743
  for i, img in enumerate(pdf_cache["images"]):
744
  result = process_image(
745
  img,
 
746
  min_pixels=int(min_pix) if min_pix else None,
747
  max_pixels=int(max_pix) if max_pix else None
748
  )
 
768
  # Process single image
769
  result = process_image(
770
  image,
 
771
  min_pixels=int(min_pix) if min_pix else None,
772
  max_pixels=int(max_pix) if max_pix else None
773
  )
 
795
  f'<div class="model-status status-error">❌ {error_msg}</div>'
796
  )
797
 
 
 
 
 
798
  def handle_file_upload(file_path):
799
  """Handle file upload and show preview"""
800
  if not file_path:
 
858
  outputs=[image_preview, page_info, markdown_output]
859
  )
860
 
 
 
 
 
 
 
861
  process_btn.click(
862
  process_document,
863
+ inputs=[file_input, max_new_tokens, min_pixels, max_pixels],
864
  outputs=[processed_image, markdown_output, raw_output, json_output, model_status]
865
  )
866