João Pedro
can't use other files in spaces
94847d6
raw
history blame
2.33 kB
import os
import wandb
import streamlit as st
from transformers import LayoutLMv3Processor, LayoutLMv3ForSequenceClassification
from pdf2image import convert_from_bytes
from PIL import Image
wandb_api_key = os.getnev("WANDB_API_KEY")
if not wandb_api_key:
st.error(
"Couldn't find WanDB API key. Please set it up as an environemnt variable",
icon="🚨",
)
else:
wandb.login(key=wandb_api_key)
processor = LayoutLMv3Processor.from_pretrained("model/layoutlmv3/")
model = LayoutLMv3ForSequenceClassification.from_pretrained("model/layoutlmv3/")
id2label = model.config.id2label
label2id = model.config.label2id
st.title("Document Classification with LayoutLMv3")
uploaded_file = st.file_uploader(
"Upload Document", type=["pdf", "jpg", "png"], accept_multiple_files=False
)
if uploaded_file:
run = wandb.init(project='hydra-classifier', name='feedback-loop')
if uploaded_file.type == "application/pdf":
images = convert_from_bytes(uploaded_file.getvalue())
else:
images = [Image.open(uploaded_file)]
for i, image in enumerate(images):
st.image(image, caption=f'Uploaded Image {i}', use_container_width=True)
print(f'Encoding image with index {i}')
encoding = processor(
image,
return_tensors="pt",
truncation=True,
max_length=512,
)
print(f'Predicting image with index {i}')
outputs = model(**encoding)
prediction = outputs.logits.argmax(-1)[0].item()
st.write(f"Prediction: {id2label[prediction]}")
feedback = st.radio(
"Is the classification correct?", ("Yes", "No"),
key=f'prediction-{i}'
)
if feedback == "No":
correct_label = st.selectbox(
"Please select the correct label:", labels,
key=f'selectbox-{i}'
)
print(f'Correct label for image {i}: {correct_label}')
run.log({
'filepath': uploaded_file,
'filetype': uploaded_file.type,
'predicted_label': id2label[prediction],
'predicted_label_id': prediction,
'correct_label': correct_label,
'correct_label_id': label2id[correct_label]
})
run.finish()