SalML commited on
Commit
19d52ee
·
1 Parent(s): 37dcf5d

Delete TDTSR.py

Browse files
Files changed (1) hide show
  1. TDTSR.py +0 -332
TDTSR.py DELETED
@@ -1,332 +0,0 @@
1
- import os
2
- import cv2
3
- from transformers import DetrFeatureExtractor
4
- from transformers import DetrForObjectDetection
5
- import torch
6
- import matplotlib.pyplot as plt
7
- from matplotlib.patches import Circle, Wedge, Rectangle
8
- import streamlit as st
9
- from PIL import Image
10
- import math
11
-
12
-
13
- colors = ["red", "blue", "green", "yellow", "orange", "violet"]
14
-
15
-
16
- def table_detector(image, THRESHOLD_PROBA):
17
- '''
18
- Table detection using DEtect-object TRansformer pre-trained on 1 million tables
19
-
20
- '''
21
-
22
- feature_extractor = DetrFeatureExtractor(do_resize=True, size=800, max_size=800)
23
- encoding = feature_extractor(image, return_tensors="pt")
24
- # encoding.keys()
25
- model = DetrForObjectDetection.from_pretrained("SalML/DETR-table-detection")
26
- # SalML\DETR-table-detection
27
- with torch.no_grad():
28
- outputs = model(**encoding)
29
-
30
- # keep only predictions of queries with 0.9+ confidence (excluding no-object class)
31
- probas = outputs.logits.softmax(-1)[0, :, :-1]
32
- keep = probas.max(-1).values > THRESHOLD_PROBA
33
-
34
- # rescale bounding boxes
35
- target_sizes = torch.tensor(image.size[::-1]).unsqueeze(0)
36
- postprocessed_outputs = feature_extractor.post_process(outputs, target_sizes)
37
- bboxes_scaled = postprocessed_outputs[0]['boxes'][keep]
38
-
39
- return (model, image, probas[keep], bboxes_scaled)
40
-
41
-
42
- def table_struct_recog(image, THRESHOLD_PROBA):
43
- '''
44
- Table structure recognition using DEtect-object TRansformer pre-trained on 1 million tables
45
- '''
46
-
47
- feature_extractor = DetrFeatureExtractor(do_resize=True, size=1000, max_size=1000)
48
- encoding = feature_extractor(image, return_tensors="pt")
49
-
50
- model = DetrForObjectDetection.from_pretrained("SalML/DETR-table-structure-recognition")
51
- with torch.no_grad():
52
- outputs = model(**encoding)
53
-
54
- # keep only predictions of queries with 0.9+ confidence (excluding no-object class)
55
- probas = outputs.logits.softmax(-1)[0, :, :-1]
56
- keep = probas.max(-1).values > THRESHOLD_PROBA
57
-
58
- # rescale bounding boxes
59
- target_sizes = torch.tensor(image.size[::-1]).unsqueeze(0)
60
- postprocessed_outputs = feature_extractor.post_process(outputs, target_sizes)
61
- bboxes_scaled = postprocessed_outputs[0]['boxes'][keep]
62
-
63
- return (model, image, probas[keep], bboxes_scaled)
64
-
65
- def add_margin(pil_img, top=20, right=20, bottom=20, left=20, color=(255,255,255)):
66
- '''
67
- Image padding as part of TSR pre-processing to prevent missing table edges
68
- '''
69
- width, height = pil_img.size
70
- new_width = width + right + left
71
- new_height = height + top + bottom
72
- result = Image.new(pil_img.mode, (new_width, new_height), color)
73
- result.paste(pil_img, (left, top))
74
- return result
75
-
76
- def plot_results_detection(c1, model, pil_img, prob, boxes, show_only_cropped=False):
77
- '''
78
- Plots the full pillow pdf-page image and adds a rectangle patch for table detection
79
- '''
80
-
81
- plt.figure(figsize=(32,20))
82
- plt.imshow(pil_img)
83
- ax = plt.gca()
84
-
85
- for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()):
86
-
87
- cl = p.argmax()
88
- xmin, ymin, xmax, ymax = xmin-3, ymin-3, xmax+3, ymax+3
89
- ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,fill=False, color=colors[cl.item()], linewidth=3))
90
- text = f'{model.config.id2label[cl.item()]}: {p[cl]:0.2f}'
91
- ax.text(xmin, ymin, text, fontsize=15,bbox=dict(facecolor='yellow', alpha=0.5))
92
- plt.axis('off')
93
- plt.show()
94
- c1.pyplot()
95
-
96
-
97
- def plot_table_detection(c2, model, pil_img, prob, boxes):
98
- '''
99
- Plots only the cropped table(s) from the table detection
100
- '''
101
-
102
- plt.figure(figsize=(32,20))
103
- ax = plt.gca()
104
- cropped_img_list = []
105
-
106
- for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()):
107
-
108
- xmin, ymin, xmax, ymax = xmin-3, ymin-3, xmax+3, ymax+3
109
- cropped_img = pil_img.crop((xmin, ymin, xmax, ymax))
110
- cropped_img_list.append(cropped_img)
111
-
112
- for cropped_img in cropped_img_list:
113
- plt.imshow(cropped_img)
114
-
115
- plt.axis('off')
116
- plt.show()
117
- c2.pyplot()
118
- return cropped_img_list
119
-
120
-
121
- def plot_structure(c3, model, pil_img, prob, boxes, class_to_show=0):
122
- '''
123
- To plot table pillow image and the TSR bounding boxes on the table
124
- '''
125
- plt.figure(figsize=(32,20))
126
- plt.imshow(pil_img)
127
- ax = plt.gca()
128
- rows = {}
129
- cols = {}
130
- header = {}
131
- row_header = {}
132
- idx = 0
133
-
134
- for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()):
135
-
136
- xmin, ymin, xmax, ymax = xmin-3, ymin-3, xmax+3, ymax+3
137
- cl = p.argmax()
138
- class_text = model.config.id2label[cl.item()]
139
- text = f'{class_text}: {p[cl]:0.2f}'
140
- # st.write(class_text)
141
- if class_text != 'table':
142
-
143
- ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,fill=False, color=colors[cl.item()], linewidth=3))
144
- ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='yellow', alpha=0.5))
145
-
146
- # if class_text == 'table column header':
147
- # header['header'] = (xmin, ymin, xmax, ymax)
148
- if class_text == 'table row':
149
- rows['table row '+str(idx)] = (xmin, ymin, xmax, ymax)
150
- if class_text == 'table column':
151
- cols['table column '+str(idx)] = (xmin, ymin, xmax, ymax)
152
- # if class_text == 'table projected row header':
153
- # row_header['header table row'+str(idx)] = (xmin, ymin, xmax, ymax)
154
-
155
- idx += 1
156
-
157
- plt.show()
158
- c3.pyplot()
159
- # return header, row_header, rows, cols
160
- return rows, cols
161
-
162
-
163
-
164
- def sort_table_features(header, row_header, rows, cols):
165
- # Sometimes the header and first row overlap, and we need the header bbox not to have first row's bbox inside the headers bbox
166
- y_header = header['header'][3] - 10
167
- rows_ = {table_feature : (xmin, ymin, xmax, ymax) for table_feature, (xmin, ymin, xmax, ymax) in sorted(rows.items(), key=lambda tup: tup[1][1]) if ymin > y_header}
168
- cols_ = {table_feature : (xmin, ymin, xmax, ymax) for table_feature, (xmin, ymin, xmax, ymax) in sorted(cols.items(), key=lambda tup: tup[1][0])}
169
-
170
- row_header_ = {table_feature : (xmin, ymin, xmax, ymax) for table_feature, (xmin, ymin, xmax, ymax) in sorted(row_header.items(), key=lambda tup: tup[1][1])}
171
-
172
- new_row = {}
173
- idx = 0
174
-
175
- for k1, v1 in rows_.items():
176
- save_row = True
177
- row_xmin, row_ymin, row_xmax, row_ymax = v1
178
- for k2, v2 in row_header_.items():
179
- header_row_xmin, header_row_ymin, header_row_xmax, header_row_ymax = v2
180
- # table row and header table row are within 2 pixel range, skip saving the row
181
- if math.isclose(row_ymin, header_row_ymin, abs_tol=2):
182
- save_row = False
183
- if save_row:
184
- new_row['table row.'+str(idx)] = (row_xmin, row_ymin, row_xmax, row_ymax)
185
- idx += 1
186
-
187
- new_row_ = {table_feature : (xmin, ymin, xmax, ymax) for table_feature, (xmin, ymin, xmax, ymax) in sorted(new_row.items(), key=lambda tup: tup[1][1])}
188
-
189
- return row_header_, new_row_, cols_
190
-
191
-
192
- def sort_table_featuresv2(rows, cols):
193
- # Sometimes the header and first row overlap, and we need the header bbox not to have first row's bbox inside the headers bbox
194
- rows_ = {table_feature : (xmin, ymin, xmax, ymax) for table_feature, (xmin, ymin, xmax, ymax) in sorted(rows.items(), key=lambda tup: tup[1][1])}
195
- cols_ = {table_feature : (xmin, ymin, xmax, ymax) for table_feature, (xmin, ymin, xmax, ymax) in sorted(cols.items(), key=lambda tup: tup[1][0])}
196
-
197
- return rows_, cols_
198
-
199
- def individual_table_features(pil_img, header, row_header, rows, cols):
200
-
201
- for k, v in header.items():
202
- xmin, ymin, xmax, ymax = v
203
- cropped_img = pil_img.crop((xmin, ymin, xmax, ymax))
204
- header[k] = xmin, ymin, xmax, ymax, cropped_img
205
-
206
- for k, v in row_header.items():
207
- xmin, ymin, xmax, ymax = v
208
- cropped_img = pil_img.crop((xmin, ymin, xmax, ymax))
209
- row_header[k] = xmin, ymin, xmax, ymax, cropped_img
210
-
211
- for k, v in rows.items():
212
- xmin, ymin, xmax, ymax = v
213
- cropped_img = pil_img.crop((xmin, ymin, xmax, ymax))
214
- rows[k] = xmin, ymin, xmax, ymax, cropped_img
215
-
216
-
217
- for k, v in cols.items():
218
- xmin, ymin, xmax, ymax = v
219
- cropped_img = pil_img.crop((xmin, ymin, xmax, ymax))
220
- cols[k] = xmin, ymin, xmax, ymax, cropped_img
221
-
222
- return header, row_header, rows, cols
223
-
224
- def individual_table_featuresv2(pil_img, rows, cols):
225
-
226
-
227
- for k, v in rows.items():
228
- xmin, ymin, xmax, ymax = v
229
- cropped_img = pil_img.crop((xmin, ymin, xmax, ymax))
230
- rows[k] = xmin, ymin, xmax, ymax, cropped_img
231
-
232
-
233
- for k, v in cols.items():
234
- xmin, ymin, xmax, ymax = v
235
- cropped_img = pil_img.crop((xmin, ymin, xmax, ymax))
236
- cols[k] = xmin, ymin, xmax, ymax, cropped_img
237
-
238
- return rows, cols
239
-
240
- def plot_table_features(c2, header, row_header, rows, cols):
241
-
242
- for k, v in header.items():
243
- _, _, _, _, pil_img = v
244
-
245
- for k, v in row_header.items():
246
- _, _, _, _, pil_img = v
247
-
248
- for k, v in rows.items():
249
- _, _, _, _, pil_img = v
250
-
251
- for k, v in cols.items():
252
- _, _, _, _, pil_img = v
253
-
254
-
255
- def master_row_set(header, row_header, rows, cols):
256
- master_row = {**header, **row_header, **rows}
257
- master_row_ = {table_feature : (xmin, ymin, xmax, ymax, img) for table_feature, (xmin, ymin, xmax, ymax, img) in sorted(master_row.items(), key=lambda tup: tup[1][1])}
258
-
259
- return master_row_
260
-
261
-
262
-
263
-
264
- def object_to_cells(master_row, cols):
265
- '''
266
- Iterates to every row, be it header/simple row/header table row, cuts rows into cells and saves images in dictionary where length of dictionary = total rows
267
- '''
268
- cells_img = {}
269
- header_idx = 0
270
- row_idx = 0
271
- for k_row, v_row in master_row.items():
272
-
273
- if k_row[:16] == 'header table row':
274
-
275
- _, _, _, _, row_header_img = v_row
276
- cells_img[k_row+'.'+str(row_idx)] = row_header_img
277
- row_idx += 1
278
-
279
- elif k_row == 'header':
280
-
281
- _, ymin, _, ymax, header_img = v_row
282
-
283
- xa, ya, xb, yb = 0, 0, 0, ymax-ymin
284
- for k_col, v_col in cols.items():
285
- xmin_col, _, xmax_col, _, col_img = v_col
286
- xa = xmin_col-19
287
- xb = xmax_col-20
288
-
289
- header_img_cropped = header_img.crop((xa, ya, xb, yb))
290
- cells_img[k_row+'.'+str(header_idx)] = header_img_cropped
291
- header_idx += 1
292
-
293
-
294
- elif k_row[:9] == 'table row':
295
-
296
- xmin, ymin, xmax, ymax, row_img = v_row
297
- xa, ya, xb, yb = 0, 0, 0, ymax-ymin
298
- row_img_list = []
299
- for k_col, v_col in cols.items():
300
- xmin_col, _, xmax_col, _, col_img = v_col
301
- xa = xmin_col-19
302
- xb = xmax_col-20
303
- row_img_cropped = row_img.crop((xa, ya, xb, yb))
304
- row_img_list.append(row_img_cropped)
305
- cells_img[k_row+'.'+str(row_idx)] = row_img_list
306
- row_idx += 1
307
-
308
- return cells_img
309
-
310
-
311
- def object_to_cellsv2(master_row, cols):
312
- '''
313
- Iterates to every row, be it header/simple row/header table row, cuts rows into cells and saves images in dictionary where length of dictionary = total rows
314
- '''
315
- cells_img = {}
316
- header_idx = 0
317
- row_idx = 0
318
- for k_row, v_row in master_row.items():
319
-
320
- xmin, ymin, xmax, ymax, row_img = v_row
321
- xa, ya, xb, yb = 0, 0, 0, ymax-ymin
322
- row_img_list = []
323
- for k_col, v_col in cols.items():
324
- xmin_col, _, xmax_col, _, col_img = v_col
325
- xa = xmin_col-19
326
- xb = xmax_col-20
327
- row_img_cropped = row_img.crop((xa, ya, xb, yb))
328
- row_img_list.append(row_img_cropped)
329
- cells_img[k_row+'.'+str(row_idx)] = row_img_list
330
- row_idx += 1
331
-
332
- return cells_img