Keemoz0 commited on
Commit
34b6cd1
·
1 Parent(s): b8be403

change class name from column to table column

Browse files
Files changed (1) hide show
  1. app.py +13 -8
app.py CHANGED
@@ -1,10 +1,7 @@
1
  import gradio as gr
2
- from huggingface_hub import hf_hub_download
3
- from PIL import Image
4
- import torch
5
  from transformers import AutoImageProcessor, AutoModelForObjectDetection
 
6
 
7
- gr.load("models/microsoft/table-transformer-structure-recognition").launch()
8
  # Load the processor and model for table structure recognition
9
  processor = AutoImageProcessor.from_pretrained("microsoft/table-transformer-structure-recognition")
10
  model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition")
@@ -20,11 +17,19 @@ def predict(image):
20
 
21
  # Extract bounding boxes and class labels
22
  predicted_boxes = outputs.pred_boxes[0].cpu().numpy() # First image
23
- predicted_classes = outputs.logits.argmax(-1).cpu().numpy() # Class predictions
 
24
  class_names = model.config.id2label # Get the class name mapping
25
- print(class_names)
26
- # Return the bounding boxes for display
27
- return {"boxes": predicted_boxes.tolist(), "classes": predicted_classes.tolist()}
 
 
 
 
 
 
 
28
 
29
  # Set up the Gradio interface
30
  interface = gr.Interface(
 
1
  import gradio as gr
 
 
 
2
  from transformers import AutoImageProcessor, AutoModelForObjectDetection
3
+ import torch
4
 
 
5
  # Load the processor and model for table structure recognition
6
  processor = AutoImageProcessor.from_pretrained("microsoft/table-transformer-structure-recognition")
7
  model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition")
 
17
 
18
  # Extract bounding boxes and class labels
19
  predicted_boxes = outputs.pred_boxes[0].cpu().numpy() # First image
20
+ predicted_class_logits = outputs.logits[0].cpu().numpy() # Class logits for the first image
21
+ predicted_classes = predicted_class_logits.argmax(-1) # Get class predictions
22
  class_names = model.config.id2label # Get the class name mapping
23
+
24
+ # Filter predictions to only include columns based on class name
25
+ column_boxes = []
26
+ for idx, class_id in enumerate(predicted_classes):
27
+ class_name = class_names[class_id]
28
+ if "table column" in class_name.lower(): # Check if the class name contains 'column'
29
+ column_boxes.append(predicted_boxes[idx])
30
+
31
+ # Return the bounding boxes for columns
32
+ return {"boxes": column_boxes, "classes": ["table column"] * len(column_boxes)}
33
 
34
  # Set up the Gradio interface
35
  interface = gr.Interface(