Spaces:
Sleeping
Sleeping
# app.py | |
import streamlit as st | |
import os | |
import faiss | |
import pickle | |
from sentence_transformers import SentenceTransformer | |
from groq import Groq | |
from dotenv import load_dotenv | |
import re # Import regular expressions for expand_query_with_llm_app | |
# --- Page Configuration (MUST BE THE FIRST STREAMLIT COMMAND) --- | |
st.set_page_config(page_title="RAG BITS Tutor", page_icon="🎓") | |
# --- Configuration and Model Loading (best outside functions to use caching) --- | |
# Important for caching large models and data | |
def load_models_and_data(): | |
# Load environment variables (if .env file is present in the Space) | |
load_dotenv() | |
groq_api_key_app = os.getenv("GROQ_API_KEY") | |
# Paths to the index and chunks | |
output_folder = "faiss_index_bits" # Must be present in the HF Space | |
index_path = os.path.join(output_folder, "bits_tutor.index") | |
chunks_path = os.path.join(output_folder, "bits_chunks.pkl") | |
# Load FAISS Index | |
if not os.path.exists(index_path): | |
st.error(f"FAISS Index not found at: {index_path}") | |
return None, None, None, None | |
index_loaded = faiss.read_index(index_path) | |
# Load Chunks | |
if not os.path.exists(chunks_path): | |
st.error(f"Chunks file not found at: {chunks_path}") | |
return None, None, None, None | |
with open(chunks_path, "rb") as f: | |
chunks_loaded = pickle.load(f) | |
# Load Embedding Model | |
embedding_model_name_app = "Sahajtomar/German-semantic" | |
embedding_model_loaded = SentenceTransformer(embedding_model_name_app) | |
# Initialize Groq Client | |
if not groq_api_key_app: | |
st.error("GROQ_API_KEY not found. Please add it as a Secret in the Hugging Face Space settings.") | |
return None, None, None, None | |
groq_client_loaded = Groq(api_key=groq_api_key_app) | |
return index_loaded, chunks_loaded, embedding_model_loaded, groq_client_loaded | |
# Load models and data when the app starts | |
# Important: The function load_models_and_data() uses st.error(), which is a Streamlit command. | |
# Therefore, st.set_page_config() must be called BEFORE the first possible call to st.error(). | |
faiss_index, chunks_data, embedding_model, groq_client = load_models_and_data() | |
# --- Core RAG Functions (adapted for Streamlit app) --- | |
def retrieve_relevant_chunks_app(query, k=5): | |
# This function uses the globally loaded embedding_model, faiss_index, and chunks_data | |
if embedding_model is None or faiss_index is None or chunks_data is None: | |
st.warning("Models or data not loaded correctly. Cannot retrieve chunks.") | |
return [] | |
query_embedding = embedding_model.encode([query], convert_to_numpy=True) | |
distances, indices = faiss_index.search(query_embedding, k) | |
retrieved_chunks_data = [(chunks_data[i], distances[0][j]) for j, i in enumerate(indices[0])] | |
return retrieved_chunks_data | |
def generate_answer_app(query, retrieved_chunks_data): | |
# This function uses the globally loaded groq_client | |
if groq_client is None: | |
st.error("Groq Client not initialized. Cannot generate answer.") | |
return "Error: LLM client not available." | |
context = "\n\n".join([chunk_text for chunk_text, dist in retrieved_chunks_data]) | |
# This prompt_template remains in German as it instructs the LLM for the German-speaking tutor | |
prompt_template = f"""Beantworte die folgende Frage ausschließlich basierend auf dem bereitgestellten Kontext aus den Lehrmaterialien zur Business IT Strategie. | |
Antworte auf Deutsch. | |
Kontext: | |
{context} | |
Frage: {query} | |
Antwort: | |
""" | |
try: | |
chat_completion = groq_client.chat.completions.create( | |
messages=[{"role": "user", "content": prompt_template}], | |
model="llama3-70b-8192", | |
temperature=0.3, | |
) | |
return chat_completion.choices[0].message.content | |
except Exception as e: | |
# Developer-facing error, UI will show a generic message or this | |
st.error(f"Error during LLM request in generate_answer_app: {e}") | |
return "An error occurred while generating the answer." | |
def expand_query_with_llm_app(original_query, llm_client_app): | |
# This function uses the passed llm_client_app (which should be the global groq_client) | |
if llm_client_app is None: | |
st.warning("LLM client for query expansion not initialized.") | |
return [original_query] | |
# This prompt_template remains in German | |
prompt_template_expansion = f"""Gegeben ist die folgende Nutzerfrage zum Thema "Business IT Strategie": "{original_query}" | |
Bitte generiere 2-3 alternative Formulierungen dieser Frage ODER eine Liste von 3-5 sehr relevanten Schlüsselbegriffen/Konzepten, | |
die helfen würden, in einer Wissensdatenbank nach Antworten zu dieser Frage zu suchen. | |
Formatiere die Ausgabe klar, z.B. als nummerierte Liste für alternative Fragen oder als kommaseparierte Liste für Schlüsselbegriffe. | |
Gib NUR die alternativen Formulierungen oder die Schlüsselbegriffe aus. Keine Einleitungssätze. | |
""" | |
try: | |
chat_completion = llm_client_app.chat.completions.create( | |
messages=[ | |
{ | |
"role": "user", | |
"content": prompt_template_expansion, | |
} | |
], | |
model="llama3-8b-8192", | |
temperature=0.5, | |
) | |
expanded_terms_text = chat_completion.choices[0].message.content | |
cleaned_queries = [] | |
potential_queries = expanded_terms_text.split('\n') | |
for line in potential_queries: | |
line = line.strip() | |
line = re.sub(r"^\s*\d+\.\s*", "", line) | |
line = re.sub(r"^\s*[-\*]\s*", "", line) | |
line = line.strip() | |
if not line or \ | |
line.lower().startswith("here are") or \ | |
line.lower().startswith("sicher, hier sind") or \ | |
line.lower().startswith("alternative formulierungen:") or \ | |
line.lower().startswith("*alternative formulierungen:**") or \ | |
len(line) < 5: | |
continue | |
cleaned_queries.append(line) | |
if len(cleaned_queries) == 1 and ',' in cleaned_queries[0] and len(cleaned_queries[0].split(',')) > 1: | |
final_expanded_list = [term.strip() for term in cleaned_queries[0].split(',') if term.strip() and len(term.strip()) > 4] | |
else: | |
final_expanded_list = cleaned_queries | |
all_queries = [original_query] | |
for q_exp in final_expanded_list: | |
is_duplicate = False | |
for q_all in all_queries: | |
if q_all.lower() == q_exp.lower(): | |
is_duplicate = True | |
break | |
if not is_duplicate: | |
all_queries.append(q_exp) | |
return all_queries[:4] | |
except Exception as e: | |
st.warning(f"Error during Query Expansion with LLM: {e}") | |
return [original_query] | |
def retrieve_with_expanded_queries_app(original_query, llm_client_app, retrieve_func, k_per_expansion=2): | |
expanded_queries = expand_query_with_llm_app(original_query, llm_client_app) | |
# For UI feedback / debugging | |
# st.write(f"Using the following queries for retrieval after expansion:") | |
# for i, eq_query in enumerate(expanded_queries): | |
# st.caption(f" ExpQuery {i}: {eq_query}") | |
all_retrieved_chunks_data = [] | |
for eq_query in expanded_queries: | |
retrieved_for_eq = retrieve_func(eq_query, k=k_per_expansion) | |
all_retrieved_chunks_data.extend(retrieved_for_eq) | |
unique_chunks_dict = {} | |
for chunk_text, distance in all_retrieved_chunks_data: | |
if chunk_text not in unique_chunks_dict or distance < unique_chunks_dict[chunk_text]: | |
unique_chunks_dict[chunk_text] = distance | |
sorted_unique_chunks_data = sorted(unique_chunks_dict.items(), key=lambda item: item[1]) | |
final_chunks_for_context = sorted_unique_chunks_data[:5] | |
# For UI feedback / debugging | |
# st.write(f"\n{len(final_chunks_for_context)} unique chunks were selected for the context.") | |
return final_chunks_for_context | |
# --- Streamlit UI --- | |
st.title("🎓 RAG Study Tutor for Business IT Strategy") | |
st.write("Ask your questions about the content of the lecture notes and case studies (in German).") | |
# User query input field (remains German for the user) | |
user_query_streamlit = st.text_input("Deine Frage:", "") | |
# Option to use query expansion | |
use_expansion = st.checkbox("Use Query Expansion (may improve results for some questions)", value=True) | |
if user_query_streamlit: | |
# Check if models and data are loaded successfully before proceeding | |
if faiss_index and chunks_data and embedding_model and groq_client: | |
with st.spinner("Searching for relevant information and generating answer..."): # Loading spinner | |
retrieved_chunks = [] | |
if use_expansion: | |
st.caption("Query expansion is active...") | |
retrieved_chunks = retrieve_with_expanded_queries_app(user_query_streamlit, groq_client, retrieve_relevant_chunks_app, k_per_expansion=2) | |
else: | |
st.caption("Direct retrieval...") | |
retrieved_chunks = retrieve_relevant_chunks_app(user_query_streamlit, k=3) | |
if retrieved_chunks: | |
# Optional display of retrieved context snippets (for debugging or transparency) | |
# with st.expander("Show retrieved context snippets (German)"): | |
# for i, (chunk, dist) in enumerate(retrieved_chunks): | |
# st.caption(f"Chunk {i+1} (Distance: {dist:.2f})") | |
# st.markdown(f"_{chunk[:200]}..._") | |
# st.divider() | |
answer = generate_answer_app(user_query_streamlit, retrieved_chunks) | |
st.subheader("Tutor's Answer:") | |
st.markdown(answer) # Displaying German answer | |
else: | |
st.warning("No relevant information could be found for your query.") | |
else: | |
st.error("The application could not be initialized correctly. Please check for error messages related to model or data loading.") | |
st.sidebar.header("About this Project") | |
st.sidebar.info( | |
"This RAG application was developed as part of the 'AI Applications' module. " | |
"It uses Sentence Transformers for embeddings, FAISS for vector search, " | |
"and an LLM via Groq for answer generation." | |
) |