NiamaLynn commited on
Commit
9140d7b
·
1 Parent(s): f5880a9

Delete files/functions.py

Browse files
Files changed (1) hide show
  1. files/functions.py +0 -784
files/functions.py DELETED
@@ -1,784 +0,0 @@
1
- import os
2
- import gradio as gr
3
- import re
4
- import string
5
- import torch
6
-
7
- from operator import itemgetter
8
- import collections
9
-
10
- import pypdf
11
- from pypdf import PdfReader
12
- from pypdf.errors import PdfReadError
13
-
14
- import pdf2image
15
- from pdf2image import convert_from_path
16
- import langdetect
17
- from langdetect import detect_langs
18
-
19
- import pandas as pd
20
- import numpy as np
21
- import random
22
- import tempfile
23
- import itertools
24
-
25
- from matplotlib import font_manager
26
- from PIL import Image, ImageDraw, ImageFont
27
- import cv2
28
-
29
-
30
-
31
- # Tesseract
32
- print(os.popen(f'cat /etc/debian_version').read())
33
- print(os.popen(f'cat /etc/issue').read())
34
- print(os.popen(f'apt search tesseract').read())
35
- import pytesseract
36
-
37
- ## Key parameters
38
-
39
- # categories colors
40
- label2color = {
41
- 'Caption': 'brown',
42
- 'Footnote': 'orange',
43
- 'Formula': 'gray',
44
- 'List-item': 'yellow',
45
- 'Page-footer': 'red',
46
- 'Page-header': 'red',
47
- 'Picture': 'violet',
48
- 'Section-header': 'orange',
49
- 'Table': 'green',
50
- 'Text': 'blue',
51
- 'Title': 'pink'
52
- }
53
-
54
- # bounding boxes start and end of a sequence
55
- cls_box = [0, 0, 0, 0]
56
- sep_box = cls_box
57
-
58
- # model
59
- from transformers import AutoTokenizer, AutoModelForTokenClassification
60
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
-
62
- model_id = "NiamaLynn/lilt-roberta-DocLayNet-base_lines_ml256-v1"
63
-
64
- tokenizer = AutoTokenizer.from_pretrained(model_id)
65
- model = AutoModelForTokenClassification.from_pretrained(model_id);
66
- model.to(device);
67
-
68
- # get labels
69
- id2label = model.config.id2label
70
- label2id = model.config.label2id
71
- num_labels = len(id2label)
72
-
73
- # (tokenization) The maximum length of a feature (sequence)
74
- if str(256) in model_id:
75
- max_length = 256
76
- elif str(512) in model_id:
77
- max_length = 512
78
- else:
79
- print("Error with max_length of chunks!")
80
-
81
- # (tokenization) overlap
82
- doc_stride = 128 # The authorized overlap between two part of the context when splitting it is needed.
83
-
84
- # max PDF page images that will be displayed
85
- max_imgboxes = 2
86
- examples_dir = 'files/'
87
- image_wo_content = examples_dir + "wo_content.png" # image without content
88
- pdf_blank = examples_dir + "blank.pdf" # blank PDF
89
- image_blank = examples_dir + "blank.png" # blank image
90
-
91
- ## get langdetect2Tesseract dictionary
92
- t = "files/languages_tesseract.csv"
93
- l = "files/languages_iso.csv"
94
-
95
- df_t = pd.read_csv(t)
96
- df_l = pd.read_csv(l)
97
-
98
- langs_t = df_t["Language"].to_list()
99
- langs_t = [lang_t.lower().strip().translate(str.maketrans('', '', string.punctuation)) for lang_t in langs_t]
100
- langs_l = df_l["Language"].to_list()
101
- langs_l = [lang_l.lower().strip().translate(str.maketrans('', '', string.punctuation)) for lang_l in langs_l]
102
- langscode_t = df_t["LangCode"].to_list()
103
- langscode_l = df_l["LangCode"].to_list()
104
-
105
- Tesseract2langdetect, langdetect2Tesseract = dict(), dict()
106
- for lang_t, langcode_t in zip(langs_t,langscode_t):
107
- try:
108
- if lang_t == "Chinese - Simplified".lower().strip().translate(str.maketrans('', '', string.punctuation)): lang_t = "chinese"
109
- index = langs_l.index(lang_t)
110
- langcode_l = langscode_l[index]
111
- Tesseract2langdetect[langcode_t] = langcode_l
112
- except:
113
- continue
114
-
115
- langdetect2Tesseract = {v:k for k,v in Tesseract2langdetect.items()}
116
-
117
- ## General
118
-
119
- # get text and bounding boxes from an image
120
- # https://stackoverflow.com/questions/61347755/how-can-i-get-line-coordinates-that-readed-by-tesseract
121
- # https://medium.com/geekculture/tesseract-ocr-understanding-the-contents-of-documents-beyond-their-text-a98704b7c655
122
- def get_data(results, factor, conf_min=0):
123
-
124
- data = {}
125
- for i in range(len(results['line_num'])):
126
- level = results['level'][i]
127
- block_num = results['block_num'][i]
128
- par_num = results['par_num'][i]
129
- line_num = results['line_num'][i]
130
- top, left = results['top'][i], results['left'][i]
131
- width, height = results['width'][i], results['height'][i]
132
- conf = results['conf'][i]
133
- text = results['text'][i]
134
-
135
- if not (text == '' or text.isspace()):
136
- if conf >= conf_min:
137
- tup = (text, left, top, width, height)
138
- data.setdefault(block_num, {}).setdefault(par_num, {}).setdefault(line_num, []).append(tup)
139
-
140
- # get paragraphs dictionary with a list of lines
141
- par_data = {}
142
- par_idx = 1
143
- for _, b in data.items():
144
- for _, p in b.items():
145
- line_data = {}
146
- line_idx = 1
147
- for _, l in p.items():
148
- line_data[line_idx] = l
149
- line_idx += 1
150
- par_data[par_idx] = line_data
151
- par_idx += 1
152
-
153
- # get lines of texts, grouped by paragraph
154
- lines = list()
155
- row_indexes = list()
156
- row_index = 0
157
- for _, par in par_data.items():
158
- count_lines = 0
159
- for _, line in par.items():
160
- if count_lines == 0:
161
- row_indexes.append(row_index)
162
- line_text = ' '.join([item[0] for item in line])
163
- lines.append(line_text)
164
- count_lines += 1
165
- row_index += 1
166
- row_index += 1
167
-
168
- # get paragraphs boxes (par_boxes)
169
- # get lines boxes (line_boxes)
170
- par_boxes = list()
171
- par_idx = 1
172
- line_boxes = list()
173
- line_idx = 1
174
- for _, par in par_data.items():
175
- xmins, ymins, xmaxs, ymaxs = [], [], [], []
176
- for _, line in par.items():
177
- xmin, ymin = line[0][1], line[0][2]
178
- xmax, ymax = (line[-1][1] + line[-1][3]), (line[-1][2] + line[-1][4])
179
- line_boxes.append([int(xmin/factor), int(ymin/factor), int(xmax/factor), int(ymax/factor)])
180
- xmins.append(xmin)
181
- ymins.append(ymin)
182
- xmaxs.append(xmax)
183
- ymaxs.append(ymax)
184
- line_idx += 1
185
- xmin, ymin, xmax, ymax = min(xmins), min(ymins), max(xmaxs), max(ymaxs)
186
- par_boxes.append([int(xmin/factor), int(ymin/factor), int(xmax/factor), int(ymax/factor)])
187
- par_idx += 1
188
-
189
- return lines, row_indexes, par_boxes, line_boxes
190
-
191
-
192
-
193
- # rescale image to get 300dpi
194
- def set_image_dpi_resize(image):
195
- length_x, width_y = image.shape[1], image.shape[0]
196
- factor = min(1, float(1024.0 / length_x))
197
- size = (int(factor * length_x), int(factor * width_y))
198
- image_resize = cv2.resize(image, size, interpolation=cv2.INTER_LANCZOS4)
199
- _, temp_filename = tempfile.mkstemp(suffix='.png')
200
- cv2.imwrite(temp_filename, image_resize, [int(cv2.IMWRITE_PNG_COMPRESSION), 9])
201
- return factor, temp_filename
202
-
203
-
204
- # it is important that each bounding box should be in (upper left, lower right) format.
205
- # source: https://github.com/NielsRogge/Transformers-Tutorials/issues/129
206
- def upperleft_to_lowerright(bbox):
207
- x0, y0, x1, y1 = tuple(bbox)
208
- if bbox[2] < bbox[0]:
209
- x0 = bbox[2]
210
- x1 = bbox[0]
211
- if bbox[3] < bbox[1]:
212
- y0 = bbox[3]
213
- y1 = bbox[1]
214
- return [x0, y0, x1, y1]
215
-
216
- # convert boundings boxes (left, top, width, height) format to (left, top, left+widght, top+height) format.
217
- def convert_box(bbox):
218
- x, y, w, h = tuple(bbox) # the row comes in (left, top, width, height) format
219
- return [x, y, x+w, y+h] # we turn it into (left, top, left+widght, top+height) to get the actual box
220
-
221
- # LiLT model gets 1000x10000 pixels images
222
- def normalize_box(bbox, width, height):
223
- return [
224
- int(1000 * (bbox[0] / width)),
225
- int(1000 * (bbox[1] / height)),
226
- int(1000 * (bbox[2] / width)),
227
- int(1000 * (bbox[3] / height)),
228
- ]
229
-
230
- # LiLT model gets 1000x10000 pixels images
231
- def denormalize_box(bbox, width, height):
232
- return [
233
- int(width * (bbox[0] / 1000)),
234
- int(height * (bbox[1] / 1000)),
235
- int(width* (bbox[2] / 1000)),
236
- int(height * (bbox[3] / 1000)),
237
- ]
238
-
239
- # get back original size
240
- def original_box(box, original_width, original_height, coco_width, coco_height):
241
- return [
242
- int(original_width * (box[0] / coco_width)),
243
- int(original_height * (box[1] / coco_height)),
244
- int(original_width * (box[2] / coco_width)),
245
- int(original_height* (box[3] / coco_height)),
246
- ]
247
-
248
- def get_blocks(bboxes_block, categories, texts):
249
-
250
- # get list of unique block boxes
251
- bbox_block_dict, bboxes_block_list, bbox_block_prec = dict(), list(), list()
252
- for count_block, bbox_block in enumerate(bboxes_block):
253
- if bbox_block != bbox_block_prec:
254
- bbox_block_indexes = [i for i, bbox in enumerate(bboxes_block) if bbox == bbox_block]
255
- bbox_block_dict[count_block] = bbox_block_indexes
256
- bboxes_block_list.append(bbox_block)
257
- bbox_block_prec = bbox_block
258
-
259
- # get list of categories and texts by unique block boxes
260
- category_block_list, text_block_list = list(), list()
261
- for bbox_block in bboxes_block_list:
262
- count_block = bboxes_block.index(bbox_block)
263
- bbox_block_indexes = bbox_block_dict[count_block]
264
- category_block = np.array(categories, dtype=object)[bbox_block_indexes].tolist()[0]
265
- category_block_list.append(category_block)
266
- text_block = np.array(texts, dtype=object)[bbox_block_indexes].tolist()
267
- text_block = [text.replace("\n","").strip() for text in text_block]
268
- if id2label[category_block] == "Text" or id2label[category_block] == "Caption" or id2label[category_block] == "Footnote":
269
- text_block = ' '.join(text_block)
270
- else:
271
- text_block = '\n'.join(text_block)
272
- text_block_list.append(text_block)
273
-
274
- return bboxes_block_list, category_block_list, text_block_list
275
-
276
- # function to sort bounding boxes
277
- def get_sorted_boxes(bboxes):
278
-
279
- # sort by y from page top to bottom
280
- sorted_bboxes = sorted(bboxes, key=itemgetter(1), reverse=False)
281
- y_list = [bbox[1] for bbox in sorted_bboxes]
282
-
283
- # sort by x from page left to right when boxes with same y
284
- if len(list(set(y_list))) != len(y_list):
285
- y_list_duplicates_indexes = dict()
286
- y_list_duplicates = [item for item, count in collections.Counter(y_list).items() if count > 1]
287
- for item in y_list_duplicates:
288
- y_list_duplicates_indexes[item] = [i for i, e in enumerate(y_list) if e == item]
289
- bbox_list_y_duplicates = sorted(np.array(sorted_bboxes, dtype=object)[y_list_duplicates_indexes[item]].tolist(), key=itemgetter(0), reverse=False)
290
- np_array_bboxes = np.array(sorted_bboxes)
291
- np_array_bboxes[y_list_duplicates_indexes[item]] = np.array(bbox_list_y_duplicates)
292
- sorted_bboxes = np_array_bboxes.tolist()
293
-
294
- return sorted_bboxes
295
-
296
- # sort data from y = 0 to end of page (and after, x=0 to end of page when necessary)
297
- def sort_data(bboxes, categories, texts):
298
-
299
- sorted_bboxes = get_sorted_boxes(bboxes)
300
- sorted_bboxes_indexes = [bboxes.index(bbox) for bbox in sorted_bboxes]
301
- sorted_categories = np.array(categories, dtype=object)[sorted_bboxes_indexes].tolist()
302
- sorted_texts = np.array(texts, dtype=object)[sorted_bboxes_indexes].tolist()
303
-
304
- return sorted_bboxes, sorted_categories, sorted_texts
305
-
306
- # sort data from y = 0 to end of page (and after, x=0 to end of page when necessary)
307
- def sort_data_wo_labels(bboxes, texts):
308
-
309
- sorted_bboxes = get_sorted_boxes(bboxes)
310
- sorted_bboxes_indexes = [bboxes.index(bbox) for bbox in sorted_bboxes]
311
- sorted_texts = np.array(texts, dtype=object)[sorted_bboxes_indexes].tolist()
312
-
313
- return sorted_bboxes, sorted_texts
314
-
315
- ## PDF processing
316
-
317
- import os
318
- from PyPDF2 import PdfReader, PdfReadError
319
- from pdf2image import convert_from_path
320
-
321
- def pdf_to_images(uploaded_pdf):
322
-
323
- if uploaded_pdf is None:
324
- path_to_file = pdf_blank
325
- filename = os.path.basename(path_to_file)
326
- msg = "Invalid PDF file."
327
- images = [cv2.imread(image_blank)]
328
- else:
329
- # path to the uploaded PDF
330
- path_to_file = uploaded_pdf.name
331
- filename = os.path.basename(path_to_file)
332
-
333
- try:
334
- PdfReader(path_to_file)
335
- except PdfReadError:
336
- path_to_file = pdf_blank
337
- filename = os.path.basename(path_to_file)
338
- msg = "Invalid PDF file."
339
- images = [cv2.imread(image_blank)]
340
- else:
341
- try:
342
- images = convert_from_path(path_to_file, last_page=max_imgboxes)
343
- num_imgs = len(images)
344
- msg = f'The PDF "{filename}" was converted into {num_imgs} images.'
345
- except Exception as e:
346
- msg = f'Error with the PDF "{filename}": it was not converted into images. Error: {e}'
347
- images = [cv2.imread(image_wo_content)]
348
-
349
- return filename, msg, images
350
-
351
-
352
-
353
- def extraction_data_from_image(images):
354
-
355
- num_imgs = len(images)
356
-
357
- if num_imgs > 0:
358
-
359
- custom_config = r'--oem 3 --psm 3 -l eng'
360
- results, lines, row_indexes, par_boxes, line_boxes = dict(), dict(), dict(), dict(), dict()
361
- images_ids_list, lines_list, par_boxes_list, line_boxes_list, images_list, page_no_list, num_pages_list = list(), list(), list(), list(), list(), list(), list()
362
-
363
- try:
364
- for i, image in enumerate(images):
365
- # image preprocessing
366
- img = image.copy()
367
- factor, path_to_img = set_image_dpi_resize(img)
368
- img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
369
- ret, img = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY)
370
-
371
- # OCR PyTesseract | get langs of page
372
- txt = pytesseract.image_to_string(img, config=custom_config)
373
- txt = txt.strip().lower()
374
- txt = re.sub(r" +", " ", txt)
375
- txt = re.sub(r"(\n\s*)+\n+", "\n", txt)
376
- try:
377
- langs = detect_langs(txt)
378
- langs = [langdetect2Tesseract[langs[i].lang] for i in range(len(langs))]
379
- langs_string = '+'.join(langs)
380
- except:
381
- langs_string = "eng"
382
- langs_string += '+osd'
383
- custom_config = f'--oem 3 --psm 3 -l {langs_string}'
384
-
385
- # OCR PyTesseract | get data
386
- results[i] = pytesseract.image_to_data(img, config=custom_config, output_type=pytesseract.Output.DICT)
387
-
388
- lines[i], row_indexes[i], par_boxes[i], line_boxes[i] = get_data(results[i], factor, conf_min=0)
389
- lines_list.append(lines[i])
390
- par_boxes_list.append(par_boxes[i])
391
- line_boxes_list.append(line_boxes[i])
392
- images_ids_list.append(i)
393
- images_list.append(images[i])
394
- page_no_list.append(i)
395
- num_pages_list.append(num_imgs)
396
-
397
- except Exception as e:
398
- print(f"There was an error within the extraction of PDF text by the OCR: {e}")
399
- else:
400
- dataset = Dataset.from_dict({"images_ids": images_ids_list, "images": images_list, "page_no": page_no_list, "num_pages": num_pages_list, "texts": lines_list, "bboxes_line": line_boxes_list})
401
-
402
- return dataset, lines, row_indexes, par_boxes, line_boxes
403
-
404
- ## Inference
405
-
406
- def prepare_inference_features(example, cls_box = cls_box, sep_box = sep_box):
407
-
408
- images_ids_list, chunks_ids_list, input_ids_list, attention_mask_list, bb_list = list(), list(), list(), list(), list()
409
-
410
- # get batch
411
- batch_images_ids = example["images_ids"]
412
- batch_images = example["images"]
413
- batch_bboxes_line = example["bboxes_line"]
414
- batch_texts = example["texts"]
415
- batch_images_size = [image.size for image in batch_images]
416
-
417
- batch_width, batch_height = [image_size[0] for image_size in batch_images_size], [image_size[1] for image_size in batch_images_size]
418
-
419
- # add a dimension if not a batch but only one image
420
- if not isinstance(batch_images_ids, list):
421
- batch_images_ids = [batch_images_ids]
422
- batch_images = [batch_images]
423
- batch_bboxes_line = [batch_bboxes_line]
424
- batch_texts = [batch_texts]
425
- batch_width, batch_height = [batch_width], [batch_height]
426
-
427
- # process all images of the batch
428
- for num_batch, (image_id, boxes, texts, width, height) in enumerate(zip(batch_images_ids, batch_bboxes_line, batch_texts, batch_width, batch_height)):
429
- tokens_list = []
430
- bboxes_list = []
431
-
432
- # add a dimension if only on image
433
- if not isinstance(texts, list):
434
- texts, boxes = [texts], [boxes]
435
-
436
- # convert boxes to original
437
- normalize_bboxes_line = [normalize_box(upperleft_to_lowerright(box), width, height) for box in boxes]
438
-
439
- # sort boxes with texts
440
- # we want sorted lists from top to bottom of the image
441
- boxes, texts = sort_data_wo_labels(normalize_bboxes_line, texts)
442
-
443
- count = 0
444
- for box, text in zip(boxes, texts):
445
- tokens = tokenizer.tokenize(text)
446
- num_tokens = len(tokens) # get number of tokens
447
- tokens_list.extend(tokens)
448
-
449
- bboxes_list.extend([box] * num_tokens) # number of boxes must be the same as the number of tokens
450
-
451
- # use of return_overflowing_tokens=True / stride=doc_stride
452
- # to get parts of image with overlap
453
- # source: https://huggingface.co/course/chapter6/3b?fw=tf#handling-long-contexts
454
- encodings = tokenizer(" ".join(texts),
455
- truncation=True,
456
- padding="max_length",
457
- max_length=max_length,
458
- stride=doc_stride,
459
- return_overflowing_tokens=True,
460
- return_offsets_mapping=True
461
- )
462
-
463
- otsm = encodings.pop("overflow_to_sample_mapping")
464
- offset_mapping = encodings.pop("offset_mapping")
465
-
466
- # Let's label those examples and get their boxes
467
- sequence_length_prev = 0
468
- for i, offsets in enumerate(offset_mapping):
469
- # truncate tokens, boxes and labels based on length of chunk - 2 (special tokens <s> and </s>)
470
- sequence_length = len(encodings.input_ids[i]) - 2
471
- if i == 0: start = 0
472
- else: start += sequence_length_prev - doc_stride
473
- end = start + sequence_length
474
- sequence_length_prev = sequence_length
475
-
476
- # get tokens, boxes and labels of this image chunk
477
- bb = [cls_box] + bboxes_list[start:end] + [sep_box]
478
-
479
- # as the last chunk can have a length < max_length
480
- # we must to add [tokenizer.pad_token] (tokens), [sep_box] (boxes) and [-100] (labels)
481
- if len(bb) < max_length:
482
- bb = bb + [sep_box] * (max_length - len(bb))
483
-
484
- # append results
485
- input_ids_list.append(encodings["input_ids"][i])
486
- attention_mask_list.append(encodings["attention_mask"][i])
487
- bb_list.append(bb)
488
- images_ids_list.append(image_id)
489
- chunks_ids_list.append(i)
490
-
491
- return {
492
- "images_ids": images_ids_list,
493
- "chunk_ids": chunks_ids_list,
494
- "input_ids": input_ids_list,
495
- "attention_mask": attention_mask_list,
496
- "normalized_bboxes": bb_list,
497
- }
498
-
499
- from torch.utils.data import Dataset
500
-
501
- class CustomDataset(Dataset):
502
- def __init__(self, dataset, tokenizer):
503
- self.dataset = dataset
504
- self.tokenizer = tokenizer
505
-
506
- def __len__(self):
507
- return len(self.dataset)
508
-
509
- def __getitem__(self, idx):
510
- # get item
511
- example = self.dataset[idx]
512
- encoding = dict()
513
- encoding["images_ids"] = example["images_ids"]
514
- encoding["chunk_ids"] = example["chunk_ids"]
515
- encoding["input_ids"] = example["input_ids"]
516
- encoding["attention_mask"] = example["attention_mask"]
517
- encoding["bbox"] = example["normalized_bboxes"]
518
-
519
- return encoding
520
-
521
- import torch.nn.functional as F
522
-
523
- # get predictions at token level
524
- def predictions_token_level(images, custom_encoded_dataset):
525
-
526
- num_imgs = len(images)
527
- if num_imgs > 0:
528
-
529
- chunk_ids, input_ids, bboxes, outputs, token_predictions = dict(), dict(), dict(), dict(), dict()
530
- images_ids_list = list()
531
-
532
- for i,encoding in enumerate(custom_encoded_dataset):
533
-
534
- # get custom encoded data
535
- image_id = encoding['images_ids']
536
- chunk_id = encoding['chunk_ids']
537
- input_id = torch.tensor(encoding['input_ids'])[None]
538
- attention_mask = torch.tensor(encoding['attention_mask'])[None]
539
- bbox = torch.tensor(encoding['bbox'])[None]
540
-
541
- # save data in dictionnaries
542
- if image_id not in images_ids_list: images_ids_list.append(image_id)
543
-
544
- if image_id in chunk_ids: chunk_ids[image_id].append(chunk_id)
545
- else: chunk_ids[image_id] = [chunk_id]
546
-
547
- if image_id in input_ids: input_ids[image_id].append(input_id)
548
- else: input_ids[image_id] = [input_id]
549
-
550
- if image_id in bboxes: bboxes[image_id].append(bbox)
551
- else: bboxes[image_id] = [bbox]
552
-
553
- # get prediction with forward pass
554
- with torch.no_grad():
555
- output = model(
556
- input_ids=input_id,
557
- attention_mask=attention_mask,
558
- bbox=bbox
559
- )
560
-
561
- # save probabilities of predictions in dictionnary
562
- if image_id in outputs: outputs[image_id].append(F.softmax(output.logits.squeeze(), dim=-1))
563
- else: outputs[image_id] = [F.softmax(output.logits.squeeze(), dim=-1)]
564
-
565
- return outputs, images_ids_list, chunk_ids, input_ids, bboxes
566
-
567
- else:
568
- print("An error occurred while getting predictions!")
569
-
570
- from functools import reduce
571
-
572
- # Get predictions (line level)
573
- def predictions_line_level(dataset, outputs, images_ids_list, chunk_ids, input_ids, bboxes):
574
-
575
- ten_probs_dict, ten_input_ids_dict, ten_bboxes_dict = dict(), dict(), dict()
576
- bboxes_list_dict, input_ids_dict_dict, probs_dict_dict, df = dict(), dict(), dict(), dict()
577
-
578
- if len(images_ids_list) > 0:
579
-
580
- for i, image_id in enumerate(images_ids_list):
581
-
582
- # get image information
583
- images_list = dataset.filter(lambda example: example["images_ids"] == image_id)["images"]
584
- image = images_list[0]
585
- width, height = image.size
586
-
587
- # get data
588
- chunk_ids_list = chunk_ids[image_id]
589
- outputs_list = outputs[image_id]
590
- input_ids_list = input_ids[image_id]
591
- bboxes_list = bboxes[image_id]
592
-
593
- # create zeros tensors
594
- ten_probs = torch.zeros((outputs_list[0].shape[0] - 2)*len(outputs_list), outputs_list[0].shape[1])
595
- ten_input_ids = torch.ones(size=(1, (outputs_list[0].shape[0] - 2)*len(outputs_list)), dtype =int)
596
- ten_bboxes = torch.zeros(size=(1, (outputs_list[0].shape[0] - 2)*len(outputs_list), 4), dtype =int)
597
-
598
- if len(outputs_list) > 1:
599
-
600
- for num_output, (output, input_id, bbox) in enumerate(zip(outputs_list, input_ids_list, bboxes_list)):
601
- start = num_output*(max_length - 2) - max(0,num_output)*doc_stride
602
- end = start + (max_length - 2)
603
-
604
- if num_output == 0:
605
- ten_probs[start:end,:] += output[1:-1]
606
- ten_input_ids[:,start:end] = input_id[:,1:-1]
607
- ten_bboxes[:,start:end,:] = bbox[:,1:-1,:]
608
- else:
609
- ten_probs[start:start + doc_stride,:] += output[1:1 + doc_stride]
610
- ten_probs[start:start + doc_stride,:] = ten_probs[start:start + doc_stride,:] * 0.5
611
- ten_probs[start + doc_stride:end,:] += output[1 + doc_stride:-1]
612
-
613
- ten_input_ids[:,start:start + doc_stride] = input_id[:,1:1 + doc_stride]
614
- ten_input_ids[:,start + doc_stride:end] = input_id[:,1 + doc_stride:-1]
615
-
616
- ten_bboxes[:,start:start + doc_stride,:] = bbox[:,1:1 + doc_stride,:]
617
- ten_bboxes[:,start + doc_stride:end,:] = bbox[:,1 + doc_stride:-1,:]
618
-
619
- else:
620
- ten_probs += outputs_list[0][1:-1]
621
- ten_input_ids = input_ids_list[0][:,1:-1]
622
- ten_bboxes = bboxes_list[0][:,1:-1]
623
-
624
- ten_probs_list, ten_input_ids_list, ten_bboxes_list = ten_probs.tolist(), ten_input_ids.tolist()[0], ten_bboxes.tolist()[0]
625
- bboxes_list = list()
626
- input_ids_dict, probs_dict = dict(), dict()
627
- bbox_prev = [-100, -100, -100, -100]
628
- for probs, input_id, bbox in zip(ten_probs_list, ten_input_ids_list, ten_bboxes_list):
629
- bbox = denormalize_box(bbox, width, height)
630
- if bbox != bbox_prev and bbox != cls_box and bbox != sep_box and bbox[0] != bbox[2] and bbox[1] != bbox[3]:
631
- bboxes_list.append(bbox)
632
- input_ids_dict[str(bbox)] = [input_id]
633
- probs_dict[str(bbox)] = [probs]
634
- elif bbox != cls_box and bbox != sep_box and bbox[0] != bbox[2] and bbox[1] != bbox[3]:
635
- input_ids_dict[str(bbox)].append(input_id)
636
- probs_dict[str(bbox)].append(probs)
637
- bbox_prev = bbox
638
-
639
- probs_bbox = dict()
640
- for i,bbox in enumerate(bboxes_list):
641
- probs = probs_dict[str(bbox)]
642
- probs = np.array(probs).T.tolist()
643
-
644
- probs_label = list()
645
- for probs_list in probs:
646
- prob_label = reduce(lambda x, y: x*y, probs_list)
647
- prob_label = prob_label**(1./(len(probs_list))) # normalization
648
- probs_label.append(prob_label)
649
- max_value = max(probs_label)
650
- max_index = probs_label.index(max_value)
651
- probs_bbox[str(bbox)] = max_index
652
-
653
- bboxes_list_dict[image_id] = bboxes_list
654
- input_ids_dict_dict[image_id] = input_ids_dict
655
- probs_dict_dict[image_id] = probs_bbox
656
-
657
- df[image_id] = pd.DataFrame()
658
- df[image_id]["bboxes"] = bboxes_list
659
- df[image_id]["texts"] = [tokenizer.decode(input_ids_dict[str(bbox)]) for bbox in bboxes_list]
660
- df[image_id]["labels"] = [id2label[probs_bbox[str(bbox)]] for bbox in bboxes_list]
661
-
662
- return probs_bbox, bboxes_list_dict, input_ids_dict_dict, probs_dict_dict, df
663
-
664
- else:
665
- print("An error occurred while getting predictions!")
666
-
667
-
668
-
669
- def get_labeled_images(dataset, images_ids_list, bboxes_list_dict, probs_dict_dict):
670
-
671
- labeled_images = list()
672
-
673
- for i, image_id in enumerate(images_ids_list):
674
-
675
- # get image
676
- images_list = dataset.filter(lambda example: example["images_ids"] == image_id)["images"]
677
- image = images_list[0]
678
- width, height = image.shape[1], image.shape[0]
679
-
680
- # get predicted boxes and labels
681
- bboxes_list = bboxes_list_dict[image_id]
682
- probs_bbox = probs_dict_dict[image_id]
683
-
684
- img_with_boxes = image.copy()
685
-
686
- for bbox in bboxes_list:
687
- predicted_label = id2label[probs_bbox[str(bbox)]]
688
- bbox = tuple(map(int, bbox)) # Convert to integers
689
- cv2.rectangle(img_with_boxes, bbox[0:2], bbox[2:4], label2color[predicted_label], 2)
690
- cv2.putText(img_with_boxes, predicted_label, (bbox[0] + 10, bbox[1] - 30), cv2.FONT_HERSHEY_SIMPLEX, 1, label2color[predicted_label], 2)
691
-
692
- labeled_images.append(img_with_boxes)
693
-
694
- return labeled_images
695
-
696
-
697
- # get data of encoded chunk
698
- def get_encoded_chunk_inference(index_chunk=None):
699
-
700
- # get datasets
701
- example = dataset
702
- encoded_example = encoded_dataset
703
-
704
- # get randomly a document in dataset
705
- if index_chunk == None: index_chunk = random.randint(0, len(encoded_example)-1)
706
- encoded_example = encoded_example[index_chunk]
707
- encoded_image_ids = encoded_example["images_ids"]
708
-
709
- # get the image
710
- example = example.filter(lambda example: example["images_ids"] == encoded_image_ids)[0]
711
- image = example["images"] # original image
712
- width, height = image.size
713
- page_no = example["page_no"]
714
- num_pages = example["num_pages"]
715
-
716
- # get boxes, texts, categories
717
- bboxes, input_ids = encoded_example["normalized_bboxes"][1:-1], encoded_example["input_ids"][1:-1]
718
- bboxes = [denormalize_box(bbox, width, height) for bbox in bboxes]
719
- num_tokens = len(input_ids) + 2
720
-
721
- # get unique bboxes and corresponding labels
722
- bboxes_list, input_ids_list = list(), list()
723
- input_ids_dict = dict()
724
- bbox_prev = [-100, -100, -100, -100]
725
- for i, (bbox, input_id) in enumerate(zip(bboxes, input_ids)):
726
- if bbox != bbox_prev:
727
- bboxes_list.append(bbox)
728
- input_ids_dict[str(bbox)] = [input_id]
729
- else:
730
- input_ids_dict[str(bbox)].append(input_id)
731
-
732
- # start_indexes_list.append(i)
733
- bbox_prev = bbox
734
-
735
- # do not keep "</s><pad><pad>..."
736
- if input_ids_dict[str(bboxes_list[-1])][0] == (tokenizer.convert_tokens_to_ids('</s>')):
737
- del input_ids_dict[str(bboxes_list[-1])]
738
- bboxes_list = bboxes_list[:-1]
739
-
740
- # get texts by line
741
- input_ids_list = input_ids_dict.values()
742
- texts_list = [tokenizer.decode(input_ids) for input_ids in input_ids_list]
743
-
744
- # display DataFrame
745
- df = pd.DataFrame({"texts": texts_list, "input_ids": input_ids_list, "bboxes": bboxes_list})
746
-
747
- return image, df, num_tokens, page_no, num_pages
748
-
749
-
750
- # display chunk of PDF image and its data
751
- def display_chunk_lines_inference(index_chunk=None):
752
-
753
- # get image and image data
754
- image, df, num_tokens, page_no, num_pages = get_encoded_chunk_inference(index_chunk=index_chunk)
755
-
756
- # get data from dataframe
757
- input_ids = df["input_ids"]
758
- texts = df["texts"]
759
- bboxes = df["bboxes"]
760
-
761
- print(f'Chunk ({num_tokens} tokens) of the PDF (page: {page_no+1} / {num_pages})\n')
762
-
763
- # display image with bounding boxes
764
- print(">> PDF image with bounding boxes of lines\n")
765
- img_with_boxes = image.copy()
766
-
767
- for box, text in zip(bboxes, texts):
768
- color = (0, 0, 255) # Red color in BGR
769
- box = tuple(map(int, box)) # Convert to integers
770
- cv2.rectangle(img_with_boxes, box[0:2], box[2:4], color, 2)
771
-
772
- # resize image to original
773
- width, height = img_with_boxes.shape[1], img_with_boxes.shape[0]
774
- img_with_boxes = cv2.resize(img_with_boxes, (int(0.5*width), int(0.5*height)))
775
-
776
- # display image using OpenCV
777
- cv2.imshow("PDF image with bounding boxes of lines", img_with_boxes)
778
- cv2.waitKey(0)
779
-
780
- # display image dataframe
781
- print("\n>> Dataframe of annotated lines\n")
782
- cols = ["texts", "bboxes"]
783
- df = df[cols]
784
- display(df) # Assuming display is a function for displaying DataFrames