File size: 3,372 Bytes
11deda5
 
056ccc3
5639711
c1d4001
056ccc3
 
2bc05a3
11deda5
 
 
 
 
 
 
d73399c
9652b01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c6f2d5
 
 
 
 
 
 
056ccc3
 
 
fd98f6f
056ccc3
 
 
ff1aca1
 
 
 
 
6c6f2d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11deda5
6c6f2d5
fd98f6f
548ee28
fd98f6f
 
056ccc3
fd98f6f
180d132
63eb0c6
6c6f2d5
056ccc3
61dba08
056ccc3
fd98f6f
61dba08
 
fd98f6f
11deda5
fd98f6f
392dd2d
d280e22
 
056ccc3
8f5a987
11deda5
ff1aca1
 
b90d0b6
 
 
 
 
 
 
 
ff1aca1
11deda5
6c0983b
6c6f2d5
ff1aca1
11deda5
6c6f2d5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import wandb
import streamlit as st
from transformers import LayoutLMv3Processor, LayoutLMv3ForSequenceClassification
from pdf2image import convert_from_bytes
from PIL import Image

wandb_api_key = os.getenv("WANDB_API_KEY")
if not wandb_api_key:
    st.error(
        "Couldn't find WanDB API key. Please set it up as an environemnt variable",
        icon="🚨",
    )
else:
    wandb.login(key=wandb_api_key)

labels = [
    'budget',
    'email',
    'form',
    'handwritten',
    'invoice',
    'language',
    'letter',
    'memo',
    'news article',
    'questionnaire',
    'resume',
    'scientific publication',
    'specification',
]
id2label = {i: label for i, label in enumerate(labels)}
label2id = {v: k for k, v in id2label.items()}

if 'model' not in st.session_state:
    st.session_state.model = LayoutLMv3ForSequenceClassification.from_pretrained("model/layoutlmv3/")
if 'processor' not in st.session_state:
    st.session_state.processor = LayoutLMv3Processor.from_pretrained("model/layoutlmv3/")

model = st.session_state.model
processor = st.session_state.processor

st.title("Document Classification with LayoutLMv3")

uploaded_file = st.file_uploader(
    "Upload Document", type=["pdf", "jpg", "png"], accept_multiple_files=False
)

feedback_table = wandb.Table(columns=[
    'image', 'filetype', 'predicted_label', 'predicted_label_id',
    'correct_label', 'correct_label_id'
])

if 'wandb_run' not in st.session_data:
    st.session_data.wandb_run = wandb.init(project='hydra-classifier', name='feedback-loop')


@st.cache_data
def classify_image(image):
    print(f'Encoding image with index {i}')
    encoding = processor(
        image,
        return_tensors="pt",
        truncation=True,
        max_length=512,
    )

    print(f'Predicting image with index {i}')
    outputs = model(**encoding)
    prediction = outputs.logits.argmax(-1)[0].item()
    return prediction


if uploaded_file:
    if uploaded_file.type == "application/pdf":
        images = convert_from_bytes(uploaded_file.getvalue())
    else:
        images = [Image.open(uploaded_file)]

    for i, image in enumerate(images):
        st.image(image, caption=f'Uploaded Image {i}', use_container_width=True)

        prediction = classify_image(image)

        st.write(f"Prediction: {id2label[prediction]}")

        feedback = st.radio(
            "Is the classification correct?", ("Yes", "No"),
            key=f'prediction-{i}'
        )

        if feedback == "No":
            correct_label = st.selectbox(
                "Please select the correct label:", labels,
                key=f'selectbox-{i}'
            )
            print(f'Correct label for image {i}: {correct_label}')

            # Add a button to confirm feedback and log it
            if st.button(f"Submit Feedback for Image {i}", key=f'submit-{i}'):
                feedback_table.add_data(
                    wandb.Image(image),
                    uploaded_file.type,
                    id2label[prediction],
                    prediction,
                    correct_label,
                    label2id[correct_label],
                )
                st.success(f"Feedback for Image {i} submitted!")

    print(feedback_table)
    run = st.session_data.wandb_run
    run.log({'feedback_table': feedback_table})
    run.finish()
    st.session_data.wandb_run = None