mp-02 commited on
Commit
f5ce91b
·
verified ·
1 Parent(s): 5e2e7ba

Delete sroie_inference.py

Browse files
Files changed (1) hide show
  1. sroie_inference.py +0 -106
sroie_inference.py DELETED
@@ -1,106 +0,0 @@
1
- import torch
2
- import cv2
3
- import numpy as np
4
- from PIL import Image, ImageDraw, ImageFont
5
- from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForTokenClassification
6
- from utils import OCR, unnormalize_box
7
-
8
-
9
- tokenizer = LayoutLMv3TokenizerFast.from_pretrained("mp-02/layoutlmv3-finetuned-sroie", apply_ocr=False)
10
- processor = LayoutLMv3Processor.from_pretrained("mp-02/layoutlmv3-finetuned-sroie", apply_ocr=False)
11
- model = LayoutLMv3ForTokenClassification.from_pretrained("mp-02/layoutlmv3-finetuned-sroie")
12
-
13
- id2label = model.config.id2label
14
- label2id = model.config.label2id
15
-
16
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
- model.to(device)
18
-
19
-
20
- def blur(image, boxes):
21
- image = np.array(image)
22
- for box in boxes:
23
-
24
- blur_x = int(box[0])
25
- blur_y = int(box[1])
26
- blur_width = int(box[2]-box[0])
27
- blur_height = int(box[3]-box[1])
28
-
29
- roi = image[blur_y:blur_y + blur_height, blur_x:blur_x + blur_width]
30
- blur_image = cv2.GaussianBlur(roi, (201, 201), 0)
31
- image[blur_y:blur_y + blur_height, blur_x:blur_x + blur_width] = blur_image
32
-
33
- return Image.fromarray(image, 'RGB')
34
-
35
-
36
- def prediction(image):
37
- boxes, words = OCR(image)
38
- encoding = processor(image, words, boxes=boxes, return_offsets_mapping=True, return_tensors="pt", truncation=True)
39
- offset_mapping = encoding.pop('offset_mapping')
40
-
41
- for k, v in encoding.items():
42
- encoding[k] = v.to(device)
43
-
44
- outputs = model(**encoding)
45
-
46
- predictions = outputs.logits.argmax(-1).squeeze().tolist()
47
- token_boxes = encoding.bbox.squeeze().tolist()
48
- probabilities = torch.softmax(outputs.logits, dim=-1)
49
- confidence_scores = probabilities.max(-1).values.squeeze().tolist()
50
-
51
- inp_ids = encoding.input_ids.squeeze().tolist()
52
- inp_words = [tokenizer.decode(i) for i in inp_ids]
53
-
54
- width, height = image.size
55
- is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0
56
-
57
- true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
58
- true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
59
- true_confidence_scores = [confidence_scores[idx] for idx, conf in enumerate(confidence_scores) if not is_subword[idx]]
60
- true_words = []
61
-
62
- for id, i in enumerate(inp_words):
63
- if not is_subword[id]:
64
- true_words.append(i)
65
- else:
66
- true_words[-1] = true_words[-1]+i
67
-
68
- true_predictions = true_predictions[1:-1]
69
- true_boxes = true_boxes[1:-1]
70
- true_words = true_words[1:-1]
71
- true_confidence_scores = true_confidence_scores[1:-1]
72
-
73
- #for i, j in enumerate(true_confidence_scores):
74
- # if j < 0.8: #####################################
75
- # true_predictions[i] = "O"
76
-
77
- d = {}
78
- for id, i in enumerate(true_predictions):
79
- #rimuovo i prefissi
80
- if i != "O":
81
- i = i[2:]
82
- if i not in d.keys():
83
- d[i] = true_words[id]
84
- else:
85
- d[i] = d[i] + ", " + true_words[id]
86
- d = {k: v.strip() for (k, v) in d.items()}
87
-
88
- if "O" in d: d.pop("O")
89
- if "TOTAL" in d: d.pop("TOTAL")
90
-
91
- blur_boxes = []
92
- for prediction, box in zip(true_predictions, true_boxes):
93
- if prediction != 'O' and prediction != 'TOTAL':
94
- blur_boxes.append(box)
95
-
96
- image = (blur(image, blur_boxes))
97
-
98
- #draw = ImageDraw.Draw(image, "RGBA")
99
- #font = ImageFont.load_default()
100
-
101
- #for prediction, box in zip(true_predictions, true_boxes):
102
- # draw.rectangle(box)
103
- # draw.text((box[0]+10, box[1]-10), text=prediction, font=font, fill="black", font_size="8")
104
-
105
- return d, image
106
-