Shak33l-UiRev's picture
Update app.py
db7e40d verified
raw
history blame
5.15 kB
import streamlit as st
from PIL import Image
import torch
from transformers import (
DonutProcessor,
VisionEncoderDecoderModel,
LayoutLMv3Processor,
LayoutLMv3ForSequenceClassification,
BrosProcessor,
BrosForTokenClassification,
LlavaProcessor,
LlavaForConditionalGeneration
)
def load_model(model_name):
"""Load the selected model and processor"""
if model_name == "Donut":
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base")
elif model_name == "LayoutLMv3":
processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base")
model = LayoutLMv3ForSequenceClassification.from_pretrained("microsoft/layoutlmv3-base")
elif model_name == "BROS":
processor = BrosProcessor.from_pretrained("microsoft/bros-base")
model = BrosForTokenClassification.from_pretrained("microsoft/bros-base")
elif model_name == "LLaVA-1.5":
processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
return model, processor
def analyze_document(image, model_name, model, processor):
"""Analyze document using selected model"""
try:
# Process image according to model requirements
if model_name == "Donut":
inputs = processor(image, return_tensors="pt")
outputs = model.generate(**inputs)
result = processor.decode(outputs[0], skip_special_tokens=True)
elif model_name == "LayoutLMv3":
inputs = processor(image, return_tensors="pt")
outputs = model(**inputs)
result = outputs.logits
# Add similar processing for other models
return result
except Exception as e:
st.error(f"Error analyzing document: {str(e)}")
return None
# Set page config
st.set_page_config(page_title="Document Analysis Comparison", layout="wide")
# Title and description
st.title("Document Understanding Model Comparison")
st.markdown("""
Compare different models for document analysis and understanding.
Upload an image and select a model to analyze it.
""")
# Create two columns for layout
col1, col2 = st.columns([1, 1])
with col1:
# File uploader
uploaded_file = st.file_uploader("Choose a document image", type=['png', 'jpg', 'jpeg', 'pdf'])
if uploaded_file is not None:
# Display uploaded image
image = Image.open(uploaded_file)
st.image(image, caption='Uploaded Document', use_column_width=True)
with col2:
# Model selection
model_info = {
"Donut": {
"description": "Best for structured OCR and document format understanding",
"memory": "6-8GB",
"strengths": ["Structured OCR", "Memory efficient", "Good with fixed formats"]
},
"LayoutLMv3": {
"description": "Strong layout understanding with reasoning capabilities",
"memory": "12-15GB",
"strengths": ["Layout understanding", "Reasoning", "Pre-trained knowledge"]
},
"BROS": {
"description": "Memory efficient with fast inference",
"memory": "4-6GB",
"strengths": ["Fast inference", "Memory efficient", "Easy fine-tuning"]
},
"LLaVA-1.5": {
"description": "Comprehensive OCR with strong reasoning",
"memory": "25-40GB",
"strengths": ["Strong reasoning", "Zero-shot capable", "Visual understanding"]
}
}
selected_model = st.selectbox(
"Select Model",
list(model_info.keys())
)
# Display model information
st.write("### Model Details")
st.write(f"**Description:** {model_info[selected_model]['description']}")
st.write(f"**Memory Required:** {model_info[selected_model]['memory']}")
st.write("**Strengths:**")
for strength in model_info[selected_model]['strengths']:
st.write(f"- {strength}")
# Analysis section
if uploaded_file is not None and selected_model:
if st.button("Analyze Document"):
with st.spinner('Loading model and analyzing document...'):
try:
# Load model and processor
model, processor = load_model(selected_model)
# Analyze document
results = analyze_document(image, selected_model, model, processor)
# Display results
st.write("### Analysis Results")
st.json(results)
except Exception as e:
st.error(f"Error during analysis: {str(e)}")
# Add information about usage and limitations
st.markdown("""
---
### Notes:
- Different models may perform better for different types of documents
- Processing time and memory requirements vary by model
- Results may vary based on document quality and format
""")
# Add a footer with version information
st.markdown("---")
st.markdown("v1.0 - Created with Streamlit")