Muhammad Adnan
buy coffee
8072f91
import streamlit as st
from transformers import pipeline
from similarity_search import get_relevant_context # Import function from similarity_search.py
from bs4 import BeautifulSoup # For stripping HTML/XML tags
import spacy # Import spaCy for NLP tasks
# Load the spaCy model (make sure to download it first via 'python -m spacy download en_core_web_sm')
nlp = spacy.load("en_core_web_sm")
# Load the Roberta model for question answering
def load_qa_model():
print("Loading QA model...")
try:
qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2")
print("QA model loaded.")
return qa_model
except Exception as e:
print(f"Error loading QA model: {e}")
raise RuntimeError("Failed to load the QA model.")
# Function to clean the context text (remove HTML tags and optional stop words)
def clean_text(context, remove_stop_words=False):
# Remove HTML/XML tags
clean_context = BeautifulSoup(context, "html.parser").get_text()
if remove_stop_words:
stop_words = set(["the", "a", "an", "of", "and", "to", "in", "for", "on", "at", "by", "with", "about", "as", "from"])
clean_context = " ".join([word for word in clean_context.split() if word.lower() not in stop_words])
return clean_context
# Function to extract proper nouns or pronouns from the question for context retrieval
def extract_topic_from_question(question):
# Process the text with spaCy
doc = nlp(question)
# Define pronouns to exclude manually if necessary
excluded_pronouns = ['I', 'you', 'he', 'she', 'it', 'they', 'we', 'them', 'this', 'that', 'these', 'those']
# Extract proper nouns (PROPN) and pronouns (PRON), but exclude certain pronouns and stopwords
proper_nouns_or_pronouns = [
token.text for token in doc
if (
token.pos_ == 'PROPN' or token.pos_ == 'PRON') and token.text.lower() not in excluded_pronouns and not token.is_stop
]
# If no proper nouns or pronouns are found, remove stopwords and return whatever is left
if not proper_nouns_or_pronouns:
remaining_tokens = [
token.text for token in doc
if not token.is_stop # Just remove stopwords, keep all other tokens
]
return " ".join(remaining_tokens)
# Otherwise, return the proper nouns or pronouns
return " ".join(proper_nouns_or_pronouns)
# Inside the answer_question_with_context function, add debugging statements:
def answer_question_with_context(question, qa_model):
try:
print(question)
# Extract topic from question (proper nouns or pronouns)
topic = extract_topic_from_question(question)
print(f"Extracted topic (proper nouns or pronouns): {topic}" if topic else "No proper nouns or pronouns extracted.")
# Retrieve relevant context based on the extracted topic
context = get_relevant_context(question, topic)
print(f"Retrieved Context: {context}") # Debug: Show context result
if not context.strip():
return "No context found for answering.", ""
# Clean the context
context = clean_text(context, remove_stop_words=True)
# Use the QA model to extract an answer from the context
result = qa_model(question=question, context=context)
return result.get('answer', 'No answer found.'), context
except Exception as e:
print(f"Error during question answering: {e}") # Debug: Log error in terminal
return f"Error during question answering: {e}", ""
# Streamlit UI
def main():
st.title("RAG Question Answering with Context Retrieval")
st.markdown("**Dataset Used:** _google_natural_questions_answerability_ ", unsafe_allow_html=True)
# User input for the question
question = st.text_input("Enter your question:", "What is the capital of Italy?") # Default question
# Display a log update
log = st.empty()
# Button to get the answer
if st.button("Get Answer"):
if not question:
st.error("Please provide a question.")
else:
try:
# Display a loading spinner and log message for the QA model
log.text("Loading QA model...")
with st.spinner("Loading QA model... Please wait."):
# Try loading the QA model
qa_model = load_qa_model()
# Display log message for context retrieval
log.text("Retrieving context...")
with st.spinner("Retrieving context..."):
answer, context = answer_question_with_context(question, qa_model)
if not context.strip():
# If context is empty, let the user enter the context manually
st.warning("I couldn't find any relevant context for this question. Please enter it below:")
context = st.text_area("Enter your context here:", "", height=200, max_chars=1000)
if not context.strip():
context = "I couldn't find any relevant context, and you didn't provide one either. Maybe next time!"
# Display the answer and context
st.subheader("Answer:")
st.write(answer) # Show the final answer
# Display the context
st.subheader("Context Used for Answering:")
st.text_area("Context:", context, height=200, max_chars=1000, key="context_input", disabled=False) # Editable context box
except Exception as e:
st.error(f"An error occurred: {e}")
log.text(f"Error: {e}") # Log error in place
# Display information about the application
st.markdown("""
### About the Application
This is a **Retrieval-Augmented Generation (RAG)** application that answers questions by dynamically retrieving context from a dataset. Here's how it works:
1. **Dynamic Topic Extraction**: The application analyzes the user's question and extracts key topics (such as proper nouns or pronouns) to understand the context of the query.
2. **Context Retrieval**: Based on the extracted topic, the app searches for the most relevant documents (a few hundred) in the dataset.
3. **Answer Generation**: Using the retrieved context, an AI model (like RoBERTa) is used to generate the most accurate answer possible. The model combines the context with its internal knowledge to provide a robust and informed response.
4. **Customization**: If the application doesn't find enough relevant context automatically, you can manually input additional context to improve the answer.
The application leverages **Roberta-based question-answering models** to generate answers based on the context retrieved. This helps provide more accurate, context-specific answers compared to traditional approaches that rely solely on pre-trained model knowledge.
**Dataset Used**: The application dynamically pulls relevant documents from a dataset google_natural_questions_answerability, helping answer user questions more effectively.
""")
# Display Buy Me a Coffee button
st.markdown("""
<div style="text-align: center;">
<p>If you find this project useful, consider buying me a coffee to support further development! ☕️</p>
<a href="https://buymeacoffee.com/adnanailabs">
<img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png" alt="Buy Me a Coffee" style="height: 50px;">
</a>
</div>
""", unsafe_allow_html=True)
if __name__ == "__main__":
main()