File size: 5,154 Bytes
4045262
db7e40d
 
 
 
 
 
 
 
 
 
 
 
4045262
db7e40d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import streamlit as st
from PIL import Image
import torch
from transformers import (
    DonutProcessor, 
    VisionEncoderDecoderModel,
    LayoutLMv3Processor, 
    LayoutLMv3ForSequenceClassification,
    BrosProcessor,
    BrosForTokenClassification,
    LlavaProcessor,
    LlavaForConditionalGeneration
)

def load_model(model_name):
    """Load the selected model and processor"""
    if model_name == "Donut":
        processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")
        model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base")
    elif model_name == "LayoutLMv3":
        processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base")
        model = LayoutLMv3ForSequenceClassification.from_pretrained("microsoft/layoutlmv3-base")
    elif model_name == "BROS":
        processor = BrosProcessor.from_pretrained("microsoft/bros-base")
        model = BrosForTokenClassification.from_pretrained("microsoft/bros-base")
    elif model_name == "LLaVA-1.5":
        processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
        model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
    
    return model, processor

def analyze_document(image, model_name, model, processor):
    """Analyze document using selected model"""
    try:
        # Process image according to model requirements
        if model_name == "Donut":
            inputs = processor(image, return_tensors="pt")
            outputs = model.generate(**inputs)
            result = processor.decode(outputs[0], skip_special_tokens=True)
        elif model_name == "LayoutLMv3":
            inputs = processor(image, return_tensors="pt")
            outputs = model(**inputs)
            result = outputs.logits
        # Add similar processing for other models
        
        return result
    except Exception as e:
        st.error(f"Error analyzing document: {str(e)}")
        return None

# Set page config
st.set_page_config(page_title="Document Analysis Comparison", layout="wide")

# Title and description
st.title("Document Understanding Model Comparison")
st.markdown("""
Compare different models for document analysis and understanding.
Upload an image and select a model to analyze it.
""")

# Create two columns for layout
col1, col2 = st.columns([1, 1])

with col1:
    # File uploader
    uploaded_file = st.file_uploader("Choose a document image", type=['png', 'jpg', 'jpeg', 'pdf'])
    
    if uploaded_file is not None:
        # Display uploaded image
        image = Image.open(uploaded_file)
        st.image(image, caption='Uploaded Document', use_column_width=True)

with col2:
    # Model selection
    model_info = {
        "Donut": {
            "description": "Best for structured OCR and document format understanding",
            "memory": "6-8GB",
            "strengths": ["Structured OCR", "Memory efficient", "Good with fixed formats"]
        },
        "LayoutLMv3": {
            "description": "Strong layout understanding with reasoning capabilities",
            "memory": "12-15GB",
            "strengths": ["Layout understanding", "Reasoning", "Pre-trained knowledge"]
        },
        "BROS": {
            "description": "Memory efficient with fast inference",
            "memory": "4-6GB",
            "strengths": ["Fast inference", "Memory efficient", "Easy fine-tuning"]
        },
        "LLaVA-1.5": {
            "description": "Comprehensive OCR with strong reasoning",
            "memory": "25-40GB",
            "strengths": ["Strong reasoning", "Zero-shot capable", "Visual understanding"]
        }
    }
    
    selected_model = st.selectbox(
        "Select Model",
        list(model_info.keys())
    )
    
    # Display model information
    st.write("### Model Details")
    st.write(f"**Description:** {model_info[selected_model]['description']}")
    st.write(f"**Memory Required:** {model_info[selected_model]['memory']}")
    st.write("**Strengths:**")
    for strength in model_info[selected_model]['strengths']:
        st.write(f"- {strength}")

# Analysis section
if uploaded_file is not None and selected_model:
    if st.button("Analyze Document"):
        with st.spinner('Loading model and analyzing document...'):
            try:
                # Load model and processor
                model, processor = load_model(selected_model)
                
                # Analyze document
                results = analyze_document(image, selected_model, model, processor)
                
                # Display results
                st.write("### Analysis Results")
                st.json(results)
                
            except Exception as e:
                st.error(f"Error during analysis: {str(e)}")

# Add information about usage and limitations
st.markdown("""
---
### Notes:
- Different models may perform better for different types of documents
- Processing time and memory requirements vary by model
- Results may vary based on document quality and format
""")

# Add a footer with version information
st.markdown("---")
st.markdown("v1.0 - Created with Streamlit")