patronmoses commited on
Commit
75632fd
·
verified ·
1 Parent(s): e26f9a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +195 -21
app.py CHANGED
@@ -10,52 +10,226 @@ from dotenv import load_dotenv
10
  import re # Import regular expressions for expand_query_with_llm_app
11
 
12
  # --- Page Configuration (MUST BE THE FIRST STREAMLIT COMMAND) ---
13
- st.set_page_config(page_title="RAG BITS Tutor", page_icon="🎓") # Set page title and icon
14
 
15
- # --- Konfiguration und Modell-Laden (am besten ausserhalb von Funktionen, um Caching zu nutzen) ---
16
- @st.cache_resource # Wichtig für das Caching von grossen Modellen und Daten
17
  def load_models_and_data():
18
- # Lade Umgebungsvariablen (falls .env Datei im Space vorhanden ist)
19
  load_dotenv()
20
- groq_api_key_app = os.getenv("GROQ_API_KEY") # Stelle sicher, dass der Key im Space verfügbar ist (siehe Schritt 3)
21
 
22
- # Pfade zum Index und den Chunks
23
- output_folder = "faiss_index_bits" # Muss im HF Space vorhanden sein
24
  index_path = os.path.join(output_folder, "bits_tutor.index")
25
  chunks_path = os.path.join(output_folder, "bits_chunks.pkl")
26
 
27
- # Lade FAISS Index
28
  if not os.path.exists(index_path):
29
- st.error(f"FAISS Index nicht gefunden unter: {index_path}") # This is a Streamlit command
30
  return None, None, None, None
31
  index_loaded = faiss.read_index(index_path)
32
 
33
- # Lade Chunks
34
  if not os.path.exists(chunks_path):
35
- st.error(f"Chunks-Datei nicht gefunden unter: {chunks_path}") # This is a Streamlit command
36
  return None, None, None, None
37
  with open(chunks_path, "rb") as f:
38
  chunks_loaded = pickle.load(f)
39
 
40
- # Lade Embedding-Modell
41
  embedding_model_name_app = "Sahajtomar/German-semantic"
42
  embedding_model_loaded = SentenceTransformer(embedding_model_name_app)
43
 
44
- # Initialisiere Groq Client
45
  if not groq_api_key_app:
46
- st.error("GROQ_API_KEY nicht gefunden. Bitte im Hugging Face Space als Secret hinzufügen.") # This is a Streamlit command
47
  return None, None, None, None
48
  groq_client_loaded = Groq(api_key=groq_api_key_app)
49
 
50
  return index_loaded, chunks_loaded, embedding_model_loaded, groq_client_loaded
51
 
52
- # Lade Modelle und Daten beim Start der App
53
- # Wichtig: Die Funktion load_models_and_data() verwendet st.error(), was ein Streamlit-Befehl ist.
54
- # Daher muss st.set_page_config() VOR dem ersten möglichen Aufruf von st.error() stehen.
55
  faiss_index, chunks_data, embedding_model, groq_client = load_models_and_data()
56
 
57
- # ... (Rest deines app.py Skripts bleibt gleich) ...
58
 
59
- # --- Streamlit UI (kommt nach load_models_and_data) ---
60
- st.title("🎓 RAG Study Tutor for Business IT Strategy")
61
- # ... etc. ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  import re # Import regular expressions for expand_query_with_llm_app
11
 
12
  # --- Page Configuration (MUST BE THE FIRST STREAMLIT COMMAND) ---
13
+ st.set_page_config(page_title="RAG BITS Tutor", page_icon="🎓")
14
 
15
+ # --- Configuration and Model Loading (best outside functions to use caching) ---
16
+ @st.cache_resource # Important for caching large models and data
17
  def load_models_and_data():
18
+ # Load environment variables (if .env file is present in the Space)
19
  load_dotenv()
20
+ groq_api_key_app = os.getenv("GROQ_API_KEY")
21
 
22
+ # Paths to the index and chunks
23
+ output_folder = "faiss_index_bits" # Must be present in the HF Space
24
  index_path = os.path.join(output_folder, "bits_tutor.index")
25
  chunks_path = os.path.join(output_folder, "bits_chunks.pkl")
26
 
27
+ # Load FAISS Index
28
  if not os.path.exists(index_path):
29
+ st.error(f"FAISS Index not found at: {index_path}")
30
  return None, None, None, None
31
  index_loaded = faiss.read_index(index_path)
32
 
33
+ # Load Chunks
34
  if not os.path.exists(chunks_path):
35
+ st.error(f"Chunks file not found at: {chunks_path}")
36
  return None, None, None, None
37
  with open(chunks_path, "rb") as f:
38
  chunks_loaded = pickle.load(f)
39
 
40
+ # Load Embedding Model
41
  embedding_model_name_app = "Sahajtomar/German-semantic"
42
  embedding_model_loaded = SentenceTransformer(embedding_model_name_app)
43
 
44
+ # Initialize Groq Client
45
  if not groq_api_key_app:
46
+ st.error("GROQ_API_KEY not found. Please add it as a Secret in the Hugging Face Space settings.")
47
  return None, None, None, None
48
  groq_client_loaded = Groq(api_key=groq_api_key_app)
49
 
50
  return index_loaded, chunks_loaded, embedding_model_loaded, groq_client_loaded
51
 
52
+ # Load models and data when the app starts
53
+ # Important: The function load_models_and_data() uses st.error(), which is a Streamlit command.
54
+ # Therefore, st.set_page_config() must be called BEFORE the first possible call to st.error().
55
  faiss_index, chunks_data, embedding_model, groq_client = load_models_and_data()
56
 
57
+ # --- Core RAG Functions (adapted for Streamlit app) ---
58
 
59
+ def retrieve_relevant_chunks_app(query, k=5):
60
+ # This function uses the globally loaded embedding_model, faiss_index, and chunks_data
61
+ if embedding_model is None or faiss_index is None or chunks_data is None:
62
+ st.warning("Models or data not loaded correctly. Cannot retrieve chunks.")
63
+ return []
64
+ query_embedding = embedding_model.encode([query], convert_to_numpy=True)
65
+ distances, indices = faiss_index.search(query_embedding, k)
66
+ retrieved_chunks_data = [(chunks_data[i], distances[0][j]) for j, i in enumerate(indices[0])]
67
+ return retrieved_chunks_data
68
+
69
+ def generate_answer_app(query, retrieved_chunks_data):
70
+ # This function uses the globally loaded groq_client
71
+ if groq_client is None:
72
+ st.error("Groq Client not initialized. Cannot generate answer.")
73
+ return "Error: LLM client not available."
74
+
75
+ context = "\n\n".join([chunk_text for chunk_text, dist in retrieved_chunks_data])
76
+
77
+ # This prompt_template remains in German as it instructs the LLM for the German-speaking tutor
78
+ prompt_template = f"""Beantworte die folgende Frage ausschließlich basierend auf dem bereitgestellten Kontext aus den Lehrmaterialien zur Business IT Strategie.
79
+ Antworte auf Deutsch.
80
+
81
+ Kontext:
82
+ {context}
83
+
84
+ Frage: {query}
85
+
86
+ Antwort:
87
+ """
88
+ try:
89
+ chat_completion = groq_client.chat.completions.create(
90
+ messages=[{"role": "user", "content": prompt_template}],
91
+ model="llama3-70b-8192",
92
+ temperature=0.3,
93
+ )
94
+ return chat_completion.choices[0].message.content
95
+ except Exception as e:
96
+ # Developer-facing error, UI will show a generic message or this
97
+ st.error(f"Error during LLM request in generate_answer_app: {e}")
98
+ return "An error occurred while generating the answer."
99
+
100
+ def expand_query_with_llm_app(original_query, llm_client_app):
101
+ # This function uses the passed llm_client_app (which should be the global groq_client)
102
+ if llm_client_app is None:
103
+ st.warning("LLM client for query expansion not initialized.")
104
+ return [original_query]
105
+
106
+ # This prompt_template remains in German
107
+ prompt_template_expansion = f"""Gegeben ist die folgende Nutzerfrage zum Thema "Business IT Strategie": "{original_query}"
108
+
109
+ Bitte generiere 2-3 alternative Formulierungen dieser Frage ODER eine Liste von 3-5 sehr relevanten Schlüsselbegriffen/Konzepten,
110
+ die helfen würden, in einer Wissensdatenbank nach Antworten zu dieser Frage zu suchen.
111
+ Formatiere die Ausgabe klar, z.B. als nummerierte Liste für alternative Fragen oder als kommaseparierte Liste für Schlüsselbegriffe.
112
+ Gib NUR die alternativen Formulierungen oder die Schlüsselbegriffe aus. Keine Einleitungssätze.
113
+ """
114
+ try:
115
+ chat_completion = llm_client_app.chat.completions.create(
116
+ messages=[
117
+ {
118
+ "role": "user",
119
+ "content": prompt_template_expansion,
120
+ }
121
+ ],
122
+ model="llama3-8b-8192",
123
+ temperature=0.5,
124
+ )
125
+ expanded_terms_text = chat_completion.choices[0].message.content
126
+
127
+ cleaned_queries = []
128
+ potential_queries = expanded_terms_text.split('\n')
129
+
130
+ for line in potential_queries:
131
+ line = line.strip()
132
+ line = re.sub(r"^\s*\d+\.\s*", "", line)
133
+ line = re.sub(r"^\s*[-\*]\s*", "", line)
134
+ line = line.strip()
135
+
136
+ if not line or \
137
+ line.lower().startswith("here are") or \
138
+ line.lower().startswith("sicher, hier sind") or \
139
+ line.lower().startswith("alternative formulierungen:") or \
140
+ line.lower().startswith("*alternative formulierungen:**") or \
141
+ len(line) < 5:
142
+ continue
143
+ cleaned_queries.append(line)
144
+
145
+ if len(cleaned_queries) == 1 and ',' in cleaned_queries[0] and len(cleaned_queries[0].split(',')) > 1:
146
+ final_expanded_list = [term.strip() for term in cleaned_queries[0].split(',') if term.strip() and len(term.strip()) > 4]
147
+ else:
148
+ final_expanded_list = cleaned_queries
149
+
150
+ all_queries = [original_query]
151
+ for q_exp in final_expanded_list:
152
+ is_duplicate = False
153
+ for q_all in all_queries:
154
+ if q_all.lower() == q_exp.lower():
155
+ is_duplicate = True
156
+ break
157
+ if not is_duplicate:
158
+ all_queries.append(q_exp)
159
+
160
+ return all_queries[:4]
161
+
162
+ except Exception as e:
163
+ st.warning(f"Error during Query Expansion with LLM: {e}")
164
+ return [original_query]
165
+
166
+ def retrieve_with_expanded_queries_app(original_query, llm_client_app, retrieve_func, k_per_expansion=2):
167
+ expanded_queries = expand_query_with_llm_app(original_query, llm_client_app)
168
+
169
+ # For UI feedback / debugging
170
+ # st.write(f"Using the following queries for retrieval after expansion:")
171
+ # for i, eq_query in enumerate(expanded_queries):
172
+ # st.caption(f" ExpQuery {i}: {eq_query}")
173
+
174
+ all_retrieved_chunks_data = []
175
+ for eq_query in expanded_queries:
176
+ retrieved_for_eq = retrieve_func(eq_query, k=k_per_expansion)
177
+ all_retrieved_chunks_data.extend(retrieved_for_eq)
178
+
179
+ unique_chunks_dict = {}
180
+ for chunk_text, distance in all_retrieved_chunks_data:
181
+ if chunk_text not in unique_chunks_dict or distance < unique_chunks_dict[chunk_text]:
182
+ unique_chunks_dict[chunk_text] = distance
183
+
184
+ sorted_unique_chunks_data = sorted(unique_chunks_dict.items(), key=lambda item: item[1])
185
+
186
+ final_chunks_for_context = sorted_unique_chunks_data[:5]
187
+
188
+ # For UI feedback / debugging
189
+ # st.write(f"\n{len(final_chunks_for_context)} unique chunks were selected for the context.")
190
+ return final_chunks_for_context
191
+
192
+ # --- Streamlit UI ---
193
+ st.title("🎓 RAG Study Tutor for Business IT Strategy")
194
+ st.write("Ask your questions about the content of the lecture notes and case studies (in German).")
195
+
196
+ # User query input field (remains German for the user)
197
+ user_query_streamlit = st.text_input("Deine Frage:", "")
198
+
199
+ # Option to use query expansion
200
+ use_expansion = st.checkbox("Use Query Expansion (may improve results for some questions)", value=True)
201
+
202
+ if user_query_streamlit:
203
+ # Check if models and data are loaded successfully before proceeding
204
+ if faiss_index and chunks_data and embedding_model and groq_client:
205
+ with st.spinner("Searching for relevant information and generating answer..."): # Loading spinner
206
+ retrieved_chunks = []
207
+ if use_expansion:
208
+ st.caption("Query expansion is active...")
209
+ retrieved_chunks = retrieve_with_expanded_queries_app(user_query_streamlit, groq_client, retrieve_relevant_chunks_app, k_per_expansion=2)
210
+ else:
211
+ st.caption("Direct retrieval...")
212
+ retrieved_chunks = retrieve_relevant_chunks_app(user_query_streamlit, k=3)
213
+
214
+ if retrieved_chunks:
215
+ # Optional display of retrieved context snippets (for debugging or transparency)
216
+ # with st.expander("Show retrieved context snippets (German)"):
217
+ # for i, (chunk, dist) in enumerate(retrieved_chunks):
218
+ # st.caption(f"Chunk {i+1} (Distance: {dist:.2f})")
219
+ # st.markdown(f"_{chunk[:200]}..._")
220
+ # st.divider()
221
+
222
+ answer = generate_answer_app(user_query_streamlit, retrieved_chunks)
223
+ st.subheader("Tutor's Answer:")
224
+ st.markdown(answer) # Displaying German answer
225
+ else:
226
+ st.warning("No relevant information could be found for your query.")
227
+ else:
228
+ st.error("The application could not be initialized correctly. Please check for error messages related to model or data loading.")
229
+
230
+ st.sidebar.header("About this Project")
231
+ st.sidebar.info(
232
+ "This RAG application was developed as part of the 'AI Applications' module. "
233
+ "It uses Sentence Transformers for embeddings, FAISS for vector search, "
234
+ "and an LLM via Groq for answer generation."
235
+ )