NiamaLynn commited on
Commit
27e9376
·
1 Parent(s): 695f273

Delete functions.py

Browse files
Files changed (1) hide show
  1. functions.py +0 -805
functions.py DELETED
@@ -1,805 +0,0 @@
1
- import os
2
- import gradio as gr
3
- import re
4
- import string
5
- import torch
6
-
7
- from operator import itemgetter
8
- import collections
9
-
10
- import pypdf
11
- from pypdf import PdfReader
12
- from pypdf.errors import PdfReadError
13
-
14
- import pdf2image
15
- from pdf2image import convert_from_path
16
- import langdetect
17
- from langdetect import detect_langs
18
-
19
- import pandas as pd
20
- import numpy as np
21
- import random
22
- import tempfile
23
- import itertools
24
-
25
- from matplotlib import font_manager
26
- from PIL import Image, ImageDraw, ImageFont
27
- import cv2
28
-
29
- # Tesseract
30
- print(os.popen(f'cat /etc/debian_version').read())
31
- print(os.popen(f'cat /etc/issue').read())
32
- print(os.popen(f'apt search tesseract').read())
33
- import pytesseract
34
-
35
- ## Key parameters
36
-
37
- # categories colors
38
- label2color = {
39
- 'Caption': 'brown',
40
- 'Footnote': 'orange',
41
- 'Formula': 'gray',
42
- 'List-item': 'yellow',
43
- 'Page-footer': 'red',
44
- 'Page-header': 'red',
45
- 'Picture': 'violet',
46
- 'Section-header': 'orange',
47
- 'Table': 'green',
48
- 'Text': 'blue',
49
- 'Title': 'pink'
50
- }
51
-
52
- # bounding boxes start and end of a sequence
53
- cls_box = [0, 0, 0, 0]
54
- sep_box = cls_box
55
-
56
- # model
57
- from transformers import AutoTokenizer, AutoModelForTokenClassification
58
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59
-
60
- model_id = "NiamaLynn/lilt-roberta-DocLayNet-base_lines_ml256-v1"
61
-
62
- tokenizer = AutoTokenizer.from_pretrained(model_id)
63
- model = AutoModelForTokenClassification.from_pretrained(model_id);
64
- model.to(device);
65
-
66
- # get labels
67
- id2label = model.config.id2label
68
- label2id = model.config.label2id
69
- num_labels = len(id2label)
70
-
71
- # (tokenization) The maximum length of a feature (sequence)
72
- if str(256) in model_id:
73
- max_length = 256
74
- elif str(512) in model_id:
75
- max_length = 512
76
- else:
77
- print("Error with max_length of chunks!")
78
-
79
- # (tokenization) overlap
80
- doc_stride = 128 # The authorized overlap between two part of the context when splitting it is needed.
81
-
82
- # max PDF page images that will be displayed
83
- max_imgboxes = 2
84
- examples_dir = 'files/'
85
- image_wo_content = examples_dir + "wo_content.png" # image without content
86
- pdf_blank = examples_dir + "blank.pdf" # blank PDF
87
- image_blank = examples_dir + "blank.png" # blank image
88
-
89
- ## get langdetect2Tesseract dictionary
90
- t = "files/languages_tesseract.csv"
91
- l = "files/languages_iso.csv"
92
-
93
- df_t = pd.read_csv(t)
94
- df_l = pd.read_csv(l)
95
-
96
- langs_t = df_t["Language"].to_list()
97
- langs_t = [lang_t.lower().strip().translate(str.maketrans('', '', string.punctuation)) for lang_t in langs_t]
98
- langs_l = df_l["Language"].to_list()
99
- langs_l = [lang_l.lower().strip().translate(str.maketrans('', '', string.punctuation)) for lang_l in langs_l]
100
- langscode_t = df_t["LangCode"].to_list()
101
- langscode_l = df_l["LangCode"].to_list()
102
-
103
- Tesseract2langdetect, langdetect2Tesseract = dict(), dict()
104
- for lang_t, langcode_t in zip(langs_t,langscode_t):
105
- try:
106
- if lang_t == "Chinese - Simplified".lower().strip().translate(str.maketrans('', '', string.punctuation)): lang_t = "chinese"
107
- index = langs_l.index(lang_t)
108
- langcode_l = langscode_l[index]
109
- Tesseract2langdetect[langcode_t] = langcode_l
110
- except:
111
- continue
112
-
113
- langdetect2Tesseract = {v:k for k,v in Tesseract2langdetect.items()}
114
-
115
- ## General
116
-
117
- # get text and bounding boxes from an image
118
- # https://stackoverflow.com/questions/61347755/how-can-i-get-line-coordinates-that-readed-by-tesseract
119
- # https://medium.com/geekculture/tesseract-ocr-understanding-the-contents-of-documents-beyond-their-text-a98704b7c655
120
- def get_data(results, factor, conf_min=0):
121
-
122
- data = {}
123
- for i in range(len(results['line_num'])):
124
- level = results['level'][i]
125
- block_num = results['block_num'][i]
126
- par_num = results['par_num'][i]
127
- line_num = results['line_num'][i]
128
- top, left = results['top'][i], results['left'][i]
129
- width, height = results['width'][i], results['height'][i]
130
- conf = results['conf'][i]
131
- text = results['text'][i]
132
- if not (text == '' or text.isspace()):
133
- if conf >= conf_min:
134
- tup = (text, left, top, width, height)
135
- if block_num in list(data.keys()):
136
- if par_num in list(data[block_num].keys()):
137
- if line_num in list(data[block_num][par_num].keys()):
138
- data[block_num][par_num][line_num].append(tup)
139
- else:
140
- data[block_num][par_num][line_num] = [tup]
141
- else:
142
- data[block_num][par_num] = {}
143
- data[block_num][par_num][line_num] = [tup]
144
- else:
145
- data[block_num] = {}
146
- data[block_num][par_num] = {}
147
- data[block_num][par_num][line_num] = [tup]
148
-
149
- # get paragraphs dicionnary with list of lines
150
- par_data = {}
151
- par_idx = 1
152
- for _, b in data.items():
153
- for _, p in b.items():
154
- line_data = {}
155
- line_idx = 1
156
- for _, l in p.items():
157
- line_data[line_idx] = l
158
- line_idx += 1
159
- par_data[par_idx] = line_data
160
- par_idx += 1
161
-
162
- # get lines of texts, grouped by paragraph
163
- lines = list()
164
- row_indexes = list()
165
- row_index = 0
166
- for _,par in par_data.items():
167
- count_lines = 0
168
- for _,line in par.items():
169
- if count_lines == 0: row_indexes.append(row_index)
170
- line_text = ' '.join([item[0] for item in line])
171
- lines.append(line_text)
172
- count_lines += 1
173
- row_index += 1
174
- # lines.append("\n")
175
- row_index += 1
176
- # lines = lines[:-1]
177
-
178
- # get paragraphes boxes (par_boxes)
179
- # get lines boxes (line_boxes)
180
- par_boxes = list()
181
- par_idx = 1
182
- line_boxes = list()
183
- line_idx = 1
184
- for _, par in par_data.items():
185
- xmins, ymins, xmaxs, ymaxs = list(), list(), list(), list()
186
- for _, line in par.items():
187
- xmin, ymin = line[0][1], line[0][2]
188
- xmax, ymax = (line[-1][1] + line[-1][3]), (line[-1][2] + line[-1][4])
189
- line_boxes.append([int(xmin/factor), int(ymin/factor), int(xmax/factor), int(ymax/factor)])
190
- xmins.append(xmin)
191
- ymins.append(ymin)
192
- xmaxs.append(xmax)
193
- ymaxs.append(ymax)
194
- line_idx += 1
195
- xmin, ymin, xmax, ymax = min(xmins), min(ymins), max(xmaxs), max(ymaxs)
196
- par_boxes.append([int(xmin/factor), int(ymin/factor), int(xmax/factor), int(ymax/factor)])
197
- par_idx += 1
198
-
199
- return lines, row_indexes, par_boxes, line_boxes #data, par_data #
200
-
201
- # rescale image to get 300dpi
202
- def set_image_dpi_resize(image):
203
- """
204
- Rescaling image to 300dpi while resizing
205
- :param image: An image
206
- :return: A rescaled image
207
- """
208
- length_x, width_y = image.size
209
- factor = min(1, float(1024.0 / length_x))
210
- size = int(factor * length_x), int(factor * width_y)
211
- image_resize = image.resize(size, Image.Resampling.LANCZOS)
212
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='1.png')
213
- temp_filename = temp_file.name
214
- image_resize.save(temp_filename, dpi=(300, 300))
215
- return factor, temp_filename
216
-
217
- # it is important that each bounding box should be in (upper left, lower right) format.
218
- # source: https://github.com/NielsRogge/Transformers-Tutorials/issues/129
219
- def upperleft_to_lowerright(bbox):
220
- x0, y0, x1, y1 = tuple(bbox)
221
- if bbox[2] < bbox[0]:
222
- x0 = bbox[2]
223
- x1 = bbox[0]
224
- if bbox[3] < bbox[1]:
225
- y0 = bbox[3]
226
- y1 = bbox[1]
227
- return [x0, y0, x1, y1]
228
-
229
- # convert boundings boxes (left, top, width, height) format to (left, top, left+widght, top+height) format.
230
- def convert_box(bbox):
231
- x, y, w, h = tuple(bbox) # the row comes in (left, top, width, height) format
232
- return [x, y, x+w, y+h] # we turn it into (left, top, left+widght, top+height) to get the actual box
233
-
234
- # LiLT model gets 1000x10000 pixels images
235
- def normalize_box(bbox, width, height):
236
- return [
237
- int(1000 * (bbox[0] / width)),
238
- int(1000 * (bbox[1] / height)),
239
- int(1000 * (bbox[2] / width)),
240
- int(1000 * (bbox[3] / height)),
241
- ]
242
-
243
- # LiLT model gets 1000x10000 pixels images
244
- def denormalize_box(bbox, width, height):
245
- return [
246
- int(width * (bbox[0] / 1000)),
247
- int(height * (bbox[1] / 1000)),
248
- int(width* (bbox[2] / 1000)),
249
- int(height * (bbox[3] / 1000)),
250
- ]
251
-
252
- # get back original size
253
- def original_box(box, original_width, original_height, coco_width, coco_height):
254
- return [
255
- int(original_width * (box[0] / coco_width)),
256
- int(original_height * (box[1] / coco_height)),
257
- int(original_width * (box[2] / coco_width)),
258
- int(original_height* (box[3] / coco_height)),
259
- ]
260
-
261
- def get_blocks(bboxes_block, categories, texts):
262
-
263
- # get list of unique block boxes
264
- bbox_block_dict, bboxes_block_list, bbox_block_prec = dict(), list(), list()
265
- for count_block, bbox_block in enumerate(bboxes_block):
266
- if bbox_block != bbox_block_prec:
267
- bbox_block_indexes = [i for i, bbox in enumerate(bboxes_block) if bbox == bbox_block]
268
- bbox_block_dict[count_block] = bbox_block_indexes
269
- bboxes_block_list.append(bbox_block)
270
- bbox_block_prec = bbox_block
271
-
272
- # get list of categories and texts by unique block boxes
273
- category_block_list, text_block_list = list(), list()
274
- for bbox_block in bboxes_block_list:
275
- count_block = bboxes_block.index(bbox_block)
276
- bbox_block_indexes = bbox_block_dict[count_block]
277
- category_block = np.array(categories, dtype=object)[bbox_block_indexes].tolist()[0]
278
- category_block_list.append(category_block)
279
- text_block = np.array(texts, dtype=object)[bbox_block_indexes].tolist()
280
- text_block = [text.replace("\n","").strip() for text in text_block]
281
- if id2label[category_block] == "Text" or id2label[category_block] == "Caption" or id2label[category_block] == "Footnote":
282
- text_block = ' '.join(text_block)
283
- else:
284
- text_block = '\n'.join(text_block)
285
- text_block_list.append(text_block)
286
-
287
- return bboxes_block_list, category_block_list, text_block_list
288
-
289
- # function to sort bounding boxes
290
- def get_sorted_boxes(bboxes):
291
-
292
- # sort by y from page top to bottom
293
- sorted_bboxes = sorted(bboxes, key=itemgetter(1), reverse=False)
294
- y_list = [bbox[1] for bbox in sorted_bboxes]
295
-
296
- # sort by x from page left to right when boxes with same y
297
- if len(list(set(y_list))) != len(y_list):
298
- y_list_duplicates_indexes = dict()
299
- y_list_duplicates = [item for item, count in collections.Counter(y_list).items() if count > 1]
300
- for item in y_list_duplicates:
301
- y_list_duplicates_indexes[item] = [i for i, e in enumerate(y_list) if e == item]
302
- bbox_list_y_duplicates = sorted(np.array(sorted_bboxes, dtype=object)[y_list_duplicates_indexes[item]].tolist(), key=itemgetter(0), reverse=False)
303
- np_array_bboxes = np.array(sorted_bboxes)
304
- np_array_bboxes[y_list_duplicates_indexes[item]] = np.array(bbox_list_y_duplicates)
305
- sorted_bboxes = np_array_bboxes.tolist()
306
-
307
- return sorted_bboxes
308
-
309
- # sort data from y = 0 to end of page (and after, x=0 to end of page when necessary)
310
- def sort_data(bboxes, categories, texts):
311
-
312
- sorted_bboxes = get_sorted_boxes(bboxes)
313
- sorted_bboxes_indexes = [bboxes.index(bbox) for bbox in sorted_bboxes]
314
- sorted_categories = np.array(categories, dtype=object)[sorted_bboxes_indexes].tolist()
315
- sorted_texts = np.array(texts, dtype=object)[sorted_bboxes_indexes].tolist()
316
-
317
- return sorted_bboxes, sorted_categories, sorted_texts
318
-
319
- # sort data from y = 0 to end of page (and after, x=0 to end of page when necessary)
320
- def sort_data_wo_labels(bboxes, texts):
321
-
322
- sorted_bboxes = get_sorted_boxes(bboxes)
323
- sorted_bboxes_indexes = [bboxes.index(bbox) for bbox in sorted_bboxes]
324
- sorted_texts = np.array(texts, dtype=object)[sorted_bboxes_indexes].tolist()
325
-
326
- return sorted_bboxes, sorted_texts
327
-
328
- ## PDF processing
329
-
330
- # get filename and images of PDF pages
331
- def pdf_to_images(uploaded_pdf):
332
-
333
- # Check if None object
334
- if uploaded_pdf is None:
335
- path_to_file = pdf_blank
336
- filename = path_to_file.replace(examples_dir,"")
337
- msg = "Invalid PDF file."
338
- images = [Image.open(image_blank)]
339
- else:
340
- # path to the uploaded PDF
341
- path_to_file = uploaded_pdf.name
342
- filename = path_to_file.replace("/tmp/","")
343
-
344
- try:
345
- PdfReader(path_to_file)
346
- except PdfReadError:
347
- path_to_file = pdf_blank
348
- filename = path_to_file.replace(examples_dir,"")
349
- msg = "Invalid PDF file."
350
- images = [Image.open(image_blank)]
351
- else:
352
- try:
353
- images = convert_from_path(path_to_file, last_page=max_imgboxes)
354
- num_imgs = len(images)
355
- msg = f'The PDF "{filename}" was converted into {num_imgs} images.'
356
- except:
357
- msg = f'Error with the PDF "{filename}": it was not converted into images.'
358
- images = [Image.open(image_wo_content)]
359
-
360
- return filename, msg, images
361
-
362
- # Extraction of image data (text and bounding boxes)
363
- def extraction_data_from_image(images):
364
-
365
- num_imgs = len(images)
366
-
367
- if num_imgs > 0:
368
-
369
- # https://pyimagesearch.com/2021/11/15/tesseract-page-segmentation-modes-psms-explained-how-to-improve-your-ocr-accuracy/
370
- custom_config = r'--oem 3 --psm 3 -l eng' # default config PyTesseract: --oem 3 --psm 3 -l eng+deu+fra+jpn+por+spa+rus+hin+chi_sim
371
- results, lines, row_indexes, par_boxes, line_boxes = dict(), dict(), dict(), dict(), dict()
372
- images_ids_list, lines_list, par_boxes_list, line_boxes_list, images_list, page_no_list, num_pages_list = list(), list(), list(), list(), list(), list(), list()
373
-
374
- try:
375
- for i,image in enumerate(images):
376
- # image preprocessing
377
- # https://docs.opencv.org/3.0-beta/doc/py_tutorials/py_imgproc/py_thresholding/py_thresholding.html
378
- img = image.copy()
379
- factor, path_to_img = set_image_dpi_resize(img) # Rescaling to 300dpi
380
- img = Image.open(path_to_img)
381
- img = np.array(img, dtype='uint8') # convert PIL to cv2
382
- img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # gray scale image
383
- ret,img = cv2.threshold(img,127,255,cv2.THRESH_BINARY)
384
-
385
- # OCR PyTesseract | get langs of page
386
- txt = pytesseract.image_to_string(img, config=custom_config)
387
- txt = txt.strip().lower()
388
- txt = re.sub(r" +", " ", txt) # multiple space
389
- txt = re.sub(r"(\n\s*)+\n+", "\n", txt) # multiple line
390
- # txt = os.popen(f'tesseract {img_filepath} - {custom_config}').read()
391
- try:
392
- langs = detect_langs(txt)
393
- langs = [langdetect2Tesseract[langs[i].lang] for i in range(len(langs))]
394
- langs_string = '+'.join(langs)
395
- except:
396
- langs_string = "eng"
397
- langs_string += '+osd'
398
- custom_config = f'--oem 3 --psm 3 -l {langs_string}' # default config PyTesseract: --oem 3 --psm 3
399
-
400
- # OCR PyTesseract | get data
401
- results[i] = pytesseract.image_to_data(img, config=custom_config, output_type=pytesseract.Output.DICT)
402
- # results[i] = os.popen(f'tesseract {img_filepath} - {custom_config}').read()
403
-
404
- lines[i], row_indexes[i], par_boxes[i], line_boxes[i] = get_data(results[i], factor, conf_min=0)
405
- lines_list.append(lines[i])
406
- par_boxes_list.append(par_boxes[i])
407
- line_boxes_list.append(line_boxes[i])
408
- images_ids_list.append(i)
409
- images_list.append(images[i])
410
- page_no_list.append(i)
411
- num_pages_list.append(num_imgs)
412
-
413
- except:
414
- print(f"There was an error within the extraction of PDF text by the OCR!")
415
- else:
416
- from datasets import Dataset
417
- dataset = Dataset.from_dict({"images_ids": images_ids_list, "images": images_list, "page_no": page_no_list, "num_pages": num_pages_list, "texts": lines_list, "bboxes_line": line_boxes_list})
418
-
419
- # print(f"The text data was successfully extracted by the OCR!")
420
-
421
- return dataset, lines, row_indexes, par_boxes, line_boxes
422
-
423
- ## Inference
424
-
425
- def prepare_inference_features(example, cls_box = cls_box, sep_box = sep_box):
426
-
427
- images_ids_list, chunks_ids_list, input_ids_list, attention_mask_list, bb_list = list(), list(), list(), list(), list()
428
-
429
- # get batch
430
- batch_images_ids = example["images_ids"]
431
- batch_images = example["images"]
432
- batch_bboxes_line = example["bboxes_line"]
433
- batch_texts = example["texts"]
434
- batch_images_size = [image.size for image in batch_images]
435
-
436
- batch_width, batch_height = [image_size[0] for image_size in batch_images_size], [image_size[1] for image_size in batch_images_size]
437
-
438
- # add a dimension if not a batch but only one image
439
- if not isinstance(batch_images_ids, list):
440
- batch_images_ids = [batch_images_ids]
441
- batch_images = [batch_images]
442
- batch_bboxes_line = [batch_bboxes_line]
443
- batch_texts = [batch_texts]
444
- batch_width, batch_height = [batch_width], [batch_height]
445
-
446
- # process all images of the batch
447
- for num_batch, (image_id, boxes, texts, width, height) in enumerate(zip(batch_images_ids, batch_bboxes_line, batch_texts, batch_width, batch_height)):
448
- tokens_list = []
449
- bboxes_list = []
450
-
451
- # add a dimension if only on image
452
- if not isinstance(texts, list):
453
- texts, boxes = [texts], [boxes]
454
-
455
- # convert boxes to original
456
- normalize_bboxes_line = [normalize_box(upperleft_to_lowerright(box), width, height) for box in boxes]
457
-
458
- # sort boxes with texts
459
- # we want sorted lists from top to bottom of the image
460
- boxes, texts = sort_data_wo_labels(normalize_bboxes_line, texts)
461
-
462
- count = 0
463
- for box, text in zip(boxes, texts):
464
- tokens = tokenizer.tokenize(text)
465
- num_tokens = len(tokens) # get number of tokens
466
- tokens_list.extend(tokens)
467
-
468
- bboxes_list.extend([box] * num_tokens) # number of boxes must be the same as the number of tokens
469
-
470
- # use of return_overflowing_tokens=True / stride=doc_stride
471
- # to get parts of image with overlap
472
- # source: https://huggingface.co/course/chapter6/3b?fw=tf#handling-long-contexts
473
- encodings = tokenizer(" ".join(texts),
474
- truncation=True,
475
- padding="max_length",
476
- max_length=max_length,
477
- stride=doc_stride,
478
- return_overflowing_tokens=True,
479
- return_offsets_mapping=True
480
- )
481
-
482
- otsm = encodings.pop("overflow_to_sample_mapping")
483
- offset_mapping = encodings.pop("offset_mapping")
484
-
485
- # Let's label those examples and get their boxes
486
- sequence_length_prev = 0
487
- for i, offsets in enumerate(offset_mapping):
488
- # truncate tokens, boxes and labels based on length of chunk - 2 (special tokens <s> and </s>)
489
- sequence_length = len(encodings.input_ids[i]) - 2
490
- if i == 0: start = 0
491
- else: start += sequence_length_prev - doc_stride
492
- end = start + sequence_length
493
- sequence_length_prev = sequence_length
494
-
495
- # get tokens, boxes and labels of this image chunk
496
- bb = [cls_box] + bboxes_list[start:end] + [sep_box]
497
-
498
- # as the last chunk can have a length < max_length
499
- # we must to add [tokenizer.pad_token] (tokens), [sep_box] (boxes) and [-100] (labels)
500
- if len(bb) < max_length:
501
- bb = bb + [sep_box] * (max_length - len(bb))
502
-
503
- # append results
504
- input_ids_list.append(encodings["input_ids"][i])
505
- attention_mask_list.append(encodings["attention_mask"][i])
506
- bb_list.append(bb)
507
- images_ids_list.append(image_id)
508
- chunks_ids_list.append(i)
509
-
510
- return {
511
- "images_ids": images_ids_list,
512
- "chunk_ids": chunks_ids_list,
513
- "input_ids": input_ids_list,
514
- "attention_mask": attention_mask_list,
515
- "normalized_bboxes": bb_list,
516
- }
517
-
518
- from torch.utils.data import Dataset
519
-
520
- class CustomDataset(Dataset):
521
- def __init__(self, dataset, tokenizer):
522
- self.dataset = dataset
523
- self.tokenizer = tokenizer
524
-
525
- def __len__(self):
526
- return len(self.dataset)
527
-
528
- def __getitem__(self, idx):
529
- # get item
530
- example = self.dataset[idx]
531
- encoding = dict()
532
- encoding["images_ids"] = example["images_ids"]
533
- encoding["chunk_ids"] = example["chunk_ids"]
534
- encoding["input_ids"] = example["input_ids"]
535
- encoding["attention_mask"] = example["attention_mask"]
536
- encoding["bbox"] = example["normalized_bboxes"]
537
-
538
- return encoding
539
-
540
- import torch.nn.functional as F
541
-
542
- # get predictions at token level
543
- def predictions_token_level(images, custom_encoded_dataset):
544
-
545
- num_imgs = len(images)
546
- if num_imgs > 0:
547
-
548
- chunk_ids, input_ids, bboxes, outputs, token_predictions = dict(), dict(), dict(), dict(), dict()
549
- images_ids_list = list()
550
-
551
- for i,encoding in enumerate(custom_encoded_dataset):
552
-
553
- # get custom encoded data
554
- image_id = encoding['images_ids']
555
- chunk_id = encoding['chunk_ids']
556
- input_id = torch.tensor(encoding['input_ids'])[None]
557
- attention_mask = torch.tensor(encoding['attention_mask'])[None]
558
- bbox = torch.tensor(encoding['bbox'])[None]
559
-
560
- # save data in dictionnaries
561
- if image_id not in images_ids_list: images_ids_list.append(image_id)
562
-
563
- if image_id in chunk_ids: chunk_ids[image_id].append(chunk_id)
564
- else: chunk_ids[image_id] = [chunk_id]
565
-
566
- if image_id in input_ids: input_ids[image_id].append(input_id)
567
- else: input_ids[image_id] = [input_id]
568
-
569
- if image_id in bboxes: bboxes[image_id].append(bbox)
570
- else: bboxes[image_id] = [bbox]
571
-
572
- # get prediction with forward pass
573
- with torch.no_grad():
574
- output = model(
575
- input_ids=input_id,
576
- attention_mask=attention_mask,
577
- bbox=bbox
578
- )
579
-
580
- # save probabilities of predictions in dictionnary
581
- if image_id in outputs: outputs[image_id].append(F.softmax(output.logits.squeeze(), dim=-1))
582
- else: outputs[image_id] = [F.softmax(output.logits.squeeze(), dim=-1)]
583
-
584
- return outputs, images_ids_list, chunk_ids, input_ids, bboxes
585
-
586
- else:
587
- print("An error occurred while getting predictions!")
588
-
589
- from functools import reduce
590
-
591
- # Get predictions (line level)
592
- def predictions_line_level(dataset, outputs, images_ids_list, chunk_ids, input_ids, bboxes):
593
-
594
- ten_probs_dict, ten_input_ids_dict, ten_bboxes_dict = dict(), dict(), dict()
595
- bboxes_list_dict, input_ids_dict_dict, probs_dict_dict, df = dict(), dict(), dict(), dict()
596
-
597
- if len(images_ids_list) > 0:
598
-
599
- for i, image_id in enumerate(images_ids_list):
600
-
601
- # get image information
602
- images_list = dataset.filter(lambda example: example["images_ids"] == image_id)["images"]
603
- image = images_list[0]
604
- width, height = image.size
605
-
606
- # get data
607
- chunk_ids_list = chunk_ids[image_id]
608
- outputs_list = outputs[image_id]
609
- input_ids_list = input_ids[image_id]
610
- bboxes_list = bboxes[image_id]
611
-
612
- # create zeros tensors
613
- ten_probs = torch.zeros((outputs_list[0].shape[0] - 2)*len(outputs_list), outputs_list[0].shape[1])
614
- ten_input_ids = torch.ones(size=(1, (outputs_list[0].shape[0] - 2)*len(outputs_list)), dtype =int)
615
- ten_bboxes = torch.zeros(size=(1, (outputs_list[0].shape[0] - 2)*len(outputs_list), 4), dtype =int)
616
-
617
- if len(outputs_list) > 1:
618
-
619
- for num_output, (output, input_id, bbox) in enumerate(zip(outputs_list, input_ids_list, bboxes_list)):
620
- start = num_output*(max_length - 2) - max(0,num_output)*doc_stride
621
- end = start + (max_length - 2)
622
-
623
- if num_output == 0:
624
- ten_probs[start:end,:] += output[1:-1]
625
- ten_input_ids[:,start:end] = input_id[:,1:-1]
626
- ten_bboxes[:,start:end,:] = bbox[:,1:-1,:]
627
- else:
628
- ten_probs[start:start + doc_stride,:] += output[1:1 + doc_stride]
629
- ten_probs[start:start + doc_stride,:] = ten_probs[start:start + doc_stride,:] * 0.5
630
- ten_probs[start + doc_stride:end,:] += output[1 + doc_stride:-1]
631
-
632
- ten_input_ids[:,start:start + doc_stride] = input_id[:,1:1 + doc_stride]
633
- ten_input_ids[:,start + doc_stride:end] = input_id[:,1 + doc_stride:-1]
634
-
635
- ten_bboxes[:,start:start + doc_stride,:] = bbox[:,1:1 + doc_stride,:]
636
- ten_bboxes[:,start + doc_stride:end,:] = bbox[:,1 + doc_stride:-1,:]
637
-
638
- else:
639
- ten_probs += outputs_list[0][1:-1]
640
- ten_input_ids = input_ids_list[0][:,1:-1]
641
- ten_bboxes = bboxes_list[0][:,1:-1]
642
-
643
- ten_probs_list, ten_input_ids_list, ten_bboxes_list = ten_probs.tolist(), ten_input_ids.tolist()[0], ten_bboxes.tolist()[0]
644
- bboxes_list = list()
645
- input_ids_dict, probs_dict = dict(), dict()
646
- bbox_prev = [-100, -100, -100, -100]
647
- for probs, input_id, bbox in zip(ten_probs_list, ten_input_ids_list, ten_bboxes_list):
648
- bbox = denormalize_box(bbox, width, height)
649
- if bbox != bbox_prev and bbox != cls_box and bbox != sep_box and bbox[0] != bbox[2] and bbox[1] != bbox[3]:
650
- bboxes_list.append(bbox)
651
- input_ids_dict[str(bbox)] = [input_id]
652
- probs_dict[str(bbox)] = [probs]
653
- elif bbox != cls_box and bbox != sep_box and bbox[0] != bbox[2] and bbox[1] != bbox[3]:
654
- input_ids_dict[str(bbox)].append(input_id)
655
- probs_dict[str(bbox)].append(probs)
656
- bbox_prev = bbox
657
-
658
- probs_bbox = dict()
659
- for i,bbox in enumerate(bboxes_list):
660
- probs = probs_dict[str(bbox)]
661
- probs = np.array(probs).T.tolist()
662
-
663
- probs_label = list()
664
- for probs_list in probs:
665
- prob_label = reduce(lambda x, y: x*y, probs_list)
666
- prob_label = prob_label**(1./(len(probs_list))) # normalization
667
- probs_label.append(prob_label)
668
- max_value = max(probs_label)
669
- max_index = probs_label.index(max_value)
670
- probs_bbox[str(bbox)] = max_index
671
-
672
- bboxes_list_dict[image_id] = bboxes_list
673
- input_ids_dict_dict[image_id] = input_ids_dict
674
- probs_dict_dict[image_id] = probs_bbox
675
-
676
- df[image_id] = pd.DataFrame()
677
- df[image_id]["bboxes"] = bboxes_list
678
- df[image_id]["texts"] = [tokenizer.decode(input_ids_dict[str(bbox)]) for bbox in bboxes_list]
679
- df[image_id]["labels"] = [id2label[probs_bbox[str(bbox)]] for bbox in bboxes_list]
680
-
681
- return probs_bbox, bboxes_list_dict, input_ids_dict_dict, probs_dict_dict, df
682
-
683
- else:
684
- print("An error occurred while getting predictions!")
685
-
686
- # Get labeled images with lines bounding boxes
687
- def get_labeled_images(dataset, images_ids_list, bboxes_list_dict, probs_dict_dict):
688
-
689
- labeled_images = list()
690
-
691
- for i, image_id in enumerate(images_ids_list):
692
-
693
- # get image
694
- images_list = dataset.filter(lambda example: example["images_ids"] == image_id)["images"]
695
- image = images_list[0]
696
- width, height = image.size
697
-
698
- # get predicted boxes and labels
699
- bboxes_list = bboxes_list_dict[image_id]
700
- probs_bbox = probs_dict_dict[image_id]
701
-
702
- draw = ImageDraw.Draw(image)
703
- # https://stackoverflow.com/questions/66274858/choosing-a-pil-imagefont-by-font-name-rather-than-filename-and-cross-platform-f
704
- font = font_manager.FontProperties(family='sans-serif', weight='bold')
705
- font_file = font_manager.findfont(font)
706
- font_size = 30
707
- font = ImageFont.truetype(font_file, font_size)
708
-
709
- for bbox in bboxes_list:
710
- predicted_label = id2label[probs_bbox[str(bbox)]]
711
- draw.rectangle(bbox, outline=label2color[predicted_label])
712
- draw.text((bbox[0] + 10, bbox[1] - font_size), text=predicted_label, fill=label2color[predicted_label], font=font)
713
-
714
- labeled_images.append(image)
715
-
716
- return labeled_images
717
-
718
- # get data of encoded chunk
719
- def get_encoded_chunk_inference(index_chunk=None):
720
-
721
- # get datasets
722
- example = dataset
723
- encoded_example = encoded_dataset
724
-
725
- # get randomly a document in dataset
726
- if index_chunk == None: index_chunk = random.randint(0, len(encoded_example)-1)
727
- encoded_example = encoded_example[index_chunk]
728
- encoded_image_ids = encoded_example["images_ids"]
729
-
730
- # get the image
731
- example = example.filter(lambda example: example["images_ids"] == encoded_image_ids)[0]
732
- image = example["images"] # original image
733
- width, height = image.size
734
- page_no = example["page_no"]
735
- num_pages = example["num_pages"]
736
-
737
- # get boxes, texts, categories
738
- bboxes, input_ids = encoded_example["normalized_bboxes"][1:-1], encoded_example["input_ids"][1:-1]
739
- bboxes = [denormalize_box(bbox, width, height) for bbox in bboxes]
740
- num_tokens = len(input_ids) + 2
741
-
742
- # get unique bboxes and corresponding labels
743
- bboxes_list, input_ids_list = list(), list()
744
- input_ids_dict = dict()
745
- bbox_prev = [-100, -100, -100, -100]
746
- for i, (bbox, input_id) in enumerate(zip(bboxes, input_ids)):
747
- if bbox != bbox_prev:
748
- bboxes_list.append(bbox)
749
- input_ids_dict[str(bbox)] = [input_id]
750
- else:
751
- input_ids_dict[str(bbox)].append(input_id)
752
-
753
- # start_indexes_list.append(i)
754
- bbox_prev = bbox
755
-
756
- # do not keep "</s><pad><pad>..."
757
- if input_ids_dict[str(bboxes_list[-1])][0] == (tokenizer.convert_tokens_to_ids('</s>')):
758
- del input_ids_dict[str(bboxes_list[-1])]
759
- bboxes_list = bboxes_list[:-1]
760
-
761
- # get texts by line
762
- input_ids_list = input_ids_dict.values()
763
- texts_list = [tokenizer.decode(input_ids) for input_ids in input_ids_list]
764
-
765
- # display DataFrame
766
- df = pd.DataFrame({"texts": texts_list, "input_ids": input_ids_list, "bboxes": bboxes_list})
767
-
768
- return image, df, num_tokens, page_no, num_pages
769
-
770
- # display chunk of PDF image and its data
771
- def display_chunk_lines_inference(index_chunk=None):
772
-
773
- # get image and image data
774
- image, df, num_tokens, page_no, num_pages = get_encoded_chunk_inference(index_chunk=index_chunk)
775
-
776
- # get data from dataframe
777
- input_ids = df["input_ids"]
778
- texts = df["texts"]
779
- bboxes = df["bboxes"]
780
-
781
- print(f'Chunk ({num_tokens} tokens) of the PDF (page: {page_no+1} / {num_pages})\n')
782
-
783
- # display image with bounding boxes
784
- print(">> PDF image with bounding boxes of lines\n")
785
- draw = ImageDraw.Draw(image)
786
-
787
- labels = list()
788
- for box, text in zip(bboxes, texts):
789
- color = "red"
790
- draw.rectangle(box, outline=color)
791
-
792
- # resize image to original
793
- width, height = image.size
794
- image = image.resize((int(0.5*width), int(0.5*height)))
795
-
796
- # convert to cv and display
797
- img = np.array(image, dtype='uint8') # PIL to cv2
798
- cv2_imshow(img)
799
- cv2.waitKey(0)
800
-
801
- # display image dataframe
802
- print("\n>> Dataframe of annotated lines\n")
803
- cols = ["texts", "bboxes"]
804
- df = df[cols]
805
- display(df)