Spaces:
Sleeping
Sleeping
File size: 2,603 Bytes
11deda5 056ccc3 5639711 c1d4001 056ccc3 2bc05a3 11deda5 d73399c 9652b01 11deda5 056ccc3 fd98f6f 056ccc3 fd98f6f 94847d6 11deda5 fd98f6f 548ee28 fd98f6f 056ccc3 fd98f6f 180d132 63eb0c6 8f5a987 99700b8 8f5a987 fd98f6f dafbc40 056ccc3 61dba08 056ccc3 fd98f6f 61dba08 fd98f6f 11deda5 fd98f6f 392dd2d d280e22 056ccc3 8f5a987 11deda5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import os
import wandb
import streamlit as st
from transformers import LayoutLMv3Processor, LayoutLMv3ForSequenceClassification
from pdf2image import convert_from_bytes
from PIL import Image
wandb_api_key = os.getenv("WANDB_API_KEY")
if not wandb_api_key:
st.error(
"Couldn't find WanDB API key. Please set it up as an environemnt variable",
icon="🚨",
)
else:
wandb.login(key=wandb_api_key)
labels = [
'budget',
'email',
'form',
'handwritten',
'invoice',
'language',
'letter',
'memo',
'news article',
'questionnaire',
'resume',
'scientific publication',
'specification',
]
id2label = {i: label for i, label in enumerate(labels)}
label2id = {v: k for k, v in id2label.items()}
processor = LayoutLMv3Processor.from_pretrained("model/layoutlmv3/")
model = LayoutLMv3ForSequenceClassification.from_pretrained("model/layoutlmv3/")
st.title("Document Classification with LayoutLMv3")
uploaded_file = st.file_uploader(
"Upload Document", type=["pdf", "jpg", "png"], accept_multiple_files=False
)
if uploaded_file:
run = wandb.init(project='hydra-classifier', name='feedback-loop')
if uploaded_file.type == "application/pdf":
images = convert_from_bytes(uploaded_file.getvalue())
else:
images = [Image.open(uploaded_file)]
for i, image in enumerate(images):
st.image(image, caption=f'Uploaded Image {i}', use_container_width=True)
print(f'Encoding image with index {i}')
encoding = processor(
image,
return_tensors="pt",
truncation=True,
max_length=512,
)
print(f'Predicting image with index {i}')
outputs = model(**encoding)
prediction = outputs.logits.argmax(-1)[0].item()
st.write(f"Prediction: {id2label[prediction]}")
feedback = st.radio(
"Is the classification correct?", ("Yes", "No"),
key=f'prediction-{i}'
)
if feedback == "No":
correct_label = st.selectbox(
"Please select the correct label:", labels,
key=f'selectbox-{i}'
)
print(f'Correct label for image {i}: {correct_label}')
run.log({
'filepath': uploaded_file,
'filetype': uploaded_file.type,
'predicted_label': id2label[prediction],
'predicted_label_id': prediction,
'correct_label': correct_label,
'correct_label_id': label2id[correct_label]
})
run.finish()
|