import streamlit as st from PIL import Image import torch from transformers import ( DonutProcessor, VisionEncoderDecoderModel, LayoutLMv3Processor, LayoutLMv3ForSequenceClassification, BrosProcessor, BrosForTokenClassification, LlavaProcessor, LlavaForConditionalGeneration ) def load_model(model_name): """Load the selected model and processor""" if model_name == "Donut": processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base") model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base") elif model_name == "LayoutLMv3": processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base") model = LayoutLMv3ForSequenceClassification.from_pretrained("microsoft/layoutlmv3-base") elif model_name == "BROS": processor = BrosProcessor.from_pretrained("microsoft/bros-base") model = BrosForTokenClassification.from_pretrained("microsoft/bros-base") elif model_name == "LLaVA-1.5": processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf") return model, processor def analyze_document(image, model_name, model, processor): """Analyze document using selected model""" try: # Process image according to model requirements if model_name == "Donut": inputs = processor(image, return_tensors="pt") outputs = model.generate(**inputs) result = processor.decode(outputs[0], skip_special_tokens=True) elif model_name == "LayoutLMv3": inputs = processor(image, return_tensors="pt") outputs = model(**inputs) result = outputs.logits # Add similar processing for other models return result except Exception as e: st.error(f"Error analyzing document: {str(e)}") return None # Set page config st.set_page_config(page_title="Document Analysis Comparison", layout="wide") # Title and description st.title("Document Understanding Model Comparison") st.markdown(""" Compare different models for document analysis and understanding. Upload an image and select a model to analyze it. """) # Create two columns for layout col1, col2 = st.columns([1, 1]) with col1: # File uploader uploaded_file = st.file_uploader("Choose a document image", type=['png', 'jpg', 'jpeg', 'pdf']) if uploaded_file is not None: # Display uploaded image image = Image.open(uploaded_file) st.image(image, caption='Uploaded Document', use_column_width=True) with col2: # Model selection model_info = { "Donut": { "description": "Best for structured OCR and document format understanding", "memory": "6-8GB", "strengths": ["Structured OCR", "Memory efficient", "Good with fixed formats"] }, "LayoutLMv3": { "description": "Strong layout understanding with reasoning capabilities", "memory": "12-15GB", "strengths": ["Layout understanding", "Reasoning", "Pre-trained knowledge"] }, "BROS": { "description": "Memory efficient with fast inference", "memory": "4-6GB", "strengths": ["Fast inference", "Memory efficient", "Easy fine-tuning"] }, "LLaVA-1.5": { "description": "Comprehensive OCR with strong reasoning", "memory": "25-40GB", "strengths": ["Strong reasoning", "Zero-shot capable", "Visual understanding"] } } selected_model = st.selectbox( "Select Model", list(model_info.keys()) ) # Display model information st.write("### Model Details") st.write(f"**Description:** {model_info[selected_model]['description']}") st.write(f"**Memory Required:** {model_info[selected_model]['memory']}") st.write("**Strengths:**") for strength in model_info[selected_model]['strengths']: st.write(f"- {strength}") # Analysis section if uploaded_file is not None and selected_model: if st.button("Analyze Document"): with st.spinner('Loading model and analyzing document...'): try: # Load model and processor model, processor = load_model(selected_model) # Analyze document results = analyze_document(image, selected_model, model, processor) # Display results st.write("### Analysis Results") st.json(results) except Exception as e: st.error(f"Error during analysis: {str(e)}") # Add information about usage and limitations st.markdown(""" --- ### Notes: - Different models may perform better for different types of documents - Processing time and memory requirements vary by model - Results may vary based on document quality and format """) # Add a footer with version information st.markdown("---") st.markdown("v1.0 - Created with Streamlit")