RAG-BITS-Tutor / app.py
patronmoses's picture
Update app.py
75632fd verified
# app.py
import streamlit as st
import os
import faiss
import pickle
from sentence_transformers import SentenceTransformer
from groq import Groq
from dotenv import load_dotenv
import re # Import regular expressions for expand_query_with_llm_app
# --- Page Configuration (MUST BE THE FIRST STREAMLIT COMMAND) ---
st.set_page_config(page_title="RAG BITS Tutor", page_icon="🎓")
# --- Configuration and Model Loading (best outside functions to use caching) ---
@st.cache_resource # Important for caching large models and data
def load_models_and_data():
# Load environment variables (if .env file is present in the Space)
load_dotenv()
groq_api_key_app = os.getenv("GROQ_API_KEY")
# Paths to the index and chunks
output_folder = "faiss_index_bits" # Must be present in the HF Space
index_path = os.path.join(output_folder, "bits_tutor.index")
chunks_path = os.path.join(output_folder, "bits_chunks.pkl")
# Load FAISS Index
if not os.path.exists(index_path):
st.error(f"FAISS Index not found at: {index_path}")
return None, None, None, None
index_loaded = faiss.read_index(index_path)
# Load Chunks
if not os.path.exists(chunks_path):
st.error(f"Chunks file not found at: {chunks_path}")
return None, None, None, None
with open(chunks_path, "rb") as f:
chunks_loaded = pickle.load(f)
# Load Embedding Model
embedding_model_name_app = "Sahajtomar/German-semantic"
embedding_model_loaded = SentenceTransformer(embedding_model_name_app)
# Initialize Groq Client
if not groq_api_key_app:
st.error("GROQ_API_KEY not found. Please add it as a Secret in the Hugging Face Space settings.")
return None, None, None, None
groq_client_loaded = Groq(api_key=groq_api_key_app)
return index_loaded, chunks_loaded, embedding_model_loaded, groq_client_loaded
# Load models and data when the app starts
# Important: The function load_models_and_data() uses st.error(), which is a Streamlit command.
# Therefore, st.set_page_config() must be called BEFORE the first possible call to st.error().
faiss_index, chunks_data, embedding_model, groq_client = load_models_and_data()
# --- Core RAG Functions (adapted for Streamlit app) ---
def retrieve_relevant_chunks_app(query, k=5):
# This function uses the globally loaded embedding_model, faiss_index, and chunks_data
if embedding_model is None or faiss_index is None or chunks_data is None:
st.warning("Models or data not loaded correctly. Cannot retrieve chunks.")
return []
query_embedding = embedding_model.encode([query], convert_to_numpy=True)
distances, indices = faiss_index.search(query_embedding, k)
retrieved_chunks_data = [(chunks_data[i], distances[0][j]) for j, i in enumerate(indices[0])]
return retrieved_chunks_data
def generate_answer_app(query, retrieved_chunks_data):
# This function uses the globally loaded groq_client
if groq_client is None:
st.error("Groq Client not initialized. Cannot generate answer.")
return "Error: LLM client not available."
context = "\n\n".join([chunk_text for chunk_text, dist in retrieved_chunks_data])
# This prompt_template remains in German as it instructs the LLM for the German-speaking tutor
prompt_template = f"""Beantworte die folgende Frage ausschließlich basierend auf dem bereitgestellten Kontext aus den Lehrmaterialien zur Business IT Strategie.
Antworte auf Deutsch.
Kontext:
{context}
Frage: {query}
Antwort:
"""
try:
chat_completion = groq_client.chat.completions.create(
messages=[{"role": "user", "content": prompt_template}],
model="llama3-70b-8192",
temperature=0.3,
)
return chat_completion.choices[0].message.content
except Exception as e:
# Developer-facing error, UI will show a generic message or this
st.error(f"Error during LLM request in generate_answer_app: {e}")
return "An error occurred while generating the answer."
def expand_query_with_llm_app(original_query, llm_client_app):
# This function uses the passed llm_client_app (which should be the global groq_client)
if llm_client_app is None:
st.warning("LLM client for query expansion not initialized.")
return [original_query]
# This prompt_template remains in German
prompt_template_expansion = f"""Gegeben ist die folgende Nutzerfrage zum Thema "Business IT Strategie": "{original_query}"
Bitte generiere 2-3 alternative Formulierungen dieser Frage ODER eine Liste von 3-5 sehr relevanten Schlüsselbegriffen/Konzepten,
die helfen würden, in einer Wissensdatenbank nach Antworten zu dieser Frage zu suchen.
Formatiere die Ausgabe klar, z.B. als nummerierte Liste für alternative Fragen oder als kommaseparierte Liste für Schlüsselbegriffe.
Gib NUR die alternativen Formulierungen oder die Schlüsselbegriffe aus. Keine Einleitungssätze.
"""
try:
chat_completion = llm_client_app.chat.completions.create(
messages=[
{
"role": "user",
"content": prompt_template_expansion,
}
],
model="llama3-8b-8192",
temperature=0.5,
)
expanded_terms_text = chat_completion.choices[0].message.content
cleaned_queries = []
potential_queries = expanded_terms_text.split('\n')
for line in potential_queries:
line = line.strip()
line = re.sub(r"^\s*\d+\.\s*", "", line)
line = re.sub(r"^\s*[-\*]\s*", "", line)
line = line.strip()
if not line or \
line.lower().startswith("here are") or \
line.lower().startswith("sicher, hier sind") or \
line.lower().startswith("alternative formulierungen:") or \
line.lower().startswith("*alternative formulierungen:**") or \
len(line) < 5:
continue
cleaned_queries.append(line)
if len(cleaned_queries) == 1 and ',' in cleaned_queries[0] and len(cleaned_queries[0].split(',')) > 1:
final_expanded_list = [term.strip() for term in cleaned_queries[0].split(',') if term.strip() and len(term.strip()) > 4]
else:
final_expanded_list = cleaned_queries
all_queries = [original_query]
for q_exp in final_expanded_list:
is_duplicate = False
for q_all in all_queries:
if q_all.lower() == q_exp.lower():
is_duplicate = True
break
if not is_duplicate:
all_queries.append(q_exp)
return all_queries[:4]
except Exception as e:
st.warning(f"Error during Query Expansion with LLM: {e}")
return [original_query]
def retrieve_with_expanded_queries_app(original_query, llm_client_app, retrieve_func, k_per_expansion=2):
expanded_queries = expand_query_with_llm_app(original_query, llm_client_app)
# For UI feedback / debugging
# st.write(f"Using the following queries for retrieval after expansion:")
# for i, eq_query in enumerate(expanded_queries):
# st.caption(f" ExpQuery {i}: {eq_query}")
all_retrieved_chunks_data = []
for eq_query in expanded_queries:
retrieved_for_eq = retrieve_func(eq_query, k=k_per_expansion)
all_retrieved_chunks_data.extend(retrieved_for_eq)
unique_chunks_dict = {}
for chunk_text, distance in all_retrieved_chunks_data:
if chunk_text not in unique_chunks_dict or distance < unique_chunks_dict[chunk_text]:
unique_chunks_dict[chunk_text] = distance
sorted_unique_chunks_data = sorted(unique_chunks_dict.items(), key=lambda item: item[1])
final_chunks_for_context = sorted_unique_chunks_data[:5]
# For UI feedback / debugging
# st.write(f"\n{len(final_chunks_for_context)} unique chunks were selected for the context.")
return final_chunks_for_context
# --- Streamlit UI ---
st.title("🎓 RAG Study Tutor for Business IT Strategy")
st.write("Ask your questions about the content of the lecture notes and case studies (in German).")
# User query input field (remains German for the user)
user_query_streamlit = st.text_input("Deine Frage:", "")
# Option to use query expansion
use_expansion = st.checkbox("Use Query Expansion (may improve results for some questions)", value=True)
if user_query_streamlit:
# Check if models and data are loaded successfully before proceeding
if faiss_index and chunks_data and embedding_model and groq_client:
with st.spinner("Searching for relevant information and generating answer..."): # Loading spinner
retrieved_chunks = []
if use_expansion:
st.caption("Query expansion is active...")
retrieved_chunks = retrieve_with_expanded_queries_app(user_query_streamlit, groq_client, retrieve_relevant_chunks_app, k_per_expansion=2)
else:
st.caption("Direct retrieval...")
retrieved_chunks = retrieve_relevant_chunks_app(user_query_streamlit, k=3)
if retrieved_chunks:
# Optional display of retrieved context snippets (for debugging or transparency)
# with st.expander("Show retrieved context snippets (German)"):
# for i, (chunk, dist) in enumerate(retrieved_chunks):
# st.caption(f"Chunk {i+1} (Distance: {dist:.2f})")
# st.markdown(f"_{chunk[:200]}..._")
# st.divider()
answer = generate_answer_app(user_query_streamlit, retrieved_chunks)
st.subheader("Tutor's Answer:")
st.markdown(answer) # Displaying German answer
else:
st.warning("No relevant information could be found for your query.")
else:
st.error("The application could not be initialized correctly. Please check for error messages related to model or data loading.")
st.sidebar.header("About this Project")
st.sidebar.info(
"This RAG application was developed as part of the 'AI Applications' module. "
"It uses Sentence Transformers for embeddings, FAISS for vector search, "
"and an LLM via Groq for answer generation."
)