Keemoz0 commited on
Commit
f58ee97
·
1 Parent(s): e272d63

Show only columns

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -1,10 +1,7 @@
1
  import gradio as gr
2
- from huggingface_hub import hf_hub_download
3
- from PIL import Image
4
- import torch
5
  from transformers import AutoImageProcessor, AutoModelForObjectDetection
 
6
 
7
- gr.load("models/microsoft/table-transformer-structure-recognition").launch()
8
  # Load the processor and model for table structure recognition
9
  processor = AutoImageProcessor.from_pretrained("microsoft/table-transformer-structure-recognition")
10
  model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition")
@@ -22,8 +19,12 @@ def predict(image):
22
  predicted_boxes = outputs.pred_boxes[0].cpu().numpy() # First image
23
  predicted_classes = outputs.logits.argmax(-1).cpu().numpy() # Class predictions
24
 
25
- # Return the bounding boxes for display
26
- return {"boxes": predicted_boxes.tolist(), "classes": predicted_classes.tolist()}
 
 
 
 
27
 
28
  # Set up the Gradio interface
29
  interface = gr.Interface(
@@ -33,4 +34,4 @@ interface = gr.Interface(
33
  )
34
 
35
  # Launch the Gradio app
36
- interface.launch()
 
1
  import gradio as gr
 
 
 
2
  from transformers import AutoImageProcessor, AutoModelForObjectDetection
3
+ import torch
4
 
 
5
  # Load the processor and model for table structure recognition
6
  processor = AutoImageProcessor.from_pretrained("microsoft/table-transformer-structure-recognition")
7
  model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition")
 
19
  predicted_boxes = outputs.pred_boxes[0].cpu().numpy() # First image
20
  predicted_classes = outputs.logits.argmax(-1).cpu().numpy() # Class predictions
21
 
22
+ # Filter predictions to only include columns
23
+ column_class_id = 1 # Assuming class ID 1 corresponds to columns, adjust if needed
24
+ column_boxes = predicted_boxes[predicted_classes == column_class_id]
25
+
26
+ # Return the bounding boxes for columns
27
+ return {"boxes": column_boxes.tolist(), "classes": ["column"] * len(column_boxes)}
28
 
29
  # Set up the Gradio interface
30
  interface = gr.Interface(
 
34
  )
35
 
36
  # Launch the Gradio app
37
+ interface.launch()