Spaces:
Sleeping
Sleeping
import os | |
import wandb | |
import streamlit as st | |
from transformers import LayoutLMv3Processor, LayoutLMv3ForSequenceClassification | |
from pdf2image import convert_from_bytes | |
from PIL import Image | |
wandb_api_key = os.getenv("WANDB_API_KEY") | |
if not wandb_api_key: | |
st.error( | |
"Couldn't find WanDB API key. Please set it up as an environemnt variable", | |
icon="🚨", | |
) | |
else: | |
wandb.login(key=wandb_api_key) | |
labels = [ | |
'budget', | |
'email', | |
'form', | |
'handwritten', | |
'invoice', | |
'language', | |
'letter', | |
'memo', | |
'news article', | |
'questionnaire', | |
'resume', | |
'scientific publication', | |
'specification', | |
] | |
id2label = {i: label for i, label in enumerate(labels)} | |
label2id = {v: k for k, v in id2label.items()} | |
processor = LayoutLMv3Processor.from_pretrained("model/layoutlmv3/") | |
model = LayoutLMv3ForSequenceClassification.from_pretrained("model/layoutlmv3/") | |
st.title("Document Classification with LayoutLMv3") | |
uploaded_file = st.file_uploader( | |
"Upload Document", type=["pdf", "jpg", "png"], accept_multiple_files=False | |
) | |
if uploaded_file: | |
run = wandb.init(project='hydra-classifier', name='feedback-loop') | |
if uploaded_file.type == "application/pdf": | |
images = convert_from_bytes(uploaded_file.getvalue()) | |
else: | |
images = [Image.open(uploaded_file)] | |
for i, image in enumerate(images): | |
st.image(image, caption=f'Uploaded Image {i}', use_container_width=True) | |
print(f'Encoding image with index {i}') | |
encoding = processor( | |
image, | |
return_tensors="pt", | |
truncation=True, | |
max_length=512, | |
) | |
print(f'Predicting image with index {i}') | |
outputs = model(**encoding) | |
prediction = outputs.logits.argmax(-1)[0].item() | |
st.write(f"Prediction: {id2label[prediction]}") | |
feedback = st.radio( | |
"Is the classification correct?", ("Yes", "No"), | |
key=f'prediction-{i}' | |
) | |
if feedback == "No": | |
correct_label = st.selectbox( | |
"Please select the correct label:", labels, | |
key=f'selectbox-{i}' | |
) | |
print(f'Correct label for image {i}: {correct_label}') | |
run.log({ | |
'filepath': uploaded_file, | |
'filetype': uploaded_file.type, | |
'predicted_label': id2label[prediction], | |
'predicted_label_id': prediction, | |
'correct_label': correct_label, | |
'correct_label_id': label2id[correct_label] | |
}) | |
run.finish() | |