import streamlit as st
import re
import PyPDF2
import matplotlib.pyplot as plt
import io
from wordcloud import WordCloud
from PIL import Image

from rouge import Rouge
from datasets import load_dataset
from extractive_summarization import summarize_with_textrank, summarize_with_lsa
from abstractive_summarization import summarize_with_bart_cnn, summarize_with_bart_ft, summarize_with_led, summarize_with_t5
#from keyword_extraction import extract_keywords
from keyphrase_extraction import extract_sentences_with_obligations
from hybrid_summarization import summarize_hybrid

#-------------------------------------------------------------------#
# Load in ToS-Summaries dataset
dataset = load_dataset("EE21/ToS-Summaries")

# Extract titles or identifiers for the ToS
tos_titles = [f"Document {i}" for i in range(len(dataset['train']))]
    
# Set page to wide mode
st.set_page_config(layout="wide")

# Function to handle file upload and return its content
def load_pdf(file):
    pdf_reader = PyPDF2.PdfReader(file)
    pdf_text = ""
    for page_num in range(len(pdf_reader.pages)):
        pdf_text += pdf_reader.pages[page_num].extract_text() or ""
    return pdf_text

# Main app
def main():
    st.title("QuickToS - Terms of Service Summarizer")

    # Layout: 3 columns
    col1, col2, col3 = st.columns([1, 3, 2], gap="large")

    # Left column: Radio buttons for summarizer choice
    with col1:
        radio_options = ["Hybrid (RAKE + BART Fine-tuned)", "Abstractive (LongT5)", "Abstractive (LED)", 'Abstractive (BART Fine-tuned)', "Abstractive (BART-large-CNN)", 'Extractive (TextRank)', 
                         "Extractive (Latent Semantic Analysis)", 'Keyphrase Extraction (RAKE)']
        
        radio_selection = st.radio("Choose type of summarizer", radio_options)

    # Middle column: Text input and File uploader
    with col2:
        user_input = st.text_area("Enter a text")
        uploaded_file = st.file_uploader("Upload a PDF", type="pdf")
        
        # Dropdown for selecting the document
        tos_selection_index  = st.selectbox("Select a Terms of Service Document (only for testing purposes)", range(len(tos_titles)), format_func=lambda x: tos_titles[x])
        
        if st.button("Summarize"):
            if uploaded_file and user_input and tos_selection_index:
                st.warning("Please provide either text input or a PDF file, not both.")
                return
            elif uploaded_file:
                # Extract text from PDF
                file_content = load_pdf(uploaded_file)
                st.write("PDF uploaded successfully.")
            elif user_input:
                file_content = user_input
            elif tos_selection_index is not None:
                file_content = dataset['train'][tos_selection_index]['plain_text']
            else:
                st.warning("Please upload a PDF, enter some text, or select a document to summarize.")
                return

            # Perform hybrid summarization
            if radio_selection == "Hybrid (RAKE + BART Fine-tuned)":
                summary = summarize_hybrid(file_content)
                st.session_state.summary = summary

            # Perform extractive summarization
            if radio_selection == "Extractive (TextRank)":
                summary = summarize_with_textrank(file_content)
                st.session_state.summary = summary

            # Perform extractive summarization
            if radio_selection == "Extractive (Latent Semantic Analysis)":
                summary = summarize_with_lsa(file_content)
                st.session_state.summary = summary

            # Perform abstractive summarization
            if radio_selection == "Abstractive (BART Fine-tuned)":
                summary = summarize_with_bart_ft(file_content)
                st.session_state.summary = summary

            # Perform abstractive summarization
            if radio_selection == "Abstractive (BART-large-CNN)":
                summary = summarize_with_bart_cnn(file_content)
                st.session_state.summary = summary

            # Perform abstractive summarization
            if radio_selection == "Abstractive (LongT5)":
                summary = summarize_with_t5(file_content)
                st.session_state.summary = summary

            # Perform abstractive summarization
            if radio_selection == "Abstractive (LED)":
                summary = summarize_with_led(file_content)
                st.session_state.summary = summary

            # Perform Keyword Extraction
            #if radio_selection == "Keyword Extraction (RAKE)":
            #    summary = extract_keywords(file_content)
            #    st.session_state.summary = summary

            # Perform Keyphrase Extraction
            if radio_selection == "Keyphrase Extraction (RAKE)":
                summary = extract_sentences_with_obligations(file_content)
                st.session_state.summary = summary
    
    # Right column: Displaying text after pressing 'Summarize'
    with col3:
        st.write("Summary")
        if 'summary' in st.session_state:
            st.write(st.session_state.summary)

            # Generate and display word cloud
            wordcloud = WordCloud(width=800, height=400, background_color='white', max_words=20).generate(st.session_state.summary)
            # Convert to PIL Image
            image = wordcloud.to_image()
            # Convert PIL Image to bytes
            buf = io.BytesIO()
            image.save(buf, format='PNG')
            byte_im = buf.getvalue()
            st.image(byte_im, caption='Word Cloud of Summary', use_column_width=True)

            # Check if no PDF or text input is provided and a ToS document is selected
            if not uploaded_file and not user_input and tos_selection_index is not None and 'summary' in dataset['train'][tos_selection_index]:
                # Fetch the reference summary
                reference_summary = dataset['train'][tos_selection_index]['summary']
    
                # Calculate ROUGE scores
                rouge = Rouge()
                scores = rouge.get_scores(st.session_state.summary, reference_summary)
    
            # Display ROUGE scores as styled text
                col1, col2, col3 = st.columns(3)
                with col1:
                    st.markdown(f"<p style='text-align: center; color: black; border: 1px solid #cccccc; padding: 5px; border-radius: 4px;'>ROUGE-1: {scores[0]['rouge-1']['f']:.4f}</p>", unsafe_allow_html=True)
                with col2:
                    st.markdown(f"<p style='text-align: center; color: black; border: 1px solid #cccccc; padding: 5px; border-radius: 4px;'>ROUGE-2: {scores[0]['rouge-2']['f']:.4f}</p>", unsafe_allow_html=True)
                with col3:
                    st.markdown(f"<p style='text-align: center; color: black; border: 1px solid #cccccc; padding: 5px; border-radius: 4px;'>ROUGE-L: {scores[0]['rouge-l']['f']:.4f}</p>", unsafe_allow_html=True)

if __name__ == "__main__":
    main()