JirasakJo's picture
Update app.py
9314e12 verified
import streamlit as st
import json
import os
import requests
import json
import base64
from datetime import datetime, timedelta
import subprocess
from huggingface_hub import HfApi
from pathlib import Path
from calendar_rag import (
create_default_config,
AcademicCalendarRAG,
PipelineConfig,
ModelConfig,
RetrieverConfig,
CacheConfig,
ProcessingConfig,
LocalizationConfig
)
# Custom CSS for enhanced styling
def load_custom_css():
st.markdown("""
<style>
/* General body styling */
body {
font-family: "Arial", sans-serif !important;
color: #000000 !important;
background-color: white !important;
line-height: 1.7 !important;
}
/* Main container styling */
.main {
padding: 2rem;
color: #000000;
background-color: white;
}
/* Headers styling */
h1 {
color: #000000;
font-size: 2.8rem !important;
font-weight: 700 !important;
margin-bottom: 1.5rem !important;
text-align: center;
padding: 1rem 0;
border-bottom: 3px solid #1E3A8A;
}
h3, h4 {
color: #000000;
font-weight: 600 !important;
font-size: 1.6rem !important;
margin-top: 1.5rem !important;
}
/* Chat message styling */
.chat-message {
padding: 1.5rem;
border-radius: 10px;
margin: 1rem 0;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
font-size: 1.1rem !important;
line-height: 1.6 !important;
font-family: "Arial", sans-serif !important;
color: #000000 !important;
}
.user-message {
background-color: #F3F4F6 !important;
}
.assistant-message {
background-color: #EFF6FF !important;
}
/* Status indicators */
.status-indicator {
padding: 0.5rem 1rem;
border-radius: 6px;
font-weight: 500;
font-size: 1.2rem;
color: #000000;
}
.status-online {
background-color: #DEF7EC;
color: #03543F;
}
.status-offline {
background-color: #FDE8E8;
color: rgb(255, 255, 255);
}
</style>
""", unsafe_allow_html=True)
def clear_conversation_context():
"""Clear conversation context but keep chat display history"""
# Clear the RAG pipeline's conversation history
if 'pipeline' in st.session_state and st.session_state.pipeline:
st.session_state.pipeline.conversation_history = []
# Clear the context memory
st.session_state.context_memory = []
# Note: We keep st.session_state.chat_history for UI display purposes
def initialize_pipeline():
"""Initialize RAG pipeline with conversation memory support"""
try:
# Get API key from environment or secrets
openai_api_key = os.getenv('OPENAI_API_KEY') or st.secrets['OPENAI_API_KEY']
# Create config with same settings as main()
config = create_default_config(openai_api_key)
# Create pipeline
pipeline = AcademicCalendarRAG(config)
# Load raw data instead of calendar.json
try:
with open("calendar.json", "r", encoding="utf-8") as f:
raw_data = json.load(f)
pipeline.load_data(raw_data)
# Initialize conversation history from session state if available
if 'context_memory' in st.session_state and st.session_state.context_memory:
# Convert context memory to conversation history format
conversation_history = []
for item in st.session_state.context_memory:
conversation_history.append({"role": "user", "content": item["query"]})
conversation_history.append({"role": "assistant", "content": item["response"]})
pipeline.conversation_history = conversation_history
return pipeline
except FileNotFoundError:
st.error("calendar.json not found. Please ensure the file exists in the same directory.")
return None
except Exception as e:
st.error(f"Error initializing pipeline: {str(e)}")
return None
def load_qa_history():
"""Load QA history directly from GitHub repository"""
try:
import requests
import base64
import json
# GitHub API configuration
REPO_OWNER = "jirasaksaimekJijo"
REPO_NAME = "swu-chat-bot-project"
FILE_PATH = "qa_history.json"
GITHUB_TOKEN = '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'
# Set up GitHub API request
api_url = f"https://api.github.com/repos/{REPO_OWNER}/{REPO_NAME}/contents/{FILE_PATH}"
headers = {"Accept": "application/vnd.github.v3+json"}
if GITHUB_TOKEN:
headers["Authorization"] = f"token {GITHUB_TOKEN}"
# Make the request to GitHub API
response = requests.get(api_url, headers=headers)
if response.status_code == 200:
# Decode the content from base64
content_data = response.json()
file_content = base64.b64decode(content_data["content"]).decode("utf-8")
# Parse JSON
history_data = json.loads(file_content)
return history_data
else:
st.warning(f"Failed to fetch QA history: {response.status_code} - {response.reason}")
# Return empty list if file doesn't exist or can't be accessed
return []
except Exception as e:
st.error(f"Error loading QA history from GitHub: {str(e)}")
return []
def save_qa_history(history_entry):
"""Save QA history entry to local JSON file and push to GitHub"""
try:
import requests
import base64
import json
from pathlib import Path
# GitHub API configuration
REPO_OWNER = "jirasaksaimekJijo"
REPO_NAME = "swu-chat-bot-project"
FILE_PATH = "qa_history.json"
GITHUB_TOKEN = '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'
# First, load existing data from GitHub
api_url = f"https://api.github.com/repos/{REPO_OWNER}/{REPO_NAME}/contents/{FILE_PATH}"
headers = {
"Accept": "application/vnd.github.v3+json",
"Authorization": f"token {GITHUB_TOKEN}"
}
# Try to get existing file first
response = requests.get(api_url, headers=headers)
# Initialize empty history data
history_data = []
sha = None
if response.status_code == 200:
# File exists, get its content and SHA
content_data = response.json()
sha = content_data["sha"]
try:
# Decode and parse existing content
file_content = base64.b64decode(content_data["content"]).decode("utf-8")
if file_content.strip(): # Make sure content is not empty
history_data = json.loads(file_content)
# Ensure history_data is a list
if not isinstance(history_data, list):
st.warning("Existing history data is not a list. Initializing new list.")
history_data = []
except Exception as e:
st.warning(f"Error parsing existing history: {e}. Initializing new list.")
elif response.status_code == 404:
# File doesn't exist yet
st.info("Creating new QA history file.")
else:
st.error(f"Failed to check existing history: {response.status_code} - {response.reason}")
# Process history entry before appending
if isinstance(history_entry, dict) and all(key in history_entry for key in ["timestamp", "query", "answer"]):
# Process answer if it's a dict
if isinstance(history_entry["answer"], dict):
history_entry["answer"] = history_entry["answer"].get('answer', str(history_entry["answer"]))
# Process answer if it's a Document-like object
elif hasattr(history_entry["answer"], 'content'):
history_entry["answer"] = history_entry["answer"].content
# Convert to string for any other type
else:
history_entry["answer"] = str(history_entry["answer"])
# Append new entry to history data
history_data.append(history_entry)
# Also save locally for backup
try:
local_path = Path("qa_history.json")
with open(local_path, "w", encoding="utf-8") as f:
json.dump(history_data, f, ensure_ascii=False, indent=2)
except Exception as local_err:
st.warning(f"Failed to save local backup: {local_err}")
# Prepare content for GitHub
updated_content = json.dumps(history_data, ensure_ascii=False, indent=2)
encoded_content = base64.b64encode(updated_content.encode('utf-8')).decode('utf-8')
# Prepare the update/create payload
data = {
"message": "Update QA history",
"content": encoded_content,
}
if sha: # If file exists, include its SHA
data["sha"] = sha
# Update or create the file
update_response = requests.put(api_url, headers=headers, json=data)
if update_response.status_code in [200, 201]:
return True
else:
st.error(f"Failed to update QA history: {update_response.status_code} - {update_response.text}")
return False
except Exception as e:
import traceback
st.error(f"Error in save_qa_history: {str(e)}")
st.error(f"Traceback: {traceback.format_exc()}")
return False
def add_to_qa_history(query: str, answer: str):
"""Add new QA pair to history with validation"""
try:
# Validate inputs
if not query or not answer:
st.warning("Empty query or answer detected, skipping history update")
return None
# Handle different answer types
if isinstance(answer, dict):
# If answer is a dict with 'answer' key, extract it
processed_answer = answer.get('answer', str(answer))
elif hasattr(answer, 'content'):
# If answer is a Document-like object with content attribute
processed_answer = answer.content
else:
# Convert answer to string for any other type
processed_answer = str(answer)
# Create history entry with proper timestamp
history_entry = {
"timestamp": (datetime.now() + timedelta(hours=5)).strftime("%Y-%m-%dT%H:%M:%S"),
"query": query,
"answer": processed_answer
}
# Save entry
save_qa_history(history_entry)
return history_entry
except Exception as e:
st.error(f"Error in add_to_qa_history: {str(e)}")
return None
def add_to_history(role: str, message: str):
"""Add message to chat history, save if it's a complete QA pair, and update context memory"""
st.session_state.chat_history.append((role, message))
# If this is an assistant response, save the QA pair
if role == "assistant" and len(st.session_state.chat_history) >= 2:
# Get the corresponding user query (previous message)
user_query = st.session_state.chat_history[-2][1]
# Process and save the QA pair
history_entry = add_to_qa_history(user_query, message)
# Also update context memory if needed
if 'context_memory' not in st.session_state:
st.session_state.context_memory = []
# Format response content for context memory
if isinstance(message, dict) and "answer" in message:
response_content = message["answer"]
else:
response_content = message
st.session_state.context_memory.append({
"query": user_query,
"response": response_content,
"timestamp": (datetime.now() + timedelta(hours=5)).strftime("%Y-%m-%dT%H:%M:%S")
})
# Limit context memory size to prevent performance issues
if len(st.session_state.context_memory) > 10: # Keep last 10 exchanges
st.session_state.context_memory = st.session_state.context_memory[-10:]
def display_chat_history():
"""Display chat history with improved document display"""
for role, content in st.session_state.chat_history:
if role == "user":
st.markdown(f"""
<div class="chat-message user-message">
<strong>🧑 คำถาม:</strong><br>
{content}
</div>
""", unsafe_allow_html=True)
else:
if isinstance(content, dict):
assistant_response = content.get('answer', '❌ ไม่มีข้อมูลคำตอบ')
st.markdown(f"""
<div class="chat-message assistant-message">
<strong>🤖 คำตอบ:</strong><br>
{assistant_response}
</div>
""", unsafe_allow_html=True)
# Show reference documents like in main()
if content.get('documents'):
with st.expander("📚 ข้อมูลอ้างอิง", expanded=False):
for i, doc in enumerate(content['documents'], 1):
st.markdown(f"""
<div style="padding: 1rem; background-color: #000000; border-radius: 8px; margin: 0.5rem 0;">
<strong>เอกสารที่ {i}:</strong><br>
{doc.content}
</div>
""", unsafe_allow_html=True)
else:
st.markdown(f"""
<div class="chat-message assistant-message">
<strong>🤖 คำตอบ:</strong><br>
{content}
</div>
""", unsafe_allow_html=True)
if 'context_memory' not in st.session_state:
st.session_state.context_memory = []
def handle_submit(user_query: str):
"""Enhanced query handling with improved conversation history tracking"""
if not user_query:
st.warning("⚠️ กรุณาระบุคำถาม")
return
user_query = user_query.strip()
# Prevent duplicate submissions
if not st.session_state.chat_history or st.session_state.chat_history[-1][1] != user_query:
try:
st.session_state.processing_query = True
# Add user message to chat history
add_to_history("user", user_query)
# Convert the Streamlit chat history format to RAG format
# Include more context history (up to 5 previous exchanges)
rag_conversation_history = []
history_to_include = st.session_state.chat_history[-11:] if len(st.session_state.chat_history) > 10 else st.session_state.chat_history
for role, content in history_to_include:
# Map Streamlit's role names to the format used in the RAG system
rag_role = "user" if role == "user" else "assistant"
# Handle content based on type
if isinstance(content, dict) and "answer" in content:
rag_content = content["answer"]
else:
rag_content = content
rag_conversation_history.append({"role": rag_role, "content": rag_content})
# Store conversation context in session state
if 'context_memory' not in st.session_state:
st.session_state.context_memory = []
# Process query with improved conversation history
with st.spinner("🔍 กำลังค้นหาคำตอบ..."):
# Add debug logging to verify context
print(f"Processing query with {len(rag_conversation_history)} context messages")
# Add special handling for reference questions
reference_keywords = ["ก่อนหน้านี้", "ก่อนหน้า", "ที่ผ่านมา", "คำถามก่อนหน้า", "คำถามที่แล้ว",
"previous", "earlier", "before", "last time", "last question"]
is_reference_question = any(keyword in user_query.lower() for keyword in reference_keywords)
# If this is a reference question, emphasize context in the query
if is_reference_question and len(rag_conversation_history) >= 3:
# Extract the previous user question (should be 2 positions back)
previous_questions = [msg["content"] for msg in rag_conversation_history[:-2]
if msg["role"] == "user"]
if previous_questions:
prev_question = previous_questions[-1]
enhanced_query = f"คำถามนี้อ้างอิงถึงคำถามก่อนหน้า '{prev_question}' โปรดพิจารณาบริบทนี้ในการตอบ: {user_query}"
print(f"Enhanced reference query: {enhanced_query}")
user_query = enhanced_query
result = st.session_state.pipeline.process_query(
query=user_query,
conversation_history=rag_conversation_history
)
# Create response with same structure as main()
response_dict = {
"answer": result.get("answer", ""),
"documents": result.get("relevant_docs", [])
}
# Update chat history and context
add_to_history("assistant", response_dict)
# Add this exchange to context memory for future reference
st.session_state.context_memory.append({
"query": user_query,
"response": response_dict["answer"],
"timestamp": datetime.now().isoformat()
})
except Exception as e:
error_msg = f"❌ เกิดข้อผิดพลาด: {str(e)}"
add_to_history("assistant", error_msg)
st.error(f"Query processing error: {e}")
finally:
st.session_state.processing_query = False
st.rerun()
def create_chat_input():
"""Create chat input with enhanced configuration and combined clear button"""
with st.form(key="chat_form", clear_on_submit=True):
st.markdown("""
<label for="query_input" style="font-size: 1.2rem; font-weight: 600; margin-bottom: 1rem; display: block;">
<span style="color: #ffffff; border-left: 4px solid #ffffff; padding-left: 0.8rem;">
โปรดระบุคำถามเกี่ยวกับปฏิทินการศึกษา:
</span>
</label>
""", unsafe_allow_html=True)
query = st.text_input(
"",
key="query_input",
placeholder="เช่น: วิชาเลือกมีอะไรบ้าง?"
)
col1, col2 = st.columns([5, 5])
with col1:
submitted = st.form_submit_button(
"📤 ส่งคำถาม",
type="primary",
use_container_width=True
)
with col2:
clear_all_button = st.form_submit_button(
"🗑️ ล้างประวัติและบริบทสนทนา",
type="secondary",
use_container_width=True
)
if submitted:
handle_submit(query)
if clear_all_button:
# Clear chat history
st.session_state.chat_history = []
# Clear conversation context
clear_conversation_context()
st.info("ล้างประวัติและบริบทสนทนาแล้ว")
st.rerun()
def main():
# Page config
st.set_page_config(
page_title="Academic Calendar Assistant",
page_icon="📅",
layout="wide",
initial_sidebar_state="collapsed"
)
# Load custom CSS
load_custom_css()
# Initialize session states
if 'pipeline' not in st.session_state:
st.session_state.pipeline = None
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
if 'context_memory' not in st.session_state:
st.session_state.context_memory = []
if 'processing_query' not in st.session_state:
st.session_state.processing_query = False
# Initialize pipeline with enhanced configuration
if st.session_state.pipeline is None:
with st.spinner("กำลังเริ่มต้นระบบ..."):
st.session_state.pipeline = initialize_pipeline()
# Header
st.markdown("""
<div style="text-align: center; padding: 2rem 0;">
<h1>🎓 ผู้ช่วยค้นหาข้อมูลหลักสูตรและปฏิทินการศึกษา</h1>
<p style="font-size: 1.2rem; color: #666;">บัณฑิตวิทยาลัย มหาวิทยาลัยศรีนครินทรวิโรฒ</p>
</div>
""", unsafe_allow_html=True)
chat_col, info_col = st.columns([7, 3])
with chat_col:
display_chat_history()
create_chat_input()
# Info column
with info_col:
st.markdown("""
<div style="background-color: #F9FAFB; padding: 1.5rem; border-radius: 12px; margin-bottom: 2rem;">
<h3 style="color: #1E3A8A;">ℹ️ เกี่ยวกับระบบ</h3>
<p style="color: #000000;">
ระบบนี้ใช้เทคโนโลยี <strong>RAG (Retrieval-Augmented Generation)</strong>
ในการค้นหาและตอบคำถามเกี่ยวกับหลักสูตรและปฏิทินการศึกษา
</p>
<h4 style="color: #1E3A8A; margin-top: 1rem;">สามารถสอบถามข้อมูลเกี่ยวกับ:</h4>
<ul style="list-style-type: none; padding-left: 0;">
<li style="color: #000000; margin-bottom: 0.5rem;">📚 รายวิชาในหลักสูตร</li>
<li style="color: #000000; margin-bottom: 0.5rem;">📝 การลงทะเบียนเรียน</li>
<li style="color: #000000; margin-bottom: 0.5rem;">📅 กำหนดการต่างๆ</li>
<li style="color: #000000; margin-bottom: 0.5rem;">💰 ค่าธรรมเนียมการศึกษา</li>
<li style="color: #000000; margin-bottom: 0.5rem;">📋 ขั้นตอนการสมัคร</li>
</ul>
</div>
""", unsafe_allow_html=True)
st.markdown("""
<div style="background-color: #f9fafb; padding: 1.5rem; border-radius: 12px;">
<h3 style="color: #1E3A8A;">🔄 สถานะระบบ</h3>
<div style="margin-top: 1rem;">
<p><strong style="color: #000000;">⏰ เวลาปัจจุบัน:</strong><br>
<span style="color: #000000;">{}</span></p>
<p><strong style="color: #000000;">📡 สถานะระบบ:</strong><br>
<span class="status-indicator {}">
{} {}
</span></p>
</div>
</div>
""".format(
(datetime.now() + timedelta(hours=5)).strftime('%Y-%m-%d %H:%M:%S'),
"status-online" if st.session_state.pipeline else "status-offline",
"🟢" if st.session_state.pipeline else "🔴",
"พร้อมใช้งาน" if st.session_state.pipeline else "ไม่พร้อมใช้งาน"
), unsafe_allow_html=True)
if __name__ == "__main__":
main()