João Pedro
dont set the run to None, that breaks things
221f9cc
raw
history blame
3.34 kB
import os
import wandb
import streamlit as st
from transformers import LayoutLMv3Processor, LayoutLMv3ForSequenceClassification
from pdf2image import convert_from_bytes
from PIL import Image
wandb_api_key = os.getenv("WANDB_API_KEY")
if not wandb_api_key:
st.error(
"Couldn't find WanDB API key. Please set it up as an environemnt variable",
icon="🚨",
)
else:
wandb.login(key=wandb_api_key)
labels = [
'budget',
'email',
'form',
'handwritten',
'invoice',
'language',
'letter',
'memo',
'news article',
'questionnaire',
'resume',
'scientific publication',
'specification',
]
id2label = {i: label for i, label in enumerate(labels)}
label2id = {v: k for k, v in id2label.items()}
if 'model' not in st.session_state:
st.session_state.model = LayoutLMv3ForSequenceClassification.from_pretrained("model/layoutlmv3/")
if 'processor' not in st.session_state:
st.session_state.processor = LayoutLMv3Processor.from_pretrained("model/layoutlmv3/")
model = st.session_state.model
processor = st.session_state.processor
st.title("Document Classification with LayoutLMv3")
uploaded_file = st.file_uploader(
"Upload Document", type=["pdf", "jpg", "png"], accept_multiple_files=False
)
feedback_table = wandb.Table(columns=[
'image', 'filetype', 'predicted_label', 'predicted_label_id',
'correct_label', 'correct_label_id'
])
if 'wandb_run' not in st.session_state:
st.session_state.wandb_run = wandb.init(project='hydra-classifier', name='feedback-loop')
@st.cache_data
def classify_image(_image):
print(f'Encoding image with index {i}')
encoding = processor(
image,
return_tensors="pt",
truncation=True,
max_length=512,
)
print(f'Predicting image with index {i}')
outputs = model(**encoding)
prediction = outputs.logits.argmax(-1)[0].item()
return prediction
if uploaded_file:
if uploaded_file.type == "application/pdf":
images = convert_from_bytes(uploaded_file.getvalue())
else:
images = [Image.open(uploaded_file)]
for i, image in enumerate(images):
st.image(image, caption=f'Uploaded Image {i}', use_container_width=True)
prediction = classify_image(image)
st.write(f"Prediction: {id2label[prediction]}")
feedback = st.radio(
"Is the classification correct?", ("Yes", "No"),
key=f'prediction-{i}'
)
if feedback == "No":
correct_label = st.selectbox(
"Please select the correct label:", labels,
key=f'selectbox-{i}'
)
print(f'Correct label for image {i}: {correct_label}')
# Add a button to confirm feedback and log it
if st.button(f"Submit Feedback for Image {i}", key=f'submit-{i}'):
feedback_table.add_data(
wandb.Image(image),
uploaded_file.type,
id2label[prediction],
prediction,
correct_label,
label2id[correct_label],
)
st.success(f"Feedback for Image {i} submitted!")
print(feedback_table)
run = st.session_state.wandb_run
run.log({'feedback_table': feedback_table})
run.finish()