File size: 10,466 Bytes
e26f9a8
 
 
 
 
 
 
 
 
 
 
 
75632fd
e26f9a8
75632fd
 
e26f9a8
75632fd
e26f9a8
75632fd
e26f9a8
75632fd
 
e26f9a8
 
 
75632fd
e26f9a8
75632fd
e26f9a8
 
 
75632fd
e26f9a8
75632fd
e26f9a8
 
 
 
75632fd
e26f9a8
 
 
75632fd
e26f9a8
75632fd
e26f9a8
 
 
 
 
75632fd
 
 
e26f9a8
 
75632fd
e26f9a8
75632fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
# app.py

import streamlit as st
import os
import faiss
import pickle
from sentence_transformers import SentenceTransformer
from groq import Groq
from dotenv import load_dotenv
import re # Import regular expressions for expand_query_with_llm_app

# --- Page Configuration (MUST BE THE FIRST STREAMLIT COMMAND) ---
st.set_page_config(page_title="RAG BITS Tutor", page_icon="🎓") 

# --- Configuration and Model Loading (best outside functions to use caching) ---
@st.cache_resource # Important for caching large models and data
def load_models_and_data():
    # Load environment variables (if .env file is present in the Space)
    load_dotenv()
    groq_api_key_app = os.getenv("GROQ_API_KEY") 

    # Paths to the index and chunks
    output_folder = "faiss_index_bits" # Must be present in the HF Space
    index_path = os.path.join(output_folder, "bits_tutor.index")
    chunks_path = os.path.join(output_folder, "bits_chunks.pkl")

    # Load FAISS Index
    if not os.path.exists(index_path):
        st.error(f"FAISS Index not found at: {index_path}") 
        return None, None, None, None
    index_loaded = faiss.read_index(index_path)

    # Load Chunks
    if not os.path.exists(chunks_path):
        st.error(f"Chunks file not found at: {chunks_path}") 
        return None, None, None, None
    with open(chunks_path, "rb") as f:
        chunks_loaded = pickle.load(f)

    # Load Embedding Model
    embedding_model_name_app = "Sahajtomar/German-semantic" 
    embedding_model_loaded = SentenceTransformer(embedding_model_name_app)
    
    # Initialize Groq Client
    if not groq_api_key_app:
        st.error("GROQ_API_KEY not found. Please add it as a Secret in the Hugging Face Space settings.") 
        return None, None, None, None
    groq_client_loaded = Groq(api_key=groq_api_key_app)

    return index_loaded, chunks_loaded, embedding_model_loaded, groq_client_loaded

# Load models and data when the app starts
# Important: The function load_models_and_data() uses st.error(), which is a Streamlit command.
# Therefore, st.set_page_config() must be called BEFORE the first possible call to st.error().
faiss_index, chunks_data, embedding_model, groq_client = load_models_and_data()

# --- Core RAG Functions (adapted for Streamlit app) ---

def retrieve_relevant_chunks_app(query, k=5):
    # This function uses the globally loaded embedding_model, faiss_index, and chunks_data
    if embedding_model is None or faiss_index is None or chunks_data is None:
        st.warning("Models or data not loaded correctly. Cannot retrieve chunks.")
        return []
    query_embedding = embedding_model.encode([query], convert_to_numpy=True)
    distances, indices = faiss_index.search(query_embedding, k)
    retrieved_chunks_data = [(chunks_data[i], distances[0][j]) for j, i in enumerate(indices[0])]
    return retrieved_chunks_data

def generate_answer_app(query, retrieved_chunks_data):
    # This function uses the globally loaded groq_client
    if groq_client is None:
        st.error("Groq Client not initialized. Cannot generate answer.")
        return "Error: LLM client not available."
        
    context = "\n\n".join([chunk_text for chunk_text, dist in retrieved_chunks_data])
    
    # This prompt_template remains in German as it instructs the LLM for the German-speaking tutor
    prompt_template = f"""Beantworte die folgende Frage ausschließlich basierend auf dem bereitgestellten Kontext aus den Lehrmaterialien zur Business IT Strategie.
Antworte auf Deutsch.

Kontext:
{context}

Frage: {query}

Antwort:
"""
    try:
        chat_completion = groq_client.chat.completions.create(
            messages=[{"role": "user", "content": prompt_template}],
            model="llama3-70b-8192",
            temperature=0.3,
        )
        return chat_completion.choices[0].message.content
    except Exception as e:
        # Developer-facing error, UI will show a generic message or this
        st.error(f"Error during LLM request in generate_answer_app: {e}")
        return "An error occurred while generating the answer."

def expand_query_with_llm_app(original_query, llm_client_app):
    # This function uses the passed llm_client_app (which should be the global groq_client)
    if llm_client_app is None:
        st.warning("LLM client for query expansion not initialized.")
        return [original_query]
        
    # This prompt_template remains in German
    prompt_template_expansion = f"""Gegeben ist die folgende Nutzerfrage zum Thema "Business IT Strategie": "{original_query}"

Bitte generiere 2-3 alternative Formulierungen dieser Frage ODER eine Liste von 3-5 sehr relevanten Schlüsselbegriffen/Konzepten, 
die helfen würden, in einer Wissensdatenbank nach Antworten zu dieser Frage zu suchen.
Formatiere die Ausgabe klar, z.B. als nummerierte Liste für alternative Fragen oder als kommaseparierte Liste für Schlüsselbegriffe.
Gib NUR die alternativen Formulierungen oder die Schlüsselbegriffe aus. Keine Einleitungssätze.
"""
    try:
        chat_completion = llm_client_app.chat.completions.create(
            messages=[
                {
                    "role": "user",
                    "content": prompt_template_expansion,
                }
            ],
            model="llama3-8b-8192", 
            temperature=0.5,
        )
        expanded_terms_text = chat_completion.choices[0].message.content
        
        cleaned_queries = []
        potential_queries = expanded_terms_text.split('\n')

        for line in potential_queries:
            line = line.strip()
            line = re.sub(r"^\s*\d+\.\s*", "", line) 
            line = re.sub(r"^\s*[-\*]\s*", "", line)
            line = line.strip()

            if not line or \
               line.lower().startswith("here are") or \
               line.lower().startswith("sicher, hier sind") or \
               line.lower().startswith("alternative formulierungen:") or \
               line.lower().startswith("*alternative formulierungen:**") or \
               len(line) < 5: 
                continue
            cleaned_queries.append(line)
        
        if len(cleaned_queries) == 1 and ',' in cleaned_queries[0] and len(cleaned_queries[0].split(',')) > 1:
             final_expanded_list = [term.strip() for term in cleaned_queries[0].split(',') if term.strip() and len(term.strip()) > 4]
        else:
            final_expanded_list = cleaned_queries

        all_queries = [original_query]
        for q_exp in final_expanded_list:
            is_duplicate = False
            for q_all in all_queries:
                if q_all.lower() == q_exp.lower():
                    is_duplicate = True
                    break
            if not is_duplicate:
                all_queries.append(q_exp)
        
        return all_queries[:4] 

    except Exception as e:
        st.warning(f"Error during Query Expansion with LLM: {e}")
        return [original_query]

def retrieve_with_expanded_queries_app(original_query, llm_client_app, retrieve_func, k_per_expansion=2):
    expanded_queries = expand_query_with_llm_app(original_query, llm_client_app)
    
    # For UI feedback / debugging
    # st.write(f"Using the following queries for retrieval after expansion:") 
    # for i, eq_query in enumerate(expanded_queries):
    #     st.caption(f"  ExpQuery {i}: {eq_query}")

    all_retrieved_chunks_data = []
    for eq_query in expanded_queries:
        retrieved_for_eq = retrieve_func(eq_query, k=k_per_expansion)
        all_retrieved_chunks_data.extend(retrieved_for_eq) 

    unique_chunks_dict = {}
    for chunk_text, distance in all_retrieved_chunks_data:
        if chunk_text not in unique_chunks_dict or distance < unique_chunks_dict[chunk_text]:
            unique_chunks_dict[chunk_text] = distance
    
    sorted_unique_chunks_data = sorted(unique_chunks_dict.items(), key=lambda item: item[1])
    
    final_chunks_for_context = sorted_unique_chunks_data[:5] 
    
    # For UI feedback / debugging
    # st.write(f"\n{len(final_chunks_for_context)} unique chunks were selected for the context.")
    return final_chunks_for_context

# --- Streamlit UI ---
st.title("🎓 RAG Study Tutor for Business IT Strategy")
st.write("Ask your questions about the content of the lecture notes and case studies (in German).")

# User query input field (remains German for the user)
user_query_streamlit = st.text_input("Deine Frage:", "")

# Option to use query expansion
use_expansion = st.checkbox("Use Query Expansion (may improve results for some questions)", value=True) 

if user_query_streamlit:
    # Check if models and data are loaded successfully before proceeding
    if faiss_index and chunks_data and embedding_model and groq_client:
        with st.spinner("Searching for relevant information and generating answer..."): # Loading spinner
            retrieved_chunks = []
            if use_expansion:
                st.caption("Query expansion is active...")
                retrieved_chunks = retrieve_with_expanded_queries_app(user_query_streamlit, groq_client, retrieve_relevant_chunks_app, k_per_expansion=2)
            else:
                st.caption("Direct retrieval...")
                retrieved_chunks = retrieve_relevant_chunks_app(user_query_streamlit, k=3) 

            if retrieved_chunks:
                # Optional display of retrieved context snippets (for debugging or transparency)
                # with st.expander("Show retrieved context snippets (German)"):
                #     for i, (chunk, dist) in enumerate(retrieved_chunks):
                #         st.caption(f"Chunk {i+1} (Distance: {dist:.2f})")
                #         st.markdown(f"_{chunk[:200]}..._") 
                #     st.divider()
                
                answer = generate_answer_app(user_query_streamlit, retrieved_chunks)
                st.subheader("Tutor's Answer:") 
                st.markdown(answer) # Displaying German answer
            else:
                st.warning("No relevant information could be found for your query.") 
    else:
        st.error("The application could not be initialized correctly. Please check for error messages related to model or data loading.")

st.sidebar.header("About this Project") 
st.sidebar.info( 
    "This RAG application was developed as part of the 'AI Applications' module. "
    "It uses Sentence Transformers for embeddings, FAISS for vector search, "
    "and an LLM via Groq for answer generation."
)