rishiraj commited on
Commit
5f95db3
·
verified ·
1 Parent(s): 558c701

Delete pdf-extractor

Browse files
pdf-extractor/pdf_extractor.py DELETED
@@ -1,52 +0,0 @@
1
- from typing import List, Union, Optional
2
- import json
3
- from indexify_extractor_sdk import Content, Extractor, Feature
4
- from pydantic import BaseModel, Field
5
- from .utils.tt_module import get_tables
6
- import fitz
7
- import tempfile
8
-
9
- class PDFExtractorConfig(BaseModel):
10
- output_types: List[str] = Field(default_factory=lambda: ["text", "image", "table"])
11
-
12
- class PDFExtractor(Extractor):
13
- name = "tensorlake/pdf-extractor"
14
- description = "PDF Extractor for Texts, Images & Tables"
15
- system_dependencies = ["poppler-utils"]
16
- input_mime_types = ["application/pdf"]
17
-
18
- def __init__(self):
19
- super(PDFExtractor, self).__init__()
20
-
21
- def extract(self, content: Content, params: PDFExtractorConfig) -> List[Union[Feature, Content]]:
22
- contents = []
23
- with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as inputtmpfile:
24
- inputtmpfile.write(content.data)
25
- inputtmpfile.flush()
26
- doc = fitz.open(inputtmpfile.name)
27
-
28
- for i in range(len(doc)):
29
- page = doc[i]
30
-
31
- if "text" in params.output_types:
32
- page_text = page.get_text()
33
- feature = Feature.metadata(value={"type": "text", "page": i+1})
34
- contents.append(Content.from_text(page_text, features=[feature]))
35
-
36
- if "image" in params.output_types:
37
- image_list = page.get_images()
38
- for img in image_list:
39
- xref = img[0]
40
- pix = fitz.Pixmap(doc, xref)
41
- if not pix.colorspace.name in (fitz.csGRAY.name, fitz.csRGB.name):
42
- pix = fitz.Pixmap(fitz.csRGB, pix)
43
- feature = Feature.metadata({"type": "image", "page": i+1})
44
- contents.append(Content(content_type="image/png", data=pix.tobytes(), features=[feature]))
45
-
46
- if "table" in params.output_types:
47
- tables = get_tables(content.data)
48
- for page, content in tables.items():
49
- feature = Feature.metadata({"type": "table", "page": int(page)})
50
- contents.append(Content(content_type="application/json", data=json.dumps(content), features=[feature]))
51
-
52
- return contents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pdf-extractor/utils/tt_module.py DELETED
@@ -1,230 +0,0 @@
1
- from transformers import AutoModelForObjectDetection
2
- import torch
3
- from pdf2image import convert_from_bytes
4
- from torchvision import transforms
5
- from transformers import TableTransformerForObjectDetection
6
- import numpy as np
7
- import easyocr
8
- from tqdm.auto import tqdm
9
-
10
- model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-detection", revision="no_timm")
11
- model.config.id2label
12
- device = "cuda" if torch.cuda.is_available() else "cpu"
13
- model.to(device)
14
- structure_model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-structure-recognition-v1.1-all")
15
- structure_model.to(device)
16
- reader = easyocr.Reader(['en'])
17
-
18
- def pdf_to_img(pdf_path):
19
- image_list = []
20
- images = convert_from_bytes(pdf_path)
21
- for i in range(len(images)):
22
- image = images[i].convert("RGB")
23
- image_list.append(image)
24
- return image_list
25
-
26
- class MaxResize(object):
27
- def __init__(self, max_size=800):
28
- self.max_size = max_size
29
-
30
- def __call__(self, image):
31
- width, height = image.size
32
- current_max_size = max(width, height)
33
- scale = self.max_size / current_max_size
34
- resized_image = image.resize((int(round(scale*width)), int(round(scale*height))))
35
-
36
- return resized_image
37
-
38
- def box_cxcywh_to_xyxy(x):
39
- x_c, y_c, w, h = x.unbind(-1)
40
- b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
41
- return torch.stack(b, dim=1)
42
-
43
- def rescale_bboxes(out_bbox, size):
44
- img_w, img_h = size
45
- b = box_cxcywh_to_xyxy(out_bbox)
46
- b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
47
- return b
48
-
49
- def outputs_to_objects(outputs, img_size, id2label):
50
- m = outputs.logits.softmax(-1).max(-1)
51
- pred_labels = list(m.indices.detach().cpu().numpy())[0]
52
- pred_scores = list(m.values.detach().cpu().numpy())[0]
53
- pred_bboxes = outputs['pred_boxes'].detach().cpu()[0]
54
- pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]
55
-
56
- objects = []
57
- for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
58
- class_label = id2label[int(label)]
59
- if not class_label == 'no object':
60
- objects.append({'label': class_label, 'score': float(score),
61
- 'bbox': [float(elem) for elem in bbox]})
62
-
63
- return objects
64
-
65
- def objects_to_crops(img, tokens, objects, class_thresholds, padding=10):
66
- """
67
- Process the bounding boxes produced by the table detection model into
68
- cropped table images and cropped tokens.
69
- """
70
-
71
- table_crops = []
72
- for obj in objects:
73
- if obj['score'] < class_thresholds[obj['label']]:
74
- continue
75
-
76
- cropped_table = {}
77
-
78
- bbox = obj['bbox']
79
- bbox = [bbox[0]-padding, bbox[1]-padding, bbox[2]+padding, bbox[3]+padding]
80
-
81
- cropped_img = img.crop(bbox)
82
-
83
- table_tokens = [token for token in tokens if iob(token['bbox'], bbox) >= 0.5]
84
- for token in table_tokens:
85
- token['bbox'] = [token['bbox'][0]-bbox[0],
86
- token['bbox'][1]-bbox[1],
87
- token['bbox'][2]-bbox[0],
88
- token['bbox'][3]-bbox[1]]
89
-
90
- # If table is predicted to be rotated, rotate cropped image and tokens/words:
91
- if obj['label'] == 'table rotated':
92
- cropped_img = cropped_img.rotate(270, expand=True)
93
- for token in table_tokens:
94
- bbox = token['bbox']
95
- bbox = [cropped_img.size[0]-bbox[3]-1,
96
- bbox[0],
97
- cropped_img.size[0]-bbox[1]-1,
98
- bbox[2]]
99
- token['bbox'] = bbox
100
-
101
- cropped_table['image'] = cropped_img
102
- cropped_table['tokens'] = table_tokens
103
-
104
- table_crops.append(cropped_table)
105
-
106
- return table_crops
107
-
108
- def get_cell_coordinates_by_row(table_data):
109
- # Extract rows and columns
110
- rows = [entry for entry in table_data if entry['label'] == 'table row']
111
- columns = [entry for entry in table_data if entry['label'] == 'table column']
112
-
113
- # Sort rows and columns by their Y and X coordinates, respectively
114
- rows.sort(key=lambda x: x['bbox'][1])
115
- columns.sort(key=lambda x: x['bbox'][0])
116
-
117
- # Function to find cell coordinates
118
- def find_cell_coordinates(row, column):
119
- cell_bbox = [column['bbox'][0], row['bbox'][1], column['bbox'][2], row['bbox'][3]]
120
- return cell_bbox
121
-
122
- # Generate cell coordinates and count cells in each row
123
- cell_coordinates = []
124
-
125
- for row in rows:
126
- row_cells = []
127
- for column in columns:
128
- cell_bbox = find_cell_coordinates(row, column)
129
- row_cells.append({'column': column['bbox'], 'cell': cell_bbox})
130
-
131
- # Sort cells in the row by X coordinate
132
- row_cells.sort(key=lambda x: x['column'][0])
133
-
134
- # Append row information to cell_coordinates
135
- cell_coordinates.append({'row': row['bbox'], 'cells': row_cells, 'cell_count': len(row_cells)})
136
-
137
- # Sort rows from top to bottom
138
- cell_coordinates.sort(key=lambda x: x['row'][1])
139
-
140
- return cell_coordinates
141
-
142
- def apply_ocr(cell_coordinates, cropped_table):
143
- # let's OCR row by row
144
- data = dict()
145
- max_num_columns = 0
146
- for idx, row in enumerate(tqdm(cell_coordinates)):
147
- row_text = []
148
- for cell in row["cells"]:
149
- # crop cell out of image
150
- cell_image = np.array(cropped_table.crop(cell["cell"]))
151
- # apply OCR
152
- result = reader.readtext(np.array(cell_image))
153
- if len(result) > 0:
154
- # print([x[1] for x in list(result)])
155
- text = " ".join([x[1] for x in result])
156
- row_text.append(text)
157
-
158
- if len(row_text) > max_num_columns:
159
- max_num_columns = len(row_text)
160
-
161
- data[idx] = row_text
162
-
163
- print("Max number of columns:", max_num_columns)
164
-
165
- # pad rows which don't have max_num_columns elements
166
- # to make sure all rows have the same number of columns
167
- for row, row_data in data.copy().items():
168
- if len(row_data) != max_num_columns:
169
- row_data = row_data + ["" for _ in range(max_num_columns - len(row_data))]
170
- data[row] = row_data
171
-
172
- return data
173
-
174
- def get_tables(pdf_path):
175
- image_list = pdf_to_img(pdf_path)
176
- data_dict = {}
177
- for index, image in enumerate(image_list):
178
- detection_transform = transforms.Compose([
179
- MaxResize(800),
180
- transforms.ToTensor(),
181
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
182
- ])
183
-
184
- pixel_values = detection_transform(image).unsqueeze(0)
185
- pixel_values = pixel_values.to(device)
186
-
187
- with torch.no_grad():
188
- outputs = model(pixel_values)
189
-
190
- id2label = model.config.id2label
191
- id2label[len(model.config.id2label)] = "no object"
192
-
193
- objects = outputs_to_objects(outputs, image.size, id2label)
194
-
195
- tokens = []
196
- detection_class_thresholds = {
197
- "table": 0.5,
198
- "table rotated": 0.5,
199
- "no object": 10
200
- }
201
- crop_padding = 10
202
-
203
- tables_crops = objects_to_crops(image, tokens, objects, detection_class_thresholds, padding=0)
204
-
205
- for table_index, table_crop in enumerate(tables_crops):
206
- cropped_table = table_crop['image'].convert("RGB")
207
-
208
- structure_transform = transforms.Compose([
209
- MaxResize(1000),
210
- transforms.ToTensor(),
211
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
212
- ])
213
-
214
- pixel_values = structure_transform(cropped_table).unsqueeze(0)
215
- pixel_values = pixel_values.to(device)
216
-
217
- with torch.no_grad():
218
- outputs = structure_model(pixel_values)
219
-
220
- structure_id2label = structure_model.config.id2label
221
- structure_id2label[len(structure_id2label)] = "no object"
222
-
223
- cells = outputs_to_objects(outputs, cropped_table.size, structure_id2label)
224
- if cells[0]['score'] > 0.95:
225
- cell_coordinates = get_cell_coordinates_by_row(cells)
226
-
227
- data = apply_ocr(cell_coordinates, cropped_table)
228
- data_dict[f"{index+1}_{table_index+1}"] = data
229
-
230
- return data_dict