Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -10,52 +10,226 @@ from dotenv import load_dotenv
|
|
10 |
import re # Import regular expressions for expand_query_with_llm_app
|
11 |
|
12 |
# --- Page Configuration (MUST BE THE FIRST STREAMLIT COMMAND) ---
|
13 |
-
st.set_page_config(page_title="RAG BITS Tutor", page_icon="🎓")
|
14 |
|
15 |
-
# ---
|
16 |
-
@st.cache_resource #
|
17 |
def load_models_and_data():
|
18 |
-
#
|
19 |
load_dotenv()
|
20 |
-
groq_api_key_app = os.getenv("GROQ_API_KEY")
|
21 |
|
22 |
-
#
|
23 |
-
output_folder = "faiss_index_bits" #
|
24 |
index_path = os.path.join(output_folder, "bits_tutor.index")
|
25 |
chunks_path = os.path.join(output_folder, "bits_chunks.pkl")
|
26 |
|
27 |
-
#
|
28 |
if not os.path.exists(index_path):
|
29 |
-
st.error(f"FAISS Index
|
30 |
return None, None, None, None
|
31 |
index_loaded = faiss.read_index(index_path)
|
32 |
|
33 |
-
#
|
34 |
if not os.path.exists(chunks_path):
|
35 |
-
st.error(f"Chunks
|
36 |
return None, None, None, None
|
37 |
with open(chunks_path, "rb") as f:
|
38 |
chunks_loaded = pickle.load(f)
|
39 |
|
40 |
-
#
|
41 |
embedding_model_name_app = "Sahajtomar/German-semantic"
|
42 |
embedding_model_loaded = SentenceTransformer(embedding_model_name_app)
|
43 |
|
44 |
-
#
|
45 |
if not groq_api_key_app:
|
46 |
-
st.error("GROQ_API_KEY
|
47 |
return None, None, None, None
|
48 |
groq_client_loaded = Groq(api_key=groq_api_key_app)
|
49 |
|
50 |
return index_loaded, chunks_loaded, embedding_model_loaded, groq_client_loaded
|
51 |
|
52 |
-
#
|
53 |
-
#
|
54 |
-
#
|
55 |
faiss_index, chunks_data, embedding_model, groq_client = load_models_and_data()
|
56 |
|
57 |
-
#
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
import re # Import regular expressions for expand_query_with_llm_app
|
11 |
|
12 |
# --- Page Configuration (MUST BE THE FIRST STREAMLIT COMMAND) ---
|
13 |
+
st.set_page_config(page_title="RAG BITS Tutor", page_icon="🎓")
|
14 |
|
15 |
+
# --- Configuration and Model Loading (best outside functions to use caching) ---
|
16 |
+
@st.cache_resource # Important for caching large models and data
|
17 |
def load_models_and_data():
|
18 |
+
# Load environment variables (if .env file is present in the Space)
|
19 |
load_dotenv()
|
20 |
+
groq_api_key_app = os.getenv("GROQ_API_KEY")
|
21 |
|
22 |
+
# Paths to the index and chunks
|
23 |
+
output_folder = "faiss_index_bits" # Must be present in the HF Space
|
24 |
index_path = os.path.join(output_folder, "bits_tutor.index")
|
25 |
chunks_path = os.path.join(output_folder, "bits_chunks.pkl")
|
26 |
|
27 |
+
# Load FAISS Index
|
28 |
if not os.path.exists(index_path):
|
29 |
+
st.error(f"FAISS Index not found at: {index_path}")
|
30 |
return None, None, None, None
|
31 |
index_loaded = faiss.read_index(index_path)
|
32 |
|
33 |
+
# Load Chunks
|
34 |
if not os.path.exists(chunks_path):
|
35 |
+
st.error(f"Chunks file not found at: {chunks_path}")
|
36 |
return None, None, None, None
|
37 |
with open(chunks_path, "rb") as f:
|
38 |
chunks_loaded = pickle.load(f)
|
39 |
|
40 |
+
# Load Embedding Model
|
41 |
embedding_model_name_app = "Sahajtomar/German-semantic"
|
42 |
embedding_model_loaded = SentenceTransformer(embedding_model_name_app)
|
43 |
|
44 |
+
# Initialize Groq Client
|
45 |
if not groq_api_key_app:
|
46 |
+
st.error("GROQ_API_KEY not found. Please add it as a Secret in the Hugging Face Space settings.")
|
47 |
return None, None, None, None
|
48 |
groq_client_loaded = Groq(api_key=groq_api_key_app)
|
49 |
|
50 |
return index_loaded, chunks_loaded, embedding_model_loaded, groq_client_loaded
|
51 |
|
52 |
+
# Load models and data when the app starts
|
53 |
+
# Important: The function load_models_and_data() uses st.error(), which is a Streamlit command.
|
54 |
+
# Therefore, st.set_page_config() must be called BEFORE the first possible call to st.error().
|
55 |
faiss_index, chunks_data, embedding_model, groq_client = load_models_and_data()
|
56 |
|
57 |
+
# --- Core RAG Functions (adapted for Streamlit app) ---
|
58 |
|
59 |
+
def retrieve_relevant_chunks_app(query, k=5):
|
60 |
+
# This function uses the globally loaded embedding_model, faiss_index, and chunks_data
|
61 |
+
if embedding_model is None or faiss_index is None or chunks_data is None:
|
62 |
+
st.warning("Models or data not loaded correctly. Cannot retrieve chunks.")
|
63 |
+
return []
|
64 |
+
query_embedding = embedding_model.encode([query], convert_to_numpy=True)
|
65 |
+
distances, indices = faiss_index.search(query_embedding, k)
|
66 |
+
retrieved_chunks_data = [(chunks_data[i], distances[0][j]) for j, i in enumerate(indices[0])]
|
67 |
+
return retrieved_chunks_data
|
68 |
+
|
69 |
+
def generate_answer_app(query, retrieved_chunks_data):
|
70 |
+
# This function uses the globally loaded groq_client
|
71 |
+
if groq_client is None:
|
72 |
+
st.error("Groq Client not initialized. Cannot generate answer.")
|
73 |
+
return "Error: LLM client not available."
|
74 |
+
|
75 |
+
context = "\n\n".join([chunk_text for chunk_text, dist in retrieved_chunks_data])
|
76 |
+
|
77 |
+
# This prompt_template remains in German as it instructs the LLM for the German-speaking tutor
|
78 |
+
prompt_template = f"""Beantworte die folgende Frage ausschließlich basierend auf dem bereitgestellten Kontext aus den Lehrmaterialien zur Business IT Strategie.
|
79 |
+
Antworte auf Deutsch.
|
80 |
+
|
81 |
+
Kontext:
|
82 |
+
{context}
|
83 |
+
|
84 |
+
Frage: {query}
|
85 |
+
|
86 |
+
Antwort:
|
87 |
+
"""
|
88 |
+
try:
|
89 |
+
chat_completion = groq_client.chat.completions.create(
|
90 |
+
messages=[{"role": "user", "content": prompt_template}],
|
91 |
+
model="llama3-70b-8192",
|
92 |
+
temperature=0.3,
|
93 |
+
)
|
94 |
+
return chat_completion.choices[0].message.content
|
95 |
+
except Exception as e:
|
96 |
+
# Developer-facing error, UI will show a generic message or this
|
97 |
+
st.error(f"Error during LLM request in generate_answer_app: {e}")
|
98 |
+
return "An error occurred while generating the answer."
|
99 |
+
|
100 |
+
def expand_query_with_llm_app(original_query, llm_client_app):
|
101 |
+
# This function uses the passed llm_client_app (which should be the global groq_client)
|
102 |
+
if llm_client_app is None:
|
103 |
+
st.warning("LLM client for query expansion not initialized.")
|
104 |
+
return [original_query]
|
105 |
+
|
106 |
+
# This prompt_template remains in German
|
107 |
+
prompt_template_expansion = f"""Gegeben ist die folgende Nutzerfrage zum Thema "Business IT Strategie": "{original_query}"
|
108 |
+
|
109 |
+
Bitte generiere 2-3 alternative Formulierungen dieser Frage ODER eine Liste von 3-5 sehr relevanten Schlüsselbegriffen/Konzepten,
|
110 |
+
die helfen würden, in einer Wissensdatenbank nach Antworten zu dieser Frage zu suchen.
|
111 |
+
Formatiere die Ausgabe klar, z.B. als nummerierte Liste für alternative Fragen oder als kommaseparierte Liste für Schlüsselbegriffe.
|
112 |
+
Gib NUR die alternativen Formulierungen oder die Schlüsselbegriffe aus. Keine Einleitungssätze.
|
113 |
+
"""
|
114 |
+
try:
|
115 |
+
chat_completion = llm_client_app.chat.completions.create(
|
116 |
+
messages=[
|
117 |
+
{
|
118 |
+
"role": "user",
|
119 |
+
"content": prompt_template_expansion,
|
120 |
+
}
|
121 |
+
],
|
122 |
+
model="llama3-8b-8192",
|
123 |
+
temperature=0.5,
|
124 |
+
)
|
125 |
+
expanded_terms_text = chat_completion.choices[0].message.content
|
126 |
+
|
127 |
+
cleaned_queries = []
|
128 |
+
potential_queries = expanded_terms_text.split('\n')
|
129 |
+
|
130 |
+
for line in potential_queries:
|
131 |
+
line = line.strip()
|
132 |
+
line = re.sub(r"^\s*\d+\.\s*", "", line)
|
133 |
+
line = re.sub(r"^\s*[-\*]\s*", "", line)
|
134 |
+
line = line.strip()
|
135 |
+
|
136 |
+
if not line or \
|
137 |
+
line.lower().startswith("here are") or \
|
138 |
+
line.lower().startswith("sicher, hier sind") or \
|
139 |
+
line.lower().startswith("alternative formulierungen:") or \
|
140 |
+
line.lower().startswith("*alternative formulierungen:**") or \
|
141 |
+
len(line) < 5:
|
142 |
+
continue
|
143 |
+
cleaned_queries.append(line)
|
144 |
+
|
145 |
+
if len(cleaned_queries) == 1 and ',' in cleaned_queries[0] and len(cleaned_queries[0].split(',')) > 1:
|
146 |
+
final_expanded_list = [term.strip() for term in cleaned_queries[0].split(',') if term.strip() and len(term.strip()) > 4]
|
147 |
+
else:
|
148 |
+
final_expanded_list = cleaned_queries
|
149 |
+
|
150 |
+
all_queries = [original_query]
|
151 |
+
for q_exp in final_expanded_list:
|
152 |
+
is_duplicate = False
|
153 |
+
for q_all in all_queries:
|
154 |
+
if q_all.lower() == q_exp.lower():
|
155 |
+
is_duplicate = True
|
156 |
+
break
|
157 |
+
if not is_duplicate:
|
158 |
+
all_queries.append(q_exp)
|
159 |
+
|
160 |
+
return all_queries[:4]
|
161 |
+
|
162 |
+
except Exception as e:
|
163 |
+
st.warning(f"Error during Query Expansion with LLM: {e}")
|
164 |
+
return [original_query]
|
165 |
+
|
166 |
+
def retrieve_with_expanded_queries_app(original_query, llm_client_app, retrieve_func, k_per_expansion=2):
|
167 |
+
expanded_queries = expand_query_with_llm_app(original_query, llm_client_app)
|
168 |
+
|
169 |
+
# For UI feedback / debugging
|
170 |
+
# st.write(f"Using the following queries for retrieval after expansion:")
|
171 |
+
# for i, eq_query in enumerate(expanded_queries):
|
172 |
+
# st.caption(f" ExpQuery {i}: {eq_query}")
|
173 |
+
|
174 |
+
all_retrieved_chunks_data = []
|
175 |
+
for eq_query in expanded_queries:
|
176 |
+
retrieved_for_eq = retrieve_func(eq_query, k=k_per_expansion)
|
177 |
+
all_retrieved_chunks_data.extend(retrieved_for_eq)
|
178 |
+
|
179 |
+
unique_chunks_dict = {}
|
180 |
+
for chunk_text, distance in all_retrieved_chunks_data:
|
181 |
+
if chunk_text not in unique_chunks_dict or distance < unique_chunks_dict[chunk_text]:
|
182 |
+
unique_chunks_dict[chunk_text] = distance
|
183 |
+
|
184 |
+
sorted_unique_chunks_data = sorted(unique_chunks_dict.items(), key=lambda item: item[1])
|
185 |
+
|
186 |
+
final_chunks_for_context = sorted_unique_chunks_data[:5]
|
187 |
+
|
188 |
+
# For UI feedback / debugging
|
189 |
+
# st.write(f"\n{len(final_chunks_for_context)} unique chunks were selected for the context.")
|
190 |
+
return final_chunks_for_context
|
191 |
+
|
192 |
+
# --- Streamlit UI ---
|
193 |
+
st.title("🎓 RAG Study Tutor for Business IT Strategy")
|
194 |
+
st.write("Ask your questions about the content of the lecture notes and case studies (in German).")
|
195 |
+
|
196 |
+
# User query input field (remains German for the user)
|
197 |
+
user_query_streamlit = st.text_input("Deine Frage:", "")
|
198 |
+
|
199 |
+
# Option to use query expansion
|
200 |
+
use_expansion = st.checkbox("Use Query Expansion (may improve results for some questions)", value=True)
|
201 |
+
|
202 |
+
if user_query_streamlit:
|
203 |
+
# Check if models and data are loaded successfully before proceeding
|
204 |
+
if faiss_index and chunks_data and embedding_model and groq_client:
|
205 |
+
with st.spinner("Searching for relevant information and generating answer..."): # Loading spinner
|
206 |
+
retrieved_chunks = []
|
207 |
+
if use_expansion:
|
208 |
+
st.caption("Query expansion is active...")
|
209 |
+
retrieved_chunks = retrieve_with_expanded_queries_app(user_query_streamlit, groq_client, retrieve_relevant_chunks_app, k_per_expansion=2)
|
210 |
+
else:
|
211 |
+
st.caption("Direct retrieval...")
|
212 |
+
retrieved_chunks = retrieve_relevant_chunks_app(user_query_streamlit, k=3)
|
213 |
+
|
214 |
+
if retrieved_chunks:
|
215 |
+
# Optional display of retrieved context snippets (for debugging or transparency)
|
216 |
+
# with st.expander("Show retrieved context snippets (German)"):
|
217 |
+
# for i, (chunk, dist) in enumerate(retrieved_chunks):
|
218 |
+
# st.caption(f"Chunk {i+1} (Distance: {dist:.2f})")
|
219 |
+
# st.markdown(f"_{chunk[:200]}..._")
|
220 |
+
# st.divider()
|
221 |
+
|
222 |
+
answer = generate_answer_app(user_query_streamlit, retrieved_chunks)
|
223 |
+
st.subheader("Tutor's Answer:")
|
224 |
+
st.markdown(answer) # Displaying German answer
|
225 |
+
else:
|
226 |
+
st.warning("No relevant information could be found for your query.")
|
227 |
+
else:
|
228 |
+
st.error("The application could not be initialized correctly. Please check for error messages related to model or data loading.")
|
229 |
+
|
230 |
+
st.sidebar.header("About this Project")
|
231 |
+
st.sidebar.info(
|
232 |
+
"This RAG application was developed as part of the 'AI Applications' module. "
|
233 |
+
"It uses Sentence Transformers for embeddings, FAISS for vector search, "
|
234 |
+
"and an LLM via Groq for answer generation."
|
235 |
+
)
|