File size: 9,963 Bytes
816825a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
import streamlit as st
import os
import tempfile
import logging
from dotenv import load_dotenv
import uuid

# UI Components moved to src/ui_components.py for easier debugging and maintenance

from src.ui_components import (
    setup_page_config, load_custom_css, render_header, 
    render_getting_started, render_system_info, 
    render_processing_spinner
)
from src.rag_pipeline import RAGPipeline

load_dotenv()

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def initialize_session_state():
    if 'session_id' not in st.session_state:
        st.session_state.session_id = str(uuid.uuid4())

    if 'rag_pipeline' not in st.session_state:
        st.session_state.rag_pipeline = None

    if 'messages' not in st.session_state:
        st.session_state.messages = []
    
    if 'rag_sources' not in st.session_state:
        st.session_state.rag_sources = []
    
    if 'document_loaded' not in st.session_state:
        st.session_state.document_loaded = False
    
    if 'document_stats' not in st.session_state:
        st.session_state.document_stats = None

def process_uploaded_document(uploaded_file):
    try:
        st.info(f"Starting to process: {uploaded_file.name}")
        
        # Save uploaded file temporarily
        with tempfile.NamedTemporaryFile(delete=False, suffix='.txt', mode='w', encoding='utf-8') as tmp_file:
            content = uploaded_file.getvalue().decode('utf-8')
            tmp_file.write(content)
            tmp_file_path = tmp_file.name

        st.info(f"File saved temporarily at: {tmp_file_path}")
        st.info(f"File content length: {len(content)} characters")

        # Initialize RAG pipeline if not already done
        if st.session_state.rag_pipeline is None:
            st.info("Initializing RAG pipeline...")
            st.session_state.rag_pipeline = RAGPipeline()

        # Process document
        st.info("Processing document through RAG pipeline...")
        success = st.session_state.rag_pipeline.process_document(tmp_file_path)
        
        if success:
            st.info("Document processed successfully, getting statistics...")
            # Get document statistics
            chunks = st.session_state.rag_pipeline.document_processor.process_document(tmp_file_path)
            stats = st.session_state.rag_pipeline.document_processor.get_document_stats(chunks)
            
            # Update session state
            st.session_state.document_loaded = True
            st.session_state.document_stats = stats
            
            st.info(f"Document processed successfully: {stats['total_chunks']} chunks")
        else:
            st.error("Failed to process document")
        
        # Clean up temporary file
        os.unlink(tmp_file_path)
        
        return success
        
    except Exception as e:
        st.error(f"Error processing uploaded document: {e}")
        logger.error(f"Error processing uploaded document: {e}")
        return False
    
def handle_user_query(user_question):
    try:
        if not st.session_state.rag_pipeline or not st.session_state.document_loaded:
            return "Please upload a document first before asking questions.", []
        
        # Add user question to messages
        st.session_state.messages.append({"role": "user", "content": user_question})
        
        # Get response from RAG pipeline
        with render_processing_spinner("Thinking..."):
            answer, source_docs = st.session_state.rag_pipeline.query(user_question)

         # Add assistant response to messages
        st.session_state.messages.append({
            "role": "assistant", 
            "content": answer, 
            "sources": source_docs
        })

        logger.info(f"Query processed: '{user_question[:50]}...'")
        return answer, source_docs
        
    except Exception as e:
        logger.error(f"Error handling user query: {e}")
        error_message = f"Error processing query: {str(e)}"
        st.session_state.messages.append({"role": "assistant", "content": error_message, "sources": []})
        return error_message, []
    
def clear_all_documents():
    st.session_state.rag_sources = []
    st.session_state.document_loaded = False
    st.session_state.document_stats = None
    st.session_state.rag_pipeline = None
    st.session_state.uploaded_files = []

    # Clear the vector store as well
    if st.session_state.rag_pipeline and st.session_state.rag_pipeline.vector_store_manager:
        st.session_state.rag_pipeline.vector_store_manager.clear_vector_store()
    
    # Increment uploader key to reset file uploader
    if 'uploader_key' not in st.session_state:
        st.session_state.uploader_key = 0
    st.session_state.uploader_key += 1
    st.rerun()

def process_uploaded_files():
    if 'uploaded_files' in st.session_state and st.session_state.uploaded_files:
        for uploaded_file in st.session_state.uploaded_files:
            if uploaded_file.name not in st.session_state.rag_sources:
                # Simple test - just read the file content first
                try:
                    content = uploaded_file.getvalue().decode('utf-8')
                    st.success(f"βœ… {uploaded_file.name} uploaded successfully! Content length: {len(content)} characters")
                    st.session_state.rag_sources.append(uploaded_file.name)
                    
                    # Set document_loaded to True when we have files
                    st.session_state.document_loaded = True
                    
                    # Now try to process with RAG pipeline
                    with st.spinner(f"Processing {uploaded_file.name} with RAG..."):
                        success = process_uploaded_document(uploaded_file)
                        if success:
                            st.success(f"βœ… {uploaded_file.name} RAG processing completed!")
                        else:
                            st.error(f"❌ RAG processing failed for {uploaded_file.name}")
                            
                except Exception as e:
                    st.error(f"❌ Error reading {uploaded_file.name}: {e}")
        
        # Clear the uploaded files from session state to prevent reprocessing
        st.session_state.uploaded_files = []

def main():
    # Setup page configuration and styling
    setup_page_config()
    load_custom_css()
    
    # Initialize session state
    initialize_session_state()
    
    # Render main header
    render_header()

    # Add getting started section
    if not st.session_state.document_loaded:
        render_getting_started()

    # Clear buttons
    col1, col2 = st.columns(2)
    with col1:
        if st.button("Clear Chat", type="primary"):
            st.session_state.messages.clear()
            st.rerun()
    with col2:
        if st.button("Clear All Documents", type="secondary"):
            clear_all_documents()

    # Initialize uploader key
    if 'uploader_key' not in st.session_state:
        st.session_state.uploader_key = 0
    
    # File upload input 
    uploaded_files = st.file_uploader(
        "πŸ“„ Upload a text document (.txt only, max 200MB)", 
        type=["txt"],
        accept_multiple_files=True,
       key=f"rag_docs_{st.session_state.uploader_key}"
    )
    
    # Store uploaded files in session state and process them
    if uploaded_files:
        st.session_state.uploaded_files = uploaded_files
        st.info(f"Files uploaded: {[f.name for f in uploaded_files]}")
        process_uploaded_files()
    
    # Show documents in DB with individual remove buttons
    with st.expander(f"πŸ“š Documents in DB ({len(st.session_state.rag_sources)})"):
        if st.session_state.rag_sources:
            for i, doc in enumerate(st.session_state.rag_sources):
                col1, col2 = st.columns([3, 1])
                with col1:
                    st.write(f"β€’ {doc}")
                with col2:
                    if st.button("πŸ—‘οΈ", key=f"remove_doc_{i}_{doc}"):
                        # Remove the document
                        st.session_state.rag_sources.pop(i)
                        # Reset document_loaded if no documents left
                        if len(st.session_state.rag_sources) == 0:
                            st.session_state.document_loaded = False
                            st.session_state.document_stats = None
                            st.session_state.rag_pipeline = None
                        st.rerun()
        else:
            st.write("No documents in database")
    

    
    # Display chat history
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])
    
    # Chat input
    if prompt := st.chat_input("Your message"):
        st.session_state.messages.append({"role": "user", "content": prompt})
        with st.chat_message("user"):
            st.markdown(prompt)
        
        with st.chat_message("assistant"):
            message_placeholder = st.empty()
            full_response = ""
            
            # RAG response
            answer, source_docs = handle_user_query(prompt)
            st.write(answer)
            
            # Show source documents if available
            if source_docs and isinstance(source_docs, list) and len(source_docs) > 0:
                with st.expander("πŸ“„ View Source Documents"):
                    for i, doc in enumerate(source_docs[:3]):  # Show top 3 sources
                        st.markdown(f"**Source {i+1}:**")
                        st.markdown(f'{doc.page_content[:300]}{"..." if len(doc.page_content) > 300 else ""}')
                        st.divider()
    
    # System information
    if st.session_state.rag_pipeline:
        system_info = st.session_state.rag_pipeline.get_system_info()
        render_system_info(system_info)


if __name__ == "__main__":
    main()