Spaces:
Running
Running
File size: 5,154 Bytes
4045262 db7e40d 4045262 db7e40d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import streamlit as st
from PIL import Image
import torch
from transformers import (
DonutProcessor,
VisionEncoderDecoderModel,
LayoutLMv3Processor,
LayoutLMv3ForSequenceClassification,
BrosProcessor,
BrosForTokenClassification,
LlavaProcessor,
LlavaForConditionalGeneration
)
def load_model(model_name):
"""Load the selected model and processor"""
if model_name == "Donut":
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base")
elif model_name == "LayoutLMv3":
processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base")
model = LayoutLMv3ForSequenceClassification.from_pretrained("microsoft/layoutlmv3-base")
elif model_name == "BROS":
processor = BrosProcessor.from_pretrained("microsoft/bros-base")
model = BrosForTokenClassification.from_pretrained("microsoft/bros-base")
elif model_name == "LLaVA-1.5":
processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
return model, processor
def analyze_document(image, model_name, model, processor):
"""Analyze document using selected model"""
try:
# Process image according to model requirements
if model_name == "Donut":
inputs = processor(image, return_tensors="pt")
outputs = model.generate(**inputs)
result = processor.decode(outputs[0], skip_special_tokens=True)
elif model_name == "LayoutLMv3":
inputs = processor(image, return_tensors="pt")
outputs = model(**inputs)
result = outputs.logits
# Add similar processing for other models
return result
except Exception as e:
st.error(f"Error analyzing document: {str(e)}")
return None
# Set page config
st.set_page_config(page_title="Document Analysis Comparison", layout="wide")
# Title and description
st.title("Document Understanding Model Comparison")
st.markdown("""
Compare different models for document analysis and understanding.
Upload an image and select a model to analyze it.
""")
# Create two columns for layout
col1, col2 = st.columns([1, 1])
with col1:
# File uploader
uploaded_file = st.file_uploader("Choose a document image", type=['png', 'jpg', 'jpeg', 'pdf'])
if uploaded_file is not None:
# Display uploaded image
image = Image.open(uploaded_file)
st.image(image, caption='Uploaded Document', use_column_width=True)
with col2:
# Model selection
model_info = {
"Donut": {
"description": "Best for structured OCR and document format understanding",
"memory": "6-8GB",
"strengths": ["Structured OCR", "Memory efficient", "Good with fixed formats"]
},
"LayoutLMv3": {
"description": "Strong layout understanding with reasoning capabilities",
"memory": "12-15GB",
"strengths": ["Layout understanding", "Reasoning", "Pre-trained knowledge"]
},
"BROS": {
"description": "Memory efficient with fast inference",
"memory": "4-6GB",
"strengths": ["Fast inference", "Memory efficient", "Easy fine-tuning"]
},
"LLaVA-1.5": {
"description": "Comprehensive OCR with strong reasoning",
"memory": "25-40GB",
"strengths": ["Strong reasoning", "Zero-shot capable", "Visual understanding"]
}
}
selected_model = st.selectbox(
"Select Model",
list(model_info.keys())
)
# Display model information
st.write("### Model Details")
st.write(f"**Description:** {model_info[selected_model]['description']}")
st.write(f"**Memory Required:** {model_info[selected_model]['memory']}")
st.write("**Strengths:**")
for strength in model_info[selected_model]['strengths']:
st.write(f"- {strength}")
# Analysis section
if uploaded_file is not None and selected_model:
if st.button("Analyze Document"):
with st.spinner('Loading model and analyzing document...'):
try:
# Load model and processor
model, processor = load_model(selected_model)
# Analyze document
results = analyze_document(image, selected_model, model, processor)
# Display results
st.write("### Analysis Results")
st.json(results)
except Exception as e:
st.error(f"Error during analysis: {str(e)}")
# Add information about usage and limitations
st.markdown("""
---
### Notes:
- Different models may perform better for different types of documents
- Processing time and memory requirements vary by model
- Results may vary based on document quality and format
""")
# Add a footer with version information
st.markdown("---")
st.markdown("v1.0 - Created with Streamlit") |