Spaces:
Runtime error
Runtime error
import streamlit as st | |
import os | |
import tempfile | |
import logging | |
from dotenv import load_dotenv | |
import uuid | |
# UI Components moved to src/ui_components.py for easier debugging and maintenance | |
from src.ui_components import ( | |
setup_page_config, load_custom_css, render_header, | |
render_getting_started, render_system_info, | |
render_processing_spinner | |
) | |
from src.rag_pipeline import RAGPipeline | |
load_dotenv() | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def initialize_session_state(): | |
if 'session_id' not in st.session_state: | |
st.session_state.session_id = str(uuid.uuid4()) | |
if 'rag_pipeline' not in st.session_state: | |
st.session_state.rag_pipeline = None | |
if 'messages' not in st.session_state: | |
st.session_state.messages = [] | |
if 'rag_sources' not in st.session_state: | |
st.session_state.rag_sources = [] | |
if 'document_loaded' not in st.session_state: | |
st.session_state.document_loaded = False | |
if 'document_stats' not in st.session_state: | |
st.session_state.document_stats = None | |
def process_uploaded_document(uploaded_file): | |
try: | |
st.info(f"Starting to process: {uploaded_file.name}") | |
# Save uploaded file temporarily | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.txt', mode='w', encoding='utf-8') as tmp_file: | |
content = uploaded_file.getvalue().decode('utf-8') | |
tmp_file.write(content) | |
tmp_file_path = tmp_file.name | |
st.info(f"File saved temporarily at: {tmp_file_path}") | |
st.info(f"File content length: {len(content)} characters") | |
# Initialize RAG pipeline if not already done | |
if st.session_state.rag_pipeline is None: | |
st.info("Initializing RAG pipeline...") | |
st.session_state.rag_pipeline = RAGPipeline() | |
# Process document | |
st.info("Processing document through RAG pipeline...") | |
success = st.session_state.rag_pipeline.process_document(tmp_file_path) | |
if success: | |
st.info("Document processed successfully, getting statistics...") | |
# Get document statistics | |
chunks = st.session_state.rag_pipeline.document_processor.process_document(tmp_file_path) | |
stats = st.session_state.rag_pipeline.document_processor.get_document_stats(chunks) | |
# Update session state | |
st.session_state.document_loaded = True | |
st.session_state.document_stats = stats | |
st.info(f"Document processed successfully: {stats['total_chunks']} chunks") | |
else: | |
st.error("Failed to process document") | |
# Clean up temporary file | |
os.unlink(tmp_file_path) | |
return success | |
except Exception as e: | |
st.error(f"Error processing uploaded document: {e}") | |
logger.error(f"Error processing uploaded document: {e}") | |
return False | |
def handle_user_query(user_question): | |
try: | |
if not st.session_state.rag_pipeline or not st.session_state.document_loaded: | |
return "Please upload a document first before asking questions.", [] | |
# Add user question to messages | |
st.session_state.messages.append({"role": "user", "content": user_question}) | |
# Get response from RAG pipeline | |
with render_processing_spinner("Thinking..."): | |
answer, source_docs = st.session_state.rag_pipeline.query(user_question) | |
# Add assistant response to messages | |
st.session_state.messages.append({ | |
"role": "assistant", | |
"content": answer, | |
"sources": source_docs | |
}) | |
logger.info(f"Query processed: '{user_question[:50]}...'") | |
return answer, source_docs | |
except Exception as e: | |
logger.error(f"Error handling user query: {e}") | |
error_message = f"Error processing query: {str(e)}" | |
st.session_state.messages.append({"role": "assistant", "content": error_message, "sources": []}) | |
return error_message, [] | |
def clear_all_documents(): | |
st.session_state.rag_sources = [] | |
st.session_state.document_loaded = False | |
st.session_state.document_stats = None | |
st.session_state.rag_pipeline = None | |
st.session_state.uploaded_files = [] | |
# Clear the vector store as well | |
if st.session_state.rag_pipeline and st.session_state.rag_pipeline.vector_store_manager: | |
st.session_state.rag_pipeline.vector_store_manager.clear_vector_store() | |
# Increment uploader key to reset file uploader | |
if 'uploader_key' not in st.session_state: | |
st.session_state.uploader_key = 0 | |
st.session_state.uploader_key += 1 | |
st.rerun() | |
def process_uploaded_files(): | |
if 'uploaded_files' in st.session_state and st.session_state.uploaded_files: | |
for uploaded_file in st.session_state.uploaded_files: | |
if uploaded_file.name not in st.session_state.rag_sources: | |
# Simple test - just read the file content first | |
try: | |
content = uploaded_file.getvalue().decode('utf-8') | |
st.success(f"β {uploaded_file.name} uploaded successfully! Content length: {len(content)} characters") | |
st.session_state.rag_sources.append(uploaded_file.name) | |
# Set document_loaded to True when we have files | |
st.session_state.document_loaded = True | |
# Now try to process with RAG pipeline | |
with st.spinner(f"Processing {uploaded_file.name} with RAG..."): | |
success = process_uploaded_document(uploaded_file) | |
if success: | |
st.success(f"β {uploaded_file.name} RAG processing completed!") | |
else: | |
st.error(f"β RAG processing failed for {uploaded_file.name}") | |
except Exception as e: | |
st.error(f"β Error reading {uploaded_file.name}: {e}") | |
# Clear the uploaded files from session state to prevent reprocessing | |
st.session_state.uploaded_files = [] | |
def main(): | |
# Setup page configuration and styling | |
setup_page_config() | |
load_custom_css() | |
# Initialize session state | |
initialize_session_state() | |
# Render main header | |
render_header() | |
# Add getting started section | |
if not st.session_state.document_loaded: | |
render_getting_started() | |
# Clear buttons | |
col1, col2 = st.columns(2) | |
with col1: | |
if st.button("Clear Chat", type="primary"): | |
st.session_state.messages.clear() | |
st.rerun() | |
with col2: | |
if st.button("Clear All Documents", type="secondary"): | |
clear_all_documents() | |
# Initialize uploader key | |
if 'uploader_key' not in st.session_state: | |
st.session_state.uploader_key = 0 | |
# File upload input | |
uploaded_files = st.file_uploader( | |
"π Upload a text document (.txt only, max 200MB)", | |
type=["txt"], | |
accept_multiple_files=True, | |
key=f"rag_docs_{st.session_state.uploader_key}" | |
) | |
# Store uploaded files in session state and process them | |
if uploaded_files: | |
st.session_state.uploaded_files = uploaded_files | |
st.info(f"Files uploaded: {[f.name for f in uploaded_files]}") | |
process_uploaded_files() | |
# Show documents in DB with individual remove buttons | |
with st.expander(f"π Documents in DB ({len(st.session_state.rag_sources)})"): | |
if st.session_state.rag_sources: | |
for i, doc in enumerate(st.session_state.rag_sources): | |
col1, col2 = st.columns([3, 1]) | |
with col1: | |
st.write(f"β’ {doc}") | |
with col2: | |
if st.button("ποΈ", key=f"remove_doc_{i}_{doc}"): | |
# Remove the document | |
st.session_state.rag_sources.pop(i) | |
# Reset document_loaded if no documents left | |
if len(st.session_state.rag_sources) == 0: | |
st.session_state.document_loaded = False | |
st.session_state.document_stats = None | |
st.session_state.rag_pipeline = None | |
st.rerun() | |
else: | |
st.write("No documents in database") | |
# Display chat history | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# Chat input | |
if prompt := st.chat_input("Your message"): | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
with st.chat_message("assistant"): | |
message_placeholder = st.empty() | |
full_response = "" | |
# RAG response | |
answer, source_docs = handle_user_query(prompt) | |
st.write(answer) | |
# Show source documents if available | |
if source_docs and isinstance(source_docs, list) and len(source_docs) > 0: | |
with st.expander("π View Source Documents"): | |
for i, doc in enumerate(source_docs[:3]): # Show top 3 sources | |
st.markdown(f"**Source {i+1}:**") | |
st.markdown(f'{doc.page_content[:300]}{"..." if len(doc.page_content) > 300 else ""}') | |
st.divider() | |
# System information | |
if st.session_state.rag_pipeline: | |
system_info = st.session_state.rag_pipeline.get_system_info() | |
render_system_info(system_info) | |
if __name__ == "__main__": | |
main() |