SalML commited on
Commit
b9d9b37
·
1 Parent(s): b965b5d

Upload TDTSR.py

Browse files
Files changed (1) hide show
  1. TDTSR.py +332 -0
TDTSR.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ from transformers import DetrFeatureExtractor
4
+ from transformers import DetrForObjectDetection
5
+ import torch
6
+ import matplotlib.pyplot as plt
7
+ from matplotlib.patches import Circle, Wedge, Rectangle
8
+ import streamlit as st
9
+ from PIL import Image
10
+ import math
11
+
12
+
13
+ colors = ["red", "blue", "green", "yellow", "orange", "violet"]
14
+
15
+
16
+ def table_detector(image, THRESHOLD_PROBA):
17
+ '''
18
+ Table detection using DEtect-object TRansformer pre-trained on 1 million tables
19
+
20
+ '''
21
+
22
+ feature_extractor = DetrFeatureExtractor(do_resize=True, size=800, max_size=800)
23
+ encoding = feature_extractor(image, return_tensors="pt")
24
+ # encoding.keys()
25
+ model = DetrForObjectDetection.from_pretrained("SalML/DETR-table-detection")
26
+ # SalML\DETR-table-detection
27
+ with torch.no_grad():
28
+ outputs = model(**encoding)
29
+
30
+ # keep only predictions of queries with 0.9+ confidence (excluding no-object class)
31
+ probas = outputs.logits.softmax(-1)[0, :, :-1]
32
+ keep = probas.max(-1).values > THRESHOLD_PROBA
33
+
34
+ # rescale bounding boxes
35
+ target_sizes = torch.tensor(image.size[::-1]).unsqueeze(0)
36
+ postprocessed_outputs = feature_extractor.post_process(outputs, target_sizes)
37
+ bboxes_scaled = postprocessed_outputs[0]['boxes'][keep]
38
+
39
+ return (model, image, probas[keep], bboxes_scaled)
40
+
41
+
42
+ def table_struct_recog(image, THRESHOLD_PROBA):
43
+ '''
44
+ Table structure recognition using DEtect-object TRansformer pre-trained on 1 million tables
45
+ '''
46
+
47
+ feature_extractor = DetrFeatureExtractor(do_resize=True, size=1000, max_size=1000)
48
+ encoding = feature_extractor(image, return_tensors="pt")
49
+
50
+ model = DetrForObjectDetection.from_pretrained("SalML/DETR-table-structure-recognition")
51
+ with torch.no_grad():
52
+ outputs = model(**encoding)
53
+
54
+ # keep only predictions of queries with 0.9+ confidence (excluding no-object class)
55
+ probas = outputs.logits.softmax(-1)[0, :, :-1]
56
+ keep = probas.max(-1).values > THRESHOLD_PROBA
57
+
58
+ # rescale bounding boxes
59
+ target_sizes = torch.tensor(image.size[::-1]).unsqueeze(0)
60
+ postprocessed_outputs = feature_extractor.post_process(outputs, target_sizes)
61
+ bboxes_scaled = postprocessed_outputs[0]['boxes'][keep]
62
+
63
+ return (model, image, probas[keep], bboxes_scaled)
64
+
65
+ def add_margin(pil_img, top=20, right=20, bottom=20, left=20, color=(255,255,255)):
66
+ '''
67
+ Image padding as part of TSR pre-processing to prevent missing table edges
68
+ '''
69
+ width, height = pil_img.size
70
+ new_width = width + right + left
71
+ new_height = height + top + bottom
72
+ result = Image.new(pil_img.mode, (new_width, new_height), color)
73
+ result.paste(pil_img, (left, top))
74
+ return result
75
+
76
+ def plot_results_detection(c1, model, pil_img, prob, boxes, show_only_cropped=False):
77
+ '''
78
+ Plots the full pillow pdf-page image and adds a rectangle patch for table detection
79
+ '''
80
+
81
+ plt.figure(figsize=(32,20))
82
+ plt.imshow(pil_img)
83
+ ax = plt.gca()
84
+
85
+ for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()):
86
+
87
+ cl = p.argmax()
88
+ xmin, ymin, xmax, ymax = xmin-3, ymin-3, xmax+3, ymax+3
89
+ ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,fill=False, color=colors[cl.item()], linewidth=3))
90
+ text = f'{model.config.id2label[cl.item()]}: {p[cl]:0.2f}'
91
+ ax.text(xmin, ymin, text, fontsize=15,bbox=dict(facecolor='yellow', alpha=0.5))
92
+ plt.axis('off')
93
+ plt.show()
94
+ c1.pyplot()
95
+
96
+
97
+ def plot_table_detection(c2, model, pil_img, prob, boxes):
98
+ '''
99
+ Plots only the cropped table(s) from the table detection
100
+ '''
101
+
102
+ plt.figure(figsize=(32,20))
103
+ ax = plt.gca()
104
+ cropped_img_list = []
105
+
106
+ for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()):
107
+
108
+ xmin, ymin, xmax, ymax = xmin-3, ymin-3, xmax+3, ymax+3
109
+ cropped_img = pil_img.crop((xmin, ymin, xmax, ymax))
110
+ cropped_img_list.append(cropped_img)
111
+
112
+ for cropped_img in cropped_img_list:
113
+ plt.imshow(cropped_img)
114
+
115
+ plt.axis('off')
116
+ plt.show()
117
+ c2.pyplot()
118
+ return cropped_img_list
119
+
120
+
121
+ def plot_structure(c3, model, pil_img, prob, boxes, class_to_show=0):
122
+ '''
123
+ To plot table pillow image and the TSR bounding boxes on the table
124
+ '''
125
+ plt.figure(figsize=(32,20))
126
+ plt.imshow(pil_img)
127
+ ax = plt.gca()
128
+ rows = {}
129
+ cols = {}
130
+ header = {}
131
+ row_header = {}
132
+ idx = 0
133
+
134
+ for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()):
135
+
136
+ xmin, ymin, xmax, ymax = xmin-3, ymin-3, xmax+3, ymax+3
137
+ cl = p.argmax()
138
+ class_text = model.config.id2label[cl.item()]
139
+ text = f'{class_text}: {p[cl]:0.2f}'
140
+ # st.write(class_text)
141
+ if class_text != 'table':
142
+
143
+ ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,fill=False, color=colors[cl.item()], linewidth=3))
144
+ ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='yellow', alpha=0.5))
145
+
146
+ # if class_text == 'table column header':
147
+ # header['header'] = (xmin, ymin, xmax, ymax)
148
+ if class_text == 'table row':
149
+ rows['table row '+str(idx)] = (xmin, ymin, xmax, ymax)
150
+ if class_text == 'table column':
151
+ cols['table column '+str(idx)] = (xmin, ymin, xmax, ymax)
152
+ # if class_text == 'table projected row header':
153
+ # row_header['header table row'+str(idx)] = (xmin, ymin, xmax, ymax)
154
+
155
+ idx += 1
156
+
157
+ plt.show()
158
+ c3.pyplot()
159
+ # return header, row_header, rows, cols
160
+ return rows, cols
161
+
162
+
163
+
164
+ def sort_table_features(header, row_header, rows, cols):
165
+ # Sometimes the header and first row overlap, and we need the header bbox not to have first row's bbox inside the headers bbox
166
+ y_header = header['header'][3] - 10
167
+ rows_ = {table_feature : (xmin, ymin, xmax, ymax) for table_feature, (xmin, ymin, xmax, ymax) in sorted(rows.items(), key=lambda tup: tup[1][1]) if ymin > y_header}
168
+ cols_ = {table_feature : (xmin, ymin, xmax, ymax) for table_feature, (xmin, ymin, xmax, ymax) in sorted(cols.items(), key=lambda tup: tup[1][0])}
169
+
170
+ row_header_ = {table_feature : (xmin, ymin, xmax, ymax) for table_feature, (xmin, ymin, xmax, ymax) in sorted(row_header.items(), key=lambda tup: tup[1][1])}
171
+
172
+ new_row = {}
173
+ idx = 0
174
+
175
+ for k1, v1 in rows_.items():
176
+ save_row = True
177
+ row_xmin, row_ymin, row_xmax, row_ymax = v1
178
+ for k2, v2 in row_header_.items():
179
+ header_row_xmin, header_row_ymin, header_row_xmax, header_row_ymax = v2
180
+ # table row and header table row are within 2 pixel range, skip saving the row
181
+ if math.isclose(row_ymin, header_row_ymin, abs_tol=2):
182
+ save_row = False
183
+ if save_row:
184
+ new_row['table row.'+str(idx)] = (row_xmin, row_ymin, row_xmax, row_ymax)
185
+ idx += 1
186
+
187
+ new_row_ = {table_feature : (xmin, ymin, xmax, ymax) for table_feature, (xmin, ymin, xmax, ymax) in sorted(new_row.items(), key=lambda tup: tup[1][1])}
188
+
189
+ return row_header_, new_row_, cols_
190
+
191
+
192
+ def sort_table_featuresv2(rows, cols):
193
+ # Sometimes the header and first row overlap, and we need the header bbox not to have first row's bbox inside the headers bbox
194
+ rows_ = {table_feature : (xmin, ymin, xmax, ymax) for table_feature, (xmin, ymin, xmax, ymax) in sorted(rows.items(), key=lambda tup: tup[1][1])}
195
+ cols_ = {table_feature : (xmin, ymin, xmax, ymax) for table_feature, (xmin, ymin, xmax, ymax) in sorted(cols.items(), key=lambda tup: tup[1][0])}
196
+
197
+ return rows_, cols_
198
+
199
+ def individual_table_features(pil_img, header, row_header, rows, cols):
200
+
201
+ for k, v in header.items():
202
+ xmin, ymin, xmax, ymax = v
203
+ cropped_img = pil_img.crop((xmin, ymin, xmax, ymax))
204
+ header[k] = xmin, ymin, xmax, ymax, cropped_img
205
+
206
+ for k, v in row_header.items():
207
+ xmin, ymin, xmax, ymax = v
208
+ cropped_img = pil_img.crop((xmin, ymin, xmax, ymax))
209
+ row_header[k] = xmin, ymin, xmax, ymax, cropped_img
210
+
211
+ for k, v in rows.items():
212
+ xmin, ymin, xmax, ymax = v
213
+ cropped_img = pil_img.crop((xmin, ymin, xmax, ymax))
214
+ rows[k] = xmin, ymin, xmax, ymax, cropped_img
215
+
216
+
217
+ for k, v in cols.items():
218
+ xmin, ymin, xmax, ymax = v
219
+ cropped_img = pil_img.crop((xmin, ymin, xmax, ymax))
220
+ cols[k] = xmin, ymin, xmax, ymax, cropped_img
221
+
222
+ return header, row_header, rows, cols
223
+
224
+ def individual_table_featuresv2(pil_img, rows, cols):
225
+
226
+
227
+ for k, v in rows.items():
228
+ xmin, ymin, xmax, ymax = v
229
+ cropped_img = pil_img.crop((xmin, ymin, xmax, ymax))
230
+ rows[k] = xmin, ymin, xmax, ymax, cropped_img
231
+
232
+
233
+ for k, v in cols.items():
234
+ xmin, ymin, xmax, ymax = v
235
+ cropped_img = pil_img.crop((xmin, ymin, xmax, ymax))
236
+ cols[k] = xmin, ymin, xmax, ymax, cropped_img
237
+
238
+ return rows, cols
239
+
240
+ def plot_table_features(c2, header, row_header, rows, cols):
241
+
242
+ for k, v in header.items():
243
+ _, _, _, _, pil_img = v
244
+
245
+ for k, v in row_header.items():
246
+ _, _, _, _, pil_img = v
247
+
248
+ for k, v in rows.items():
249
+ _, _, _, _, pil_img = v
250
+
251
+ for k, v in cols.items():
252
+ _, _, _, _, pil_img = v
253
+
254
+
255
+ def master_row_set(header, row_header, rows, cols):
256
+ master_row = {**header, **row_header, **rows}
257
+ master_row_ = {table_feature : (xmin, ymin, xmax, ymax, img) for table_feature, (xmin, ymin, xmax, ymax, img) in sorted(master_row.items(), key=lambda tup: tup[1][1])}
258
+
259
+ return master_row_
260
+
261
+
262
+
263
+
264
+ def object_to_cells(master_row, cols):
265
+ '''
266
+ Iterates to every row, be it header/simple row/header table row, cuts rows into cells and saves images in dictionary where length of dictionary = total rows
267
+ '''
268
+ cells_img = {}
269
+ header_idx = 0
270
+ row_idx = 0
271
+ for k_row, v_row in master_row.items():
272
+
273
+ if k_row[:16] == 'header table row':
274
+
275
+ _, _, _, _, row_header_img = v_row
276
+ cells_img[k_row+'.'+str(row_idx)] = row_header_img
277
+ row_idx += 1
278
+
279
+ elif k_row == 'header':
280
+
281
+ _, ymin, _, ymax, header_img = v_row
282
+
283
+ xa, ya, xb, yb = 0, 0, 0, ymax-ymin
284
+ for k_col, v_col in cols.items():
285
+ xmin_col, _, xmax_col, _, col_img = v_col
286
+ xa = xmin_col-19
287
+ xb = xmax_col-20
288
+
289
+ header_img_cropped = header_img.crop((xa, ya, xb, yb))
290
+ cells_img[k_row+'.'+str(header_idx)] = header_img_cropped
291
+ header_idx += 1
292
+
293
+
294
+ elif k_row[:9] == 'table row':
295
+
296
+ xmin, ymin, xmax, ymax, row_img = v_row
297
+ xa, ya, xb, yb = 0, 0, 0, ymax-ymin
298
+ row_img_list = []
299
+ for k_col, v_col in cols.items():
300
+ xmin_col, _, xmax_col, _, col_img = v_col
301
+ xa = xmin_col-19
302
+ xb = xmax_col-20
303
+ row_img_cropped = row_img.crop((xa, ya, xb, yb))
304
+ row_img_list.append(row_img_cropped)
305
+ cells_img[k_row+'.'+str(row_idx)] = row_img_list
306
+ row_idx += 1
307
+
308
+ return cells_img
309
+
310
+
311
+ def object_to_cellsv2(master_row, cols):
312
+ '''
313
+ Iterates to every row, be it header/simple row/header table row, cuts rows into cells and saves images in dictionary where length of dictionary = total rows
314
+ '''
315
+ cells_img = {}
316
+ header_idx = 0
317
+ row_idx = 0
318
+ for k_row, v_row in master_row.items():
319
+
320
+ xmin, ymin, xmax, ymax, row_img = v_row
321
+ xa, ya, xb, yb = 0, 0, 0, ymax-ymin
322
+ row_img_list = []
323
+ for k_col, v_col in cols.items():
324
+ xmin_col, _, xmax_col, _, col_img = v_col
325
+ xa = xmin_col-19
326
+ xb = xmax_col-20
327
+ row_img_cropped = row_img.crop((xa, ya, xb, yb))
328
+ row_img_list.append(row_img_cropped)
329
+ cells_img[k_row+'.'+str(row_idx)] = row_img_list
330
+ row_idx += 1
331
+
332
+ return cells_img