File size: 2,010 Bytes
056ccc3
5639711
c1d4001
056ccc3
 
d73399c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
056ccc3
d73399c
 
a3ab611
d73399c
 
 
056ccc3
 
 
fd98f6f
056ccc3
 
 
fd98f6f
 
548ee28
fd98f6f
 
056ccc3
fd98f6f
180d132
63eb0c6
8f5a987
99700b8
 
 
 
 
 
8f5a987
fd98f6f
dafbc40
056ccc3
61dba08
056ccc3
fd98f6f
61dba08
 
fd98f6f
 
392dd2d
d280e22
 
056ccc3
8f5a987
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import streamlit as st
from transformers import LayoutLMv3Processor, LayoutLMv3ForSequenceClassification
from pdf2image import convert_from_bytes
from PIL import Image

labels = [
    'budget',
    'email',
    'form',
    'handwritten',
    'invoice',
    'language',
    'letter',
    'memo',
    'news article',
    'questionnaire',
    'resume',
    'scientific publication',
    'specification',
]
id2label = {i: label for i, label in enumerate(labels)}
label2id = {v: k for k, v in id2label.items()}

processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base")
model = LayoutLMv3ForSequenceClassification.from_pretrained(
    "microsoft/layoutlmv3-base",
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id,
)

st.title("Document Classification with LayoutLMv3")

uploaded_file = st.file_uploader(
    "Upload Document", type=["pdf", "jpg", "png"], accept_multiple_files=False
)

if uploaded_file:
    if uploaded_file.type == "application/pdf":
        images = convert_from_bytes(uploaded_file.getvalue())
    else:
        images = [Image.open(uploaded_file)]

    for i, image in enumerate(images):
        st.image(image, caption=f'Uploaded Image {i}', use_container_width=True)

        print(f'Encoding image with index {i}')
        encoding = processor(
            image,
            return_tensors="pt",
            truncation=True,
            max_length=512,
        )
        print(f'Predicting image with index {i}')
        outputs = model(**encoding)
        prediction = outputs.logits.argmax(-1)[0].item()

        st.write(f"Prediction: {id2label[prediction]}")

        feedback = st.radio(
            "Is the classification correct?", ("Yes", "No"),
            key=f'prediction-{i}'
        )
        if feedback == "No":
            correct_label = st.selectbox(
                "Please select the correct label:", labels,
                key=f'selectbox-{i}'
            )
            print(f'Correct label for image {i}: {correct_label}')