Spaces:
Sleeping
Sleeping
João Pedro
commited on
Commit
·
6c6f2d5
1
Parent(s):
b90d0b6
try to avoid streamlit re-running everything on ui changes
Browse files
app.py
CHANGED
@@ -32,8 +32,13 @@ labels = [
|
|
32 |
id2label = {i: label for i, label in enumerate(labels)}
|
33 |
label2id = {v: k for k, v in id2label.items()}
|
34 |
|
35 |
-
|
36 |
-
model = LayoutLMv3ForSequenceClassification.from_pretrained("model/layoutlmv3/")
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
st.title("Document Classification with LayoutLMv3")
|
39 |
|
@@ -41,15 +46,32 @@ uploaded_file = st.file_uploader(
|
|
41 |
"Upload Document", type=["pdf", "jpg", "png"], accept_multiple_files=False
|
42 |
)
|
43 |
|
44 |
-
|
45 |
feedback_table = wandb.Table(columns=[
|
46 |
'image', 'filetype', 'predicted_label', 'predicted_label_id',
|
47 |
'correct_label', 'correct_label_id'
|
48 |
])
|
49 |
|
50 |
-
if
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
|
|
53 |
if uploaded_file.type == "application/pdf":
|
54 |
images = convert_from_bytes(uploaded_file.getvalue())
|
55 |
else:
|
@@ -58,16 +80,7 @@ if uploaded_file:
|
|
58 |
for i, image in enumerate(images):
|
59 |
st.image(image, caption=f'Uploaded Image {i}', use_container_width=True)
|
60 |
|
61 |
-
|
62 |
-
encoding = processor(
|
63 |
-
image,
|
64 |
-
return_tensors="pt",
|
65 |
-
truncation=True,
|
66 |
-
max_length=512,
|
67 |
-
)
|
68 |
-
print(f'Predicting image with index {i}')
|
69 |
-
outputs = model(**encoding)
|
70 |
-
prediction = outputs.logits.argmax(-1)[0].item()
|
71 |
|
72 |
st.write(f"Prediction: {id2label[prediction]}")
|
73 |
|
@@ -96,5 +109,7 @@ if uploaded_file:
|
|
96 |
st.success(f"Feedback for Image {i} submitted!")
|
97 |
|
98 |
print(feedback_table)
|
|
|
99 |
run.log({'feedback_table': feedback_table})
|
100 |
run.finish()
|
|
|
|
32 |
id2label = {i: label for i, label in enumerate(labels)}
|
33 |
label2id = {v: k for k, v in id2label.items()}
|
34 |
|
35 |
+
if 'model' not in st.session_state:
|
36 |
+
st.session_state.model = LayoutLMv3ForSequenceClassification.from_pretrained("model/layoutlmv3/")
|
37 |
+
if 'processor' not in st.session_state:
|
38 |
+
st.session_state.processor = LayoutLMv3Processor.from_pretrained("model/layoutlmv3/")
|
39 |
+
|
40 |
+
model = st.session_state.model
|
41 |
+
processor = st.session_state.processor
|
42 |
|
43 |
st.title("Document Classification with LayoutLMv3")
|
44 |
|
|
|
46 |
"Upload Document", type=["pdf", "jpg", "png"], accept_multiple_files=False
|
47 |
)
|
48 |
|
|
|
49 |
feedback_table = wandb.Table(columns=[
|
50 |
'image', 'filetype', 'predicted_label', 'predicted_label_id',
|
51 |
'correct_label', 'correct_label_id'
|
52 |
])
|
53 |
|
54 |
+
if 'wandb_run' not in st.session_data:
|
55 |
+
st.session_data.wandb_run = wandb.init(project='hydra-classifier', name='feedback-loop')
|
56 |
+
|
57 |
+
|
58 |
+
@st.cache_data
|
59 |
+
def classify_image(image):
|
60 |
+
print(f'Encoding image with index {i}')
|
61 |
+
encoding = processor(
|
62 |
+
image,
|
63 |
+
return_tensors="pt",
|
64 |
+
truncation=True,
|
65 |
+
max_length=512,
|
66 |
+
)
|
67 |
+
|
68 |
+
print(f'Predicting image with index {i}')
|
69 |
+
outputs = model(**encoding)
|
70 |
+
prediction = outputs.logits.argmax(-1)[0].item()
|
71 |
+
return prediction
|
72 |
+
|
73 |
|
74 |
+
if uploaded_file:
|
75 |
if uploaded_file.type == "application/pdf":
|
76 |
images = convert_from_bytes(uploaded_file.getvalue())
|
77 |
else:
|
|
|
80 |
for i, image in enumerate(images):
|
81 |
st.image(image, caption=f'Uploaded Image {i}', use_container_width=True)
|
82 |
|
83 |
+
prediction = classify_image(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
st.write(f"Prediction: {id2label[prediction]}")
|
86 |
|
|
|
109 |
st.success(f"Feedback for Image {i} submitted!")
|
110 |
|
111 |
print(feedback_table)
|
112 |
+
run = st.session_data.wandb_run
|
113 |
run.log({'feedback_table': feedback_table})
|
114 |
run.finish()
|
115 |
+
st.session_data.wandb_run = None
|