JirasakJo's picture
Update app.py
9314e12 verified
raw
history blame
25.9 kB
import streamlit as st
import json
import os
import requests
import json
import base64
from datetime import datetime, timedelta
import subprocess
from huggingface_hub import HfApi
from pathlib import Path
from calendar_rag import (
create_default_config,
AcademicCalendarRAG,
PipelineConfig,
ModelConfig,
RetrieverConfig,
CacheConfig,
ProcessingConfig,
LocalizationConfig
)
# Custom CSS for enhanced styling
def load_custom_css():
st.markdown("""
<style>
/* General body styling */
body {
font-family: "Arial", sans-serif !important;
color: #000000 !important;
background-color: white !important;
line-height: 1.7 !important;
}
/* Main container styling */
.main {
padding: 2rem;
color: #000000;
background-color: white;
}
/* Headers styling */
h1 {
color: #000000;
font-size: 2.8rem !important;
font-weight: 700 !important;
margin-bottom: 1.5rem !important;
text-align: center;
padding: 1rem 0;
border-bottom: 3px solid #1E3A8A;
}
h3, h4 {
color: #000000;
font-weight: 600 !important;
font-size: 1.6rem !important;
margin-top: 1.5rem !important;
}
/* Chat message styling */
.chat-message {
padding: 1.5rem;
border-radius: 10px;
margin: 1rem 0;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
font-size: 1.1rem !important;
line-height: 1.6 !important;
font-family: "Arial", sans-serif !important;
color: #000000 !important;
}
.user-message {
background-color: #F3F4F6 !important;
}
.assistant-message {
background-color: #EFF6FF !important;
}
/* Status indicators */
.status-indicator {
padding: 0.5rem 1rem;
border-radius: 6px;
font-weight: 500;
font-size: 1.2rem;
color: #000000;
}
.status-online {
background-color: #DEF7EC;
color: #03543F;
}
.status-offline {
background-color: #FDE8E8;
color: rgb(255, 255, 255);
}
</style>
""", unsafe_allow_html=True)
def clear_conversation_context():
"""Clear conversation context but keep chat display history"""
# Clear the RAG pipeline's conversation history
if 'pipeline' in st.session_state and st.session_state.pipeline:
st.session_state.pipeline.conversation_history = []
# Clear the context memory
st.session_state.context_memory = []
# Note: We keep st.session_state.chat_history for UI display purposes
def initialize_pipeline():
"""Initialize RAG pipeline with conversation memory support"""
try:
# Get API key from environment or secrets
openai_api_key = os.getenv('OPENAI_API_KEY') or st.secrets['OPENAI_API_KEY']
# Create config with same settings as main()
config = create_default_config(openai_api_key)
# Create pipeline
pipeline = AcademicCalendarRAG(config)
# Load raw data instead of calendar.json
try:
with open("calendar.json", "r", encoding="utf-8") as f:
raw_data = json.load(f)
pipeline.load_data(raw_data)
# Initialize conversation history from session state if available
if 'context_memory' in st.session_state and st.session_state.context_memory:
# Convert context memory to conversation history format
conversation_history = []
for item in st.session_state.context_memory:
conversation_history.append({"role": "user", "content": item["query"]})
conversation_history.append({"role": "assistant", "content": item["response"]})
pipeline.conversation_history = conversation_history
return pipeline
except FileNotFoundError:
st.error("calendar.json not found. Please ensure the file exists in the same directory.")
return None
except Exception as e:
st.error(f"Error initializing pipeline: {str(e)}")
return None
def load_qa_history():
"""Load QA history directly from GitHub repository"""
try:
import requests
import base64
import json
# GitHub API configuration
REPO_OWNER = "jirasaksaimekJijo"
REPO_NAME = "swu-chat-bot-project"
FILE_PATH = "qa_history.json"
GITHUB_TOKEN = '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'
# Set up GitHub API request
api_url = f"https://api.github.com/repos/{REPO_OWNER}/{REPO_NAME}/contents/{FILE_PATH}"
headers = {"Accept": "application/vnd.github.v3+json"}
if GITHUB_TOKEN:
headers["Authorization"] = f"token {GITHUB_TOKEN}"
# Make the request to GitHub API
response = requests.get(api_url, headers=headers)
if response.status_code == 200:
# Decode the content from base64
content_data = response.json()
file_content = base64.b64decode(content_data["content"]).decode("utf-8")
# Parse JSON
history_data = json.loads(file_content)
return history_data
else:
st.warning(f"Failed to fetch QA history: {response.status_code} - {response.reason}")
# Return empty list if file doesn't exist or can't be accessed
return []
except Exception as e:
st.error(f"Error loading QA history from GitHub: {str(e)}")
return []
def save_qa_history(history_entry):
"""Save QA history entry to local JSON file and push to GitHub"""
try:
import requests
import base64
import json
from pathlib import Path
# GitHub API configuration
REPO_OWNER = "jirasaksaimekJijo"
REPO_NAME = "swu-chat-bot-project"
FILE_PATH = "qa_history.json"
GITHUB_TOKEN = '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'
# First, load existing data from GitHub
api_url = f"https://api.github.com/repos/{REPO_OWNER}/{REPO_NAME}/contents/{FILE_PATH}"
headers = {
"Accept": "application/vnd.github.v3+json",
"Authorization": f"token {GITHUB_TOKEN}"
}
# Try to get existing file first
response = requests.get(api_url, headers=headers)
# Initialize empty history data
history_data = []
sha = None
if response.status_code == 200:
# File exists, get its content and SHA
content_data = response.json()
sha = content_data["sha"]
try:
# Decode and parse existing content
file_content = base64.b64decode(content_data["content"]).decode("utf-8")
if file_content.strip(): # Make sure content is not empty
history_data = json.loads(file_content)
# Ensure history_data is a list
if not isinstance(history_data, list):
st.warning("Existing history data is not a list. Initializing new list.")
history_data = []
except Exception as e:
st.warning(f"Error parsing existing history: {e}. Initializing new list.")
elif response.status_code == 404:
# File doesn't exist yet
st.info("Creating new QA history file.")
else:
st.error(f"Failed to check existing history: {response.status_code} - {response.reason}")
# Process history entry before appending
if isinstance(history_entry, dict) and all(key in history_entry for key in ["timestamp", "query", "answer"]):
# Process answer if it's a dict
if isinstance(history_entry["answer"], dict):
history_entry["answer"] = history_entry["answer"].get('answer', str(history_entry["answer"]))
# Process answer if it's a Document-like object
elif hasattr(history_entry["answer"], 'content'):
history_entry["answer"] = history_entry["answer"].content
# Convert to string for any other type
else:
history_entry["answer"] = str(history_entry["answer"])
# Append new entry to history data
history_data.append(history_entry)
# Also save locally for backup
try:
local_path = Path("qa_history.json")
with open(local_path, "w", encoding="utf-8") as f:
json.dump(history_data, f, ensure_ascii=False, indent=2)
except Exception as local_err:
st.warning(f"Failed to save local backup: {local_err}")
# Prepare content for GitHub
updated_content = json.dumps(history_data, ensure_ascii=False, indent=2)
encoded_content = base64.b64encode(updated_content.encode('utf-8')).decode('utf-8')
# Prepare the update/create payload
data = {
"message": "Update QA history",
"content": encoded_content,
}
if sha: # If file exists, include its SHA
data["sha"] = sha
# Update or create the file
update_response = requests.put(api_url, headers=headers, json=data)
if update_response.status_code in [200, 201]:
return True
else:
st.error(f"Failed to update QA history: {update_response.status_code} - {update_response.text}")
return False
except Exception as e:
import traceback
st.error(f"Error in save_qa_history: {str(e)}")
st.error(f"Traceback: {traceback.format_exc()}")
return False
def add_to_qa_history(query: str, answer: str):
"""Add new QA pair to history with validation"""
try:
# Validate inputs
if not query or not answer:
st.warning("Empty query or answer detected, skipping history update")
return None
# Handle different answer types
if isinstance(answer, dict):
# If answer is a dict with 'answer' key, extract it
processed_answer = answer.get('answer', str(answer))
elif hasattr(answer, 'content'):
# If answer is a Document-like object with content attribute
processed_answer = answer.content
else:
# Convert answer to string for any other type
processed_answer = str(answer)
# Create history entry with proper timestamp
history_entry = {
"timestamp": (datetime.now() + timedelta(hours=5)).strftime("%Y-%m-%dT%H:%M:%S"),
"query": query,
"answer": processed_answer
}
# Save entry
save_qa_history(history_entry)
return history_entry
except Exception as e:
st.error(f"Error in add_to_qa_history: {str(e)}")
return None
def add_to_history(role: str, message: str):
"""Add message to chat history, save if it's a complete QA pair, and update context memory"""
st.session_state.chat_history.append((role, message))
# If this is an assistant response, save the QA pair
if role == "assistant" and len(st.session_state.chat_history) >= 2:
# Get the corresponding user query (previous message)
user_query = st.session_state.chat_history[-2][1]
# Process and save the QA pair
history_entry = add_to_qa_history(user_query, message)
# Also update context memory if needed
if 'context_memory' not in st.session_state:
st.session_state.context_memory = []
# Format response content for context memory
if isinstance(message, dict) and "answer" in message:
response_content = message["answer"]
else:
response_content = message
st.session_state.context_memory.append({
"query": user_query,
"response": response_content,
"timestamp": (datetime.now() + timedelta(hours=5)).strftime("%Y-%m-%dT%H:%M:%S")
})
# Limit context memory size to prevent performance issues
if len(st.session_state.context_memory) > 10: # Keep last 10 exchanges
st.session_state.context_memory = st.session_state.context_memory[-10:]
def display_chat_history():
"""Display chat history with improved document display"""
for role, content in st.session_state.chat_history:
if role == "user":
st.markdown(f"""
<div class="chat-message user-message">
<strong>🧑 คำถาม:</strong><br>
{content}
</div>
""", unsafe_allow_html=True)
else:
if isinstance(content, dict):
assistant_response = content.get('answer', '❌ ไม่มีข้อมูลคำตอบ')
st.markdown(f"""
<div class="chat-message assistant-message">
<strong>🤖 คำตอบ:</strong><br>
{assistant_response}
</div>
""", unsafe_allow_html=True)
# Show reference documents like in main()
if content.get('documents'):
with st.expander("📚 ข้อมูลอ้างอิง", expanded=False):
for i, doc in enumerate(content['documents'], 1):
st.markdown(f"""
<div style="padding: 1rem; background-color: #000000; border-radius: 8px; margin: 0.5rem 0;">
<strong>เอกสารที่ {i}:</strong><br>
{doc.content}
</div>
""", unsafe_allow_html=True)
else:
st.markdown(f"""
<div class="chat-message assistant-message">
<strong>🤖 คำตอบ:</strong><br>
{content}
</div>
""", unsafe_allow_html=True)
if 'context_memory' not in st.session_state:
st.session_state.context_memory = []
def handle_submit(user_query: str):
"""Enhanced query handling with improved conversation history tracking"""
if not user_query:
st.warning("⚠️ กรุณาระบุคำถาม")
return
user_query = user_query.strip()
# Prevent duplicate submissions
if not st.session_state.chat_history or st.session_state.chat_history[-1][1] != user_query:
try:
st.session_state.processing_query = True
# Add user message to chat history
add_to_history("user", user_query)
# Convert the Streamlit chat history format to RAG format
# Include more context history (up to 5 previous exchanges)
rag_conversation_history = []
history_to_include = st.session_state.chat_history[-11:] if len(st.session_state.chat_history) > 10 else st.session_state.chat_history
for role, content in history_to_include:
# Map Streamlit's role names to the format used in the RAG system
rag_role = "user" if role == "user" else "assistant"
# Handle content based on type
if isinstance(content, dict) and "answer" in content:
rag_content = content["answer"]
else:
rag_content = content
rag_conversation_history.append({"role": rag_role, "content": rag_content})
# Store conversation context in session state
if 'context_memory' not in st.session_state:
st.session_state.context_memory = []
# Process query with improved conversation history
with st.spinner("🔍 กำลังค้นหาคำตอบ..."):
# Add debug logging to verify context
print(f"Processing query with {len(rag_conversation_history)} context messages")
# Add special handling for reference questions
reference_keywords = ["ก่อนหน้านี้", "ก่อนหน้า", "ที่ผ่านมา", "คำถามก่อนหน้า", "คำถามที่แล้ว",
"previous", "earlier", "before", "last time", "last question"]
is_reference_question = any(keyword in user_query.lower() for keyword in reference_keywords)
# If this is a reference question, emphasize context in the query
if is_reference_question and len(rag_conversation_history) >= 3:
# Extract the previous user question (should be 2 positions back)
previous_questions = [msg["content"] for msg in rag_conversation_history[:-2]
if msg["role"] == "user"]
if previous_questions:
prev_question = previous_questions[-1]
enhanced_query = f"คำถามนี้อ้างอิงถึงคำถามก่อนหน้า '{prev_question}' โปรดพิจารณาบริบทนี้ในการตอบ: {user_query}"
print(f"Enhanced reference query: {enhanced_query}")
user_query = enhanced_query
result = st.session_state.pipeline.process_query(
query=user_query,
conversation_history=rag_conversation_history
)
# Create response with same structure as main()
response_dict = {
"answer": result.get("answer", ""),
"documents": result.get("relevant_docs", [])
}
# Update chat history and context
add_to_history("assistant", response_dict)
# Add this exchange to context memory for future reference
st.session_state.context_memory.append({
"query": user_query,
"response": response_dict["answer"],
"timestamp": datetime.now().isoformat()
})
except Exception as e:
error_msg = f"❌ เกิดข้อผิดพลาด: {str(e)}"
add_to_history("assistant", error_msg)
st.error(f"Query processing error: {e}")
finally:
st.session_state.processing_query = False
st.rerun()
def create_chat_input():
"""Create chat input with enhanced configuration and combined clear button"""
with st.form(key="chat_form", clear_on_submit=True):
st.markdown("""
<label for="query_input" style="font-size: 1.2rem; font-weight: 600; margin-bottom: 1rem; display: block;">
<span style="color: #ffffff; border-left: 4px solid #ffffff; padding-left: 0.8rem;">
โปรดระบุคำถามเกี่ยวกับปฏิทินการศึกษา:
</span>
</label>
""", unsafe_allow_html=True)
query = st.text_input(
"",
key="query_input",
placeholder="เช่น: วิชาเลือกมีอะไรบ้าง?"
)
col1, col2 = st.columns([5, 5])
with col1:
submitted = st.form_submit_button(
"📤 ส่งคำถาม",
type="primary",
use_container_width=True
)
with col2:
clear_all_button = st.form_submit_button(
"🗑️ ล้างประวัติและบริบทสนทนา",
type="secondary",
use_container_width=True
)
if submitted:
handle_submit(query)
if clear_all_button:
# Clear chat history
st.session_state.chat_history = []
# Clear conversation context
clear_conversation_context()
st.info("ล้างประวัติและบริบทสนทนาแล้ว")
st.rerun()
def main():
# Page config
st.set_page_config(
page_title="Academic Calendar Assistant",
page_icon="📅",
layout="wide",
initial_sidebar_state="collapsed"
)
# Load custom CSS
load_custom_css()
# Initialize session states
if 'pipeline' not in st.session_state:
st.session_state.pipeline = None
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
if 'context_memory' not in st.session_state:
st.session_state.context_memory = []
if 'processing_query' not in st.session_state:
st.session_state.processing_query = False
# Initialize pipeline with enhanced configuration
if st.session_state.pipeline is None:
with st.spinner("กำลังเริ่มต้นระบบ..."):
st.session_state.pipeline = initialize_pipeline()
# Header
st.markdown("""
<div style="text-align: center; padding: 2rem 0;">
<h1>🎓 ผู้ช่วยค้นหาข้อมูลหลักสูตรและปฏิทินการศึกษา</h1>
<p style="font-size: 1.2rem; color: #666;">บัณฑิตวิทยาลัย มหาวิทยาลัยศรีนครินทรวิโรฒ</p>
</div>
""", unsafe_allow_html=True)
chat_col, info_col = st.columns([7, 3])
with chat_col:
display_chat_history()
create_chat_input()
# Info column
with info_col:
st.markdown("""
<div style="background-color: #F9FAFB; padding: 1.5rem; border-radius: 12px; margin-bottom: 2rem;">
<h3 style="color: #1E3A8A;">ℹ️ เกี่ยวกับระบบ</h3>
<p style="color: #000000;">
ระบบนี้ใช้เทคโนโลยี <strong>RAG (Retrieval-Augmented Generation)</strong>
ในการค้นหาและตอบคำถามเกี่ยวกับหลักสูตรและปฏิทินการศึกษา
</p>
<h4 style="color: #1E3A8A; margin-top: 1rem;">สามารถสอบถามข้อมูลเกี่ยวกับ:</h4>
<ul style="list-style-type: none; padding-left: 0;">
<li style="color: #000000; margin-bottom: 0.5rem;">📚 รายวิชาในหลักสูตร</li>
<li style="color: #000000; margin-bottom: 0.5rem;">📝 การลงทะเบียนเรียน</li>
<li style="color: #000000; margin-bottom: 0.5rem;">📅 กำหนดการต่างๆ</li>
<li style="color: #000000; margin-bottom: 0.5rem;">💰 ค่าธรรมเนียมการศึกษา</li>
<li style="color: #000000; margin-bottom: 0.5rem;">📋 ขั้นตอนการสมัคร</li>
</ul>
</div>
""", unsafe_allow_html=True)
st.markdown("""
<div style="background-color: #f9fafb; padding: 1.5rem; border-radius: 12px;">
<h3 style="color: #1E3A8A;">🔄 สถานะระบบ</h3>
<div style="margin-top: 1rem;">
<p><strong style="color: #000000;">⏰ เวลาปัจจุบัน:</strong><br>
<span style="color: #000000;">{}</span></p>
<p><strong style="color: #000000;">📡 สถานะระบบ:</strong><br>
<span class="status-indicator {}">
{} {}
</span></p>
</div>
</div>
""".format(
(datetime.now() + timedelta(hours=5)).strftime('%Y-%m-%d %H:%M:%S'),
"status-online" if st.session_state.pipeline else "status-offline",
"🟢" if st.session_state.pipeline else "🔴",
"พร้อมใช้งาน" if st.session_state.pipeline else "ไม่พร้อมใช้งาน"
), unsafe_allow_html=True)
if __name__ == "__main__":
main()