Spaces:
Sleeping
Sleeping
File size: 10,466 Bytes
e26f9a8 75632fd e26f9a8 75632fd e26f9a8 75632fd e26f9a8 75632fd e26f9a8 75632fd e26f9a8 75632fd e26f9a8 75632fd e26f9a8 75632fd e26f9a8 75632fd e26f9a8 75632fd e26f9a8 75632fd e26f9a8 75632fd e26f9a8 75632fd e26f9a8 75632fd e26f9a8 75632fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 |
# app.py
import streamlit as st
import os
import faiss
import pickle
from sentence_transformers import SentenceTransformer
from groq import Groq
from dotenv import load_dotenv
import re # Import regular expressions for expand_query_with_llm_app
# --- Page Configuration (MUST BE THE FIRST STREAMLIT COMMAND) ---
st.set_page_config(page_title="RAG BITS Tutor", page_icon="🎓")
# --- Configuration and Model Loading (best outside functions to use caching) ---
@st.cache_resource # Important for caching large models and data
def load_models_and_data():
# Load environment variables (if .env file is present in the Space)
load_dotenv()
groq_api_key_app = os.getenv("GROQ_API_KEY")
# Paths to the index and chunks
output_folder = "faiss_index_bits" # Must be present in the HF Space
index_path = os.path.join(output_folder, "bits_tutor.index")
chunks_path = os.path.join(output_folder, "bits_chunks.pkl")
# Load FAISS Index
if not os.path.exists(index_path):
st.error(f"FAISS Index not found at: {index_path}")
return None, None, None, None
index_loaded = faiss.read_index(index_path)
# Load Chunks
if not os.path.exists(chunks_path):
st.error(f"Chunks file not found at: {chunks_path}")
return None, None, None, None
with open(chunks_path, "rb") as f:
chunks_loaded = pickle.load(f)
# Load Embedding Model
embedding_model_name_app = "Sahajtomar/German-semantic"
embedding_model_loaded = SentenceTransformer(embedding_model_name_app)
# Initialize Groq Client
if not groq_api_key_app:
st.error("GROQ_API_KEY not found. Please add it as a Secret in the Hugging Face Space settings.")
return None, None, None, None
groq_client_loaded = Groq(api_key=groq_api_key_app)
return index_loaded, chunks_loaded, embedding_model_loaded, groq_client_loaded
# Load models and data when the app starts
# Important: The function load_models_and_data() uses st.error(), which is a Streamlit command.
# Therefore, st.set_page_config() must be called BEFORE the first possible call to st.error().
faiss_index, chunks_data, embedding_model, groq_client = load_models_and_data()
# --- Core RAG Functions (adapted for Streamlit app) ---
def retrieve_relevant_chunks_app(query, k=5):
# This function uses the globally loaded embedding_model, faiss_index, and chunks_data
if embedding_model is None or faiss_index is None or chunks_data is None:
st.warning("Models or data not loaded correctly. Cannot retrieve chunks.")
return []
query_embedding = embedding_model.encode([query], convert_to_numpy=True)
distances, indices = faiss_index.search(query_embedding, k)
retrieved_chunks_data = [(chunks_data[i], distances[0][j]) for j, i in enumerate(indices[0])]
return retrieved_chunks_data
def generate_answer_app(query, retrieved_chunks_data):
# This function uses the globally loaded groq_client
if groq_client is None:
st.error("Groq Client not initialized. Cannot generate answer.")
return "Error: LLM client not available."
context = "\n\n".join([chunk_text for chunk_text, dist in retrieved_chunks_data])
# This prompt_template remains in German as it instructs the LLM for the German-speaking tutor
prompt_template = f"""Beantworte die folgende Frage ausschließlich basierend auf dem bereitgestellten Kontext aus den Lehrmaterialien zur Business IT Strategie.
Antworte auf Deutsch.
Kontext:
{context}
Frage: {query}
Antwort:
"""
try:
chat_completion = groq_client.chat.completions.create(
messages=[{"role": "user", "content": prompt_template}],
model="llama3-70b-8192",
temperature=0.3,
)
return chat_completion.choices[0].message.content
except Exception as e:
# Developer-facing error, UI will show a generic message or this
st.error(f"Error during LLM request in generate_answer_app: {e}")
return "An error occurred while generating the answer."
def expand_query_with_llm_app(original_query, llm_client_app):
# This function uses the passed llm_client_app (which should be the global groq_client)
if llm_client_app is None:
st.warning("LLM client for query expansion not initialized.")
return [original_query]
# This prompt_template remains in German
prompt_template_expansion = f"""Gegeben ist die folgende Nutzerfrage zum Thema "Business IT Strategie": "{original_query}"
Bitte generiere 2-3 alternative Formulierungen dieser Frage ODER eine Liste von 3-5 sehr relevanten Schlüsselbegriffen/Konzepten,
die helfen würden, in einer Wissensdatenbank nach Antworten zu dieser Frage zu suchen.
Formatiere die Ausgabe klar, z.B. als nummerierte Liste für alternative Fragen oder als kommaseparierte Liste für Schlüsselbegriffe.
Gib NUR die alternativen Formulierungen oder die Schlüsselbegriffe aus. Keine Einleitungssätze.
"""
try:
chat_completion = llm_client_app.chat.completions.create(
messages=[
{
"role": "user",
"content": prompt_template_expansion,
}
],
model="llama3-8b-8192",
temperature=0.5,
)
expanded_terms_text = chat_completion.choices[0].message.content
cleaned_queries = []
potential_queries = expanded_terms_text.split('\n')
for line in potential_queries:
line = line.strip()
line = re.sub(r"^\s*\d+\.\s*", "", line)
line = re.sub(r"^\s*[-\*]\s*", "", line)
line = line.strip()
if not line or \
line.lower().startswith("here are") or \
line.lower().startswith("sicher, hier sind") or \
line.lower().startswith("alternative formulierungen:") or \
line.lower().startswith("*alternative formulierungen:**") or \
len(line) < 5:
continue
cleaned_queries.append(line)
if len(cleaned_queries) == 1 and ',' in cleaned_queries[0] and len(cleaned_queries[0].split(',')) > 1:
final_expanded_list = [term.strip() for term in cleaned_queries[0].split(',') if term.strip() and len(term.strip()) > 4]
else:
final_expanded_list = cleaned_queries
all_queries = [original_query]
for q_exp in final_expanded_list:
is_duplicate = False
for q_all in all_queries:
if q_all.lower() == q_exp.lower():
is_duplicate = True
break
if not is_duplicate:
all_queries.append(q_exp)
return all_queries[:4]
except Exception as e:
st.warning(f"Error during Query Expansion with LLM: {e}")
return [original_query]
def retrieve_with_expanded_queries_app(original_query, llm_client_app, retrieve_func, k_per_expansion=2):
expanded_queries = expand_query_with_llm_app(original_query, llm_client_app)
# For UI feedback / debugging
# st.write(f"Using the following queries for retrieval after expansion:")
# for i, eq_query in enumerate(expanded_queries):
# st.caption(f" ExpQuery {i}: {eq_query}")
all_retrieved_chunks_data = []
for eq_query in expanded_queries:
retrieved_for_eq = retrieve_func(eq_query, k=k_per_expansion)
all_retrieved_chunks_data.extend(retrieved_for_eq)
unique_chunks_dict = {}
for chunk_text, distance in all_retrieved_chunks_data:
if chunk_text not in unique_chunks_dict or distance < unique_chunks_dict[chunk_text]:
unique_chunks_dict[chunk_text] = distance
sorted_unique_chunks_data = sorted(unique_chunks_dict.items(), key=lambda item: item[1])
final_chunks_for_context = sorted_unique_chunks_data[:5]
# For UI feedback / debugging
# st.write(f"\n{len(final_chunks_for_context)} unique chunks were selected for the context.")
return final_chunks_for_context
# --- Streamlit UI ---
st.title("🎓 RAG Study Tutor for Business IT Strategy")
st.write("Ask your questions about the content of the lecture notes and case studies (in German).")
# User query input field (remains German for the user)
user_query_streamlit = st.text_input("Deine Frage:", "")
# Option to use query expansion
use_expansion = st.checkbox("Use Query Expansion (may improve results for some questions)", value=True)
if user_query_streamlit:
# Check if models and data are loaded successfully before proceeding
if faiss_index and chunks_data and embedding_model and groq_client:
with st.spinner("Searching for relevant information and generating answer..."): # Loading spinner
retrieved_chunks = []
if use_expansion:
st.caption("Query expansion is active...")
retrieved_chunks = retrieve_with_expanded_queries_app(user_query_streamlit, groq_client, retrieve_relevant_chunks_app, k_per_expansion=2)
else:
st.caption("Direct retrieval...")
retrieved_chunks = retrieve_relevant_chunks_app(user_query_streamlit, k=3)
if retrieved_chunks:
# Optional display of retrieved context snippets (for debugging or transparency)
# with st.expander("Show retrieved context snippets (German)"):
# for i, (chunk, dist) in enumerate(retrieved_chunks):
# st.caption(f"Chunk {i+1} (Distance: {dist:.2f})")
# st.markdown(f"_{chunk[:200]}..._")
# st.divider()
answer = generate_answer_app(user_query_streamlit, retrieved_chunks)
st.subheader("Tutor's Answer:")
st.markdown(answer) # Displaying German answer
else:
st.warning("No relevant information could be found for your query.")
else:
st.error("The application could not be initialized correctly. Please check for error messages related to model or data loading.")
st.sidebar.header("About this Project")
st.sidebar.info(
"This RAG application was developed as part of the 'AI Applications' module. "
"It uses Sentence Transformers for embeddings, FAISS for vector search, "
"and an LLM via Groq for answer generation."
) |