import os import wandb import streamlit as st from transformers import LayoutLMv3Processor, LayoutLMv3ForSequenceClassification from pdf2image import convert_from_bytes from PIL import Image wandb_api_key = os.getenv("WANDB_API_KEY") if not wandb_api_key: st.error( "Couldn't find WanDB API key. Please set it up as an environemnt variable", icon="🚨", ) else: wandb.login(key=wandb_api_key) labels = [ 'budget', 'email', 'form', 'handwritten', 'invoice', 'language', 'letter', 'memo', 'news article', 'questionnaire', 'resume', 'scientific publication', 'specification', ] id2label = {i: label for i, label in enumerate(labels)} label2id = {v: k for k, v in id2label.items()} processor = LayoutLMv3Processor.from_pretrained("model/layoutlmv3/") model = LayoutLMv3ForSequenceClassification.from_pretrained("model/layoutlmv3/") st.title("Document Classification with LayoutLMv3") uploaded_file = st.file_uploader( "Upload Document", type=["pdf", "jpg", "png"], accept_multiple_files=False ) if uploaded_file: run = wandb.init(project='hydra-classifier', name='feedback-loop') if uploaded_file.type == "application/pdf": images = convert_from_bytes(uploaded_file.getvalue()) else: images = [Image.open(uploaded_file)] for i, image in enumerate(images): st.image(image, caption=f'Uploaded Image {i}', use_container_width=True) print(f'Encoding image with index {i}') encoding = processor( image, return_tensors="pt", truncation=True, max_length=512, ) print(f'Predicting image with index {i}') outputs = model(**encoding) prediction = outputs.logits.argmax(-1)[0].item() st.write(f"Prediction: {id2label[prediction]}") feedback = st.radio( "Is the classification correct?", ("Yes", "No"), key=f'prediction-{i}' ) if feedback == "No": correct_label = st.selectbox( "Please select the correct label:", labels, key=f'selectbox-{i}' ) print(f'Correct label for image {i}: {correct_label}') run.log({ 'filepath': uploaded_file, 'filetype': uploaded_file.type, 'predicted_label': id2label[prediction], 'predicted_label_id': prediction, 'correct_label': correct_label, 'correct_label_id': label2id[correct_label] }) run.finish()