Keemoz0 commited on
Commit
a6fc7d1
·
1 Parent(s): 470e893

Grab column boxes and ocr the text in it

Browse files
Files changed (2) hide show
  1. app.py +45 -6
  2. requirements.txt +2 -1
app.py CHANGED
@@ -2,33 +2,72 @@ import gradio as gr
2
  from huggingface_hub import hf_hub_download
3
  from PIL import Image
4
  import torch
 
5
  from transformers import AutoImageProcessor, AutoModelForObjectDetection
6
 
7
  # Load the processor and model for table structure recognition
8
  processor = AutoImageProcessor.from_pretrained("microsoft/table-transformer-structure-recognition")
9
  model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition")
10
 
11
- # Define the inference function
 
 
 
 
12
  def predict(image):
13
  # Preprocess the input image
14
- inputs = processor(images=image, return_tensors="pt")
15
 
16
  # Perform object detection using the model
17
  with torch.no_grad():
18
  outputs = model(**inputs)
19
 
20
- # Extract bounding boxes and class labels
21
  predicted_boxes = outputs.pred_boxes[0].cpu().numpy() # First image
22
  predicted_classes = outputs.logits.argmax(-1).cpu().numpy() # Class predictions
23
 
24
- # Return the bounding boxes for display
25
- return {"boxes": predicted_boxes.tolist(), "classes": predicted_classes.tolist()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # Set up the Gradio interface
28
  interface = gr.Interface(
29
  fn=predict, # The function that gets called when an image is uploaded
30
  inputs=gr.Image(type="pil"), # Image input (as PIL image)
31
- outputs="json", # Outputting a JSON with the boxes and classes
32
  )
33
 
34
  # Launch the Gradio app
 
2
  from huggingface_hub import hf_hub_download
3
  from PIL import Image
4
  import torch
5
+ import pytesseract
6
  from transformers import AutoImageProcessor, AutoModelForObjectDetection
7
 
8
  # Load the processor and model for table structure recognition
9
  processor = AutoImageProcessor.from_pretrained("microsoft/table-transformer-structure-recognition")
10
  model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition")
11
 
12
+ # Check if GPU is available and use it; otherwise, use CPU
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ model.to(device)
15
+
16
+ # Define the inference and OCR function
17
  def predict(image):
18
  # Preprocess the input image
19
+ inputs = processor(images=image, return_tensors="pt").to(device)
20
 
21
  # Perform object detection using the model
22
  with torch.no_grad():
23
  outputs = model(**inputs)
24
 
25
+ # Extract bounding boxes and filter for columns
26
  predicted_boxes = outputs.pred_boxes[0].cpu().numpy() # First image
27
  predicted_classes = outputs.logits.argmax(-1).cpu().numpy() # Class predictions
28
 
29
+ # Prepare OCR results
30
+ ocr_results = []
31
+
32
+ image_width, image_height = image.size # Get original image dimensions
33
+
34
+ # Iterate over detected boxes and perform OCR on columns
35
+ for box in predicted_boxes:
36
+ # Unpack the normalized bounding box (x_min, y_min, x_max, y_max)
37
+ x_min, y_min, x_max, y_max = box
38
+
39
+ # Calculate width and height (denormalize)
40
+ width = x_max - x_min
41
+ height = y_max - y_min
42
+
43
+ # Filter for columns based on aspect ratio (height > width)
44
+ if height / width > 2: # A threshold for vertical aspect ratio (adjust if needed)
45
+ # Convert normalized coordinates to pixel values
46
+ left = int(x_min * image_width)
47
+ top = int(y_min * image_height)
48
+ right = int(x_max * image_width)
49
+ bottom = int(y_max * image_height)
50
+
51
+ # Crop the image to the bounding box area
52
+ cropped_image = image.crop((left, top, right, bottom))
53
+
54
+ # Perform OCR on the cropped image
55
+ ocr_text = pytesseract.image_to_string(cropped_image)
56
+
57
+ # Append OCR result for this box
58
+ ocr_results.append({
59
+ "box": [left, top, right, bottom],
60
+ "text": ocr_text
61
+ })
62
+
63
+ # Return OCR results
64
+ return {"ocr_results": ocr_results}
65
 
66
  # Set up the Gradio interface
67
  interface = gr.Interface(
68
  fn=predict, # The function that gets called when an image is uploaded
69
  inputs=gr.Image(type="pil"), # Image input (as PIL image)
70
+ outputs="json", # Outputting a JSON with the OCR results
71
  )
72
 
73
  # Launch the Gradio app
requirements.txt CHANGED
@@ -2,4 +2,5 @@ torch
2
  transformers
3
  gradio
4
  Pillow
5
- timm
 
 
2
  transformers
3
  gradio
4
  Pillow
5
+ timm
6
+ pytesseract