File size: 3,873 Bytes
52c57e0
 
 
 
 
 
f1eda89
 
 
52c57e0
f1eda89
 
52c57e0
 
 
 
4093517
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52c57e0
d023795
e1b2c07
52c57e0
 
 
 
 
 
 
 
 
 
 
b248a87
 
52c57e0
 
 
 
 
 
 
 
 
b248a87
52c57e0
 
 
 
 
 
 
 
 
 
 
b248a87
52c57e0
a8ea53b
358c7b8
4093517
 
0581571
4093517
 
 
52c57e0
2a1d11c
52c57e0
2a1d11c
52c57e0
4093517
52c57e0
 
 
 
80ff823
52c57e0
ad0f1c1
52c57e0
5e2e7ba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
import numpy as np
from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForTokenClassification
from PIL import Image, ImageDraw, ImageFont
from utils import OCR, unnormalize_box

tokenizer = LayoutLMv3TokenizerFast.from_pretrained("mp-02/layoutlmv3-finetuned-cord-sroie", apply_ocr=False)
processor = LayoutLMv3Processor.from_pretrained("mp-02/layoutlmv3-finetuned-cord-sroie", apply_ocr=False)
model = LayoutLMv3ForTokenClassification.from_pretrained("mp-02/layoutlmv3-finetuned-cord-sroie")

id2label = model.config.id2label
label2id = model.config.label2id

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

import json

def token2json(words, labels):
    result = []
    current_entity = None
    
    for token, label in zip(words, labels):
        if label.startswith("B-"):
            if current_entity:
                result.append(current_entity)
            current_entity = {"type": label[2:], "text": token}
        elif label.startswith("I-"):
            if current_entity and current_entity["type"] == label[2:]:
                current_entity["text"] += " " + token
            else:
                if current_entity:
                    result.append(current_entity)
                current_entity = {"type": label[2:], "text": token}
        else:  # "O" label
            if current_entity:
                result.append(current_entity)
                current_entity = None
    
    if current_entity:
        result.append(current_entity)
    
    return json.dumps(result, ensure_ascii=False, indent=2)


def prediction(image):
    
    boxes, words = OCR(image)
    encoding = processor(image, words, boxes=boxes, return_offsets_mapping=True, return_tensors="pt", truncation=True)
    offset_mapping = encoding.pop('offset_mapping')

    for k, v in encoding.items():
        encoding[k] = v.to(device)

    outputs = model(**encoding)

    predictions = outputs.logits.argmax(-1).squeeze().tolist()
    token_boxes = encoding.bbox.squeeze().tolist()
    probabilities = torch.softmax(outputs.logits, dim=-1)
    confidence_scores = probabilities.max(-1).values.squeeze().tolist()

    inp_ids = encoding.input_ids.squeeze().tolist()
    inp_words = [tokenizer.decode(i) for i in inp_ids]

    width, height = image.size
    is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0

    true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
    true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
    true_confidence_scores = [confidence_scores[idx] for idx, conf in enumerate(confidence_scores) if not is_subword[idx]]
    true_words = []

    for id, i in enumerate(inp_words):
        if not is_subword[id]:
            true_words.append(i)
        else:
            true_words[-1] = true_words[-1]+i

    true_predictions = true_predictions[1:-1]
    true_boxes = true_boxes[1:-1]
    true_words = true_words[1:-1]
    true_confidence_scores = true_confidence_scores[1:-1]

    for i, conf in enumerate(true_confidence_scores):
        if conf < 0.6 :
            true_predictions[i] = "O"        

    
    d = token2json(true_words, true_predictions)
    
    """for id, i in enumerate(true_predictions):
        if i not in d.keys():
            d[i] = true_words[id]
        else:
            d[i] = d[i] + ", " + true_words[id]
    d = {k: v.strip() for (k, v) in d.items()}
    d.pop("O")"""

    draw = ImageDraw.Draw(image, "RGBA")
    font = ImageFont.load_default()

    for prediction, box, confidence in zip(true_predictions, true_boxes, true_confidence_scores):
        draw.rectangle(box)
        draw.text((box[0]+10, box[1]-10), text=prediction+ ", "+ str(confidence), font=font, fill="black", font_size="15")

    return image, d