mp-02 commited on
Commit
1c11d24
·
verified ·
1 Parent(s): 415bb0b

Delete sroie_inference.py

Browse files
Files changed (1) hide show
  1. sroie_inference.py +0 -114
sroie_inference.py DELETED
@@ -1,114 +0,0 @@
1
- import torch
2
- import cv2
3
- import numpy as np
4
- from PIL import Image, ImageDraw, ImageFont
5
- from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForTokenClassification
6
- from utils import OCR, unnormalize_box
7
-
8
-
9
- labels = ["O", "B-COMPANY", "I-COMPANY", "B-DATE", "I-DATE", "B-ADDRESS", "I-ADDRESS", "B-TOTAL", "I-TOTAL"]
10
- id2label = {v: k for v, k in enumerate(labels)}
11
- label2id = {k: v for v, k in enumerate(labels)}
12
-
13
- tokenizer = LayoutLMv3TokenizerFast.from_pretrained("Theivaprakasham/layoutlmv3-finetuned-sroie", apply_ocr=False)
14
- processor = LayoutLMv3Processor.from_pretrained("Theivaprakasham/layoutlmv3-finetuned-sroie", apply_ocr=False)
15
- model = LayoutLMv3ForTokenClassification.from_pretrained("Theivaprakasham/layoutlmv3-finetuned-sroie")
16
-
17
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
- model.to(device)
19
-
20
-
21
- def blur(image, boxes):
22
- img = cv2.imread(image)
23
- for box in boxes:
24
- blur_x = int(box[0])
25
- blur_y = int(box[1])
26
- blur_width = int(box[2]-box[0])
27
- blur_height = int(box[3]-box[1])
28
-
29
- roi = img[blur_y:blur_y + blur_height, blur_x:blur_x + blur_width]
30
- blur_image = cv2.GaussianBlur(roi, (201, 201), 0)
31
- img[blur_y:blur_y + blur_height, blur_x:blur_x + blur_width] = blur_image
32
-
33
- cv2.imwrite("images/example_with_blur.jpg", img)
34
- return "example_with_blur.jpg"
35
-
36
-
37
- def prediction(image_path: str):
38
- boxes, words = OCR(image_path)
39
- image = Image.open(image_path).convert('RGB')
40
- encoding = processor(image, words, boxes=boxes, return_offsets_mapping=True, return_tensors="pt", truncation=True)
41
- offset_mapping = encoding.pop('offset_mapping')
42
-
43
- for k, v in encoding.items():
44
- encoding[k] = v.to(device)
45
-
46
- outputs = model(**encoding)
47
-
48
- predictions = outputs.logits.argmax(-1).squeeze().tolist()
49
- token_boxes = encoding.bbox.squeeze().tolist()
50
-
51
- inp_ids = encoding.input_ids.squeeze().tolist()
52
- inp_words = [tokenizer.decode(i) for i in inp_ids]
53
-
54
- width, height = image.size
55
- is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0
56
-
57
- true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
58
- true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
59
- true_words = []
60
-
61
- for id, i in enumerate(inp_words):
62
- if not is_subword[id]:
63
- true_words.append(i)
64
- else:
65
- true_words[-1] = true_words[-1]+i
66
-
67
- true_predictions = true_predictions[1:-1]
68
- true_boxes = true_boxes[1:-1]
69
- true_words = true_words[1:-1]
70
-
71
- preds = []
72
- l_words = []
73
- bboxes = []
74
-
75
- for i, j in enumerate(true_predictions):
76
- if j != 'others':
77
- preds.append(true_predictions[i])
78
- l_words.append(true_words[i])
79
- bboxes.append(true_boxes[i])
80
-
81
- d = {}
82
- for id, i in enumerate(preds):
83
- if i not in d.keys():
84
- d[i] = l_words[id]
85
- else:
86
- d[i] = d[i] + ", " + l_words[id]
87
-
88
- d = {k: v.strip() for (k, v) in d.items()}
89
-
90
- keys_to_pop = []
91
- for k, v in d.items():
92
- if k[:2] == "I-":
93
- d["B-" + k[2:]] = d["B-" + k[2:]] + ", " + v
94
- keys_to_pop.append(k)
95
-
96
- if "O" in d: d.pop("O")
97
- if "B-TOTAL" in d: d.pop("B-TOTAL")
98
- for k in keys_to_pop: d.pop(k)
99
-
100
- blur_boxes = []
101
- for prediction, box in zip(preds, bboxes):
102
- if prediction != 'O' and prediction[2:] != 'TOTAL':
103
- blur_boxes.append(box)
104
-
105
- image = Image.open(blur(image_path, blur_boxes))
106
-
107
- draw = ImageDraw.Draw(image, "RGBA")
108
- font = ImageFont.load_default()
109
- for prediction, box in zip(preds, bboxes):
110
- draw.rectangle(box)
111
- draw.text((box[0]+10, box[1]-10), text=prediction, font=font, fill="black", font_size="8")
112
-
113
- return d, image
114
-