João Pedro
fix some wording in some strings
c9ae4da
raw
history blame
3.35 kB
import os
import wandb
import streamlit as st
from transformers import LayoutLMv3Processor, LayoutLMv3ForSequenceClassification
from pdf2image import convert_from_bytes
from PIL import Image
wandb_api_key = os.getenv("WANDB_API_KEY")
if not wandb_api_key:
st.error(
"Couldn't find WanDB API key. Please set it up as an environemnt variable",
icon="🚨",
)
else:
wandb.login(key=wandb_api_key)
labels = [
'budget',
'email',
'form',
'handwritten',
'invoice',
'language',
'letter',
'memo',
'news article',
'questionnaire',
'resume',
'scientific publication',
'specification',
]
id2label = {i: label for i, label in enumerate(labels)}
label2id = {v: k for k, v in id2label.items()}
if 'model' not in st.session_state:
st.session_state.model = LayoutLMv3ForSequenceClassification.from_pretrained("model/layoutlmv3/")
if 'processor' not in st.session_state:
st.session_state.processor = LayoutLMv3Processor.from_pretrained("model/layoutlmv3/")
model = st.session_state.model
processor = st.session_state.processor
st.title("Document Classification with LayoutLMv3")
uploaded_file = st.file_uploader(
"Upload Document", type=["pdf", "jpg", "png"], accept_multiple_files=False
)
feedback_table = wandb.Table(columns=[
'image', 'filetype', 'predicted_label', 'predicted_label_id',
'correct_label', 'correct_label_id'
])
if 'wandb_run' not in st.session_state:
st.session_state.wandb_run = wandb.init(project='hydra-classifier', name='feedback-loop')
@st.cache_data
def classify_image(_image):
print(f'Encoding image with index {i}')
encoding = processor(
image,
return_tensors="pt",
truncation=True,
max_length=512,
)
print(f'Predicting image with index {i}')
outputs = model(**encoding)
prediction = outputs.logits.argmax(-1)[0].item()
return prediction
if uploaded_file:
if uploaded_file.type == "application/pdf":
images = convert_from_bytes(uploaded_file.getvalue())
else:
images = [Image.open(uploaded_file)]
for i, image in enumerate(images):
st.image(image, caption=f'Uploaded Image {i}', use_container_width=True)
prediction = classify_image(image)
st.write(f"Prediction: {id2label[prediction]}")
feedback = st.radio(
"Is the classification correct?", ("Yes", "No"),
key=f'prediction-{i}'
)
if feedback == "No":
correct_label = st.selectbox(
"Please select the correct label:", labels,
key=f'selectbox-{i}'
)
print(f'Correct label for image {i}: {correct_label}')
# Add a button to confirm feedback and log it
if st.button(f"Add feedback for Image {i}", key=f'add-{i}'):
feedback_table.add_data(
wandb.Image(image),
uploaded_file.type,
id2label[prediction],
prediction,
correct_label,
label2id[correct_label],
)
if st.button("Submit all feedback", key=f'submit'):
run = st.session_state.wandb_run
run.log({'feedback_table': feedback_table})
run.finish()
st.success(f"Feedback submitted!")