Spaces:
Sleeping
Sleeping
File size: 3,372 Bytes
11deda5 056ccc3 5639711 c1d4001 056ccc3 2bc05a3 11deda5 d73399c 9652b01 6c6f2d5 056ccc3 fd98f6f 056ccc3 ff1aca1 6c6f2d5 11deda5 6c6f2d5 fd98f6f 548ee28 fd98f6f 056ccc3 fd98f6f 180d132 63eb0c6 6c6f2d5 056ccc3 61dba08 056ccc3 fd98f6f 61dba08 fd98f6f 11deda5 fd98f6f 392dd2d d280e22 056ccc3 8f5a987 11deda5 ff1aca1 b90d0b6 ff1aca1 11deda5 6c0983b 6c6f2d5 ff1aca1 11deda5 6c6f2d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import os
import wandb
import streamlit as st
from transformers import LayoutLMv3Processor, LayoutLMv3ForSequenceClassification
from pdf2image import convert_from_bytes
from PIL import Image
wandb_api_key = os.getenv("WANDB_API_KEY")
if not wandb_api_key:
st.error(
"Couldn't find WanDB API key. Please set it up as an environemnt variable",
icon="🚨",
)
else:
wandb.login(key=wandb_api_key)
labels = [
'budget',
'email',
'form',
'handwritten',
'invoice',
'language',
'letter',
'memo',
'news article',
'questionnaire',
'resume',
'scientific publication',
'specification',
]
id2label = {i: label for i, label in enumerate(labels)}
label2id = {v: k for k, v in id2label.items()}
if 'model' not in st.session_state:
st.session_state.model = LayoutLMv3ForSequenceClassification.from_pretrained("model/layoutlmv3/")
if 'processor' not in st.session_state:
st.session_state.processor = LayoutLMv3Processor.from_pretrained("model/layoutlmv3/")
model = st.session_state.model
processor = st.session_state.processor
st.title("Document Classification with LayoutLMv3")
uploaded_file = st.file_uploader(
"Upload Document", type=["pdf", "jpg", "png"], accept_multiple_files=False
)
feedback_table = wandb.Table(columns=[
'image', 'filetype', 'predicted_label', 'predicted_label_id',
'correct_label', 'correct_label_id'
])
if 'wandb_run' not in st.session_data:
st.session_data.wandb_run = wandb.init(project='hydra-classifier', name='feedback-loop')
@st.cache_data
def classify_image(image):
print(f'Encoding image with index {i}')
encoding = processor(
image,
return_tensors="pt",
truncation=True,
max_length=512,
)
print(f'Predicting image with index {i}')
outputs = model(**encoding)
prediction = outputs.logits.argmax(-1)[0].item()
return prediction
if uploaded_file:
if uploaded_file.type == "application/pdf":
images = convert_from_bytes(uploaded_file.getvalue())
else:
images = [Image.open(uploaded_file)]
for i, image in enumerate(images):
st.image(image, caption=f'Uploaded Image {i}', use_container_width=True)
prediction = classify_image(image)
st.write(f"Prediction: {id2label[prediction]}")
feedback = st.radio(
"Is the classification correct?", ("Yes", "No"),
key=f'prediction-{i}'
)
if feedback == "No":
correct_label = st.selectbox(
"Please select the correct label:", labels,
key=f'selectbox-{i}'
)
print(f'Correct label for image {i}: {correct_label}')
# Add a button to confirm feedback and log it
if st.button(f"Submit Feedback for Image {i}", key=f'submit-{i}'):
feedback_table.add_data(
wandb.Image(image),
uploaded_file.type,
id2label[prediction],
prediction,
correct_label,
label2id[correct_label],
)
st.success(f"Feedback for Image {i} submitted!")
print(feedback_table)
run = st.session_data.wandb_run
run.log({'feedback_table': feedback_table})
run.finish()
st.session_data.wandb_run = None
|