Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import pipeline | |
from similarity_search import get_relevant_context # Import function from similarity_search.py | |
from bs4 import BeautifulSoup # For stripping HTML/XML tags | |
import spacy # Import spaCy for NLP tasks | |
# Load the spaCy model (make sure to download it first via 'python -m spacy download en_core_web_sm') | |
nlp = spacy.load("en_core_web_sm") | |
# Load the Roberta model for question answering | |
def load_qa_model(): | |
print("Loading QA model...") | |
try: | |
qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2") | |
print("QA model loaded.") | |
return qa_model | |
except Exception as e: | |
print(f"Error loading QA model: {e}") | |
raise RuntimeError("Failed to load the QA model.") | |
# Function to clean the context text (remove HTML tags and optional stop words) | |
def clean_text(context, remove_stop_words=False): | |
# Remove HTML/XML tags | |
clean_context = BeautifulSoup(context, "html.parser").get_text() | |
if remove_stop_words: | |
stop_words = set(["the", "a", "an", "of", "and", "to", "in", "for", "on", "at", "by", "with", "about", "as", "from"]) | |
clean_context = " ".join([word for word in clean_context.split() if word.lower() not in stop_words]) | |
return clean_context | |
# Function to extract proper nouns or pronouns from the question for context retrieval | |
def extract_topic_from_question(question): | |
# Process the text with spaCy | |
doc = nlp(question) | |
# Define pronouns to exclude manually if necessary | |
excluded_pronouns = ['I', 'you', 'he', 'she', 'it', 'they', 'we', 'them', 'this', 'that', 'these', 'those'] | |
# Extract proper nouns (PROPN) and pronouns (PRON), but exclude certain pronouns and stopwords | |
proper_nouns_or_pronouns = [ | |
token.text for token in doc | |
if ( | |
token.pos_ == 'PROPN' or token.pos_ == 'PRON') and token.text.lower() not in excluded_pronouns and not token.is_stop | |
] | |
# If no proper nouns or pronouns are found, remove stopwords and return whatever is left | |
if not proper_nouns_or_pronouns: | |
remaining_tokens = [ | |
token.text for token in doc | |
if not token.is_stop # Just remove stopwords, keep all other tokens | |
] | |
return " ".join(remaining_tokens) | |
# Otherwise, return the proper nouns or pronouns | |
return " ".join(proper_nouns_or_pronouns) | |
# Inside the answer_question_with_context function, add debugging statements: | |
def answer_question_with_context(question, qa_model): | |
try: | |
print(question) | |
# Extract topic from question (proper nouns or pronouns) | |
topic = extract_topic_from_question(question) | |
print(f"Extracted topic (proper nouns or pronouns): {topic}" if topic else "No proper nouns or pronouns extracted.") | |
# Retrieve relevant context based on the extracted topic | |
context = get_relevant_context(question, topic) | |
print(f"Retrieved Context: {context}") # Debug: Show context result | |
if not context.strip(): | |
return "No context found for answering.", "" | |
# Clean the context | |
context = clean_text(context, remove_stop_words=True) | |
# Use the QA model to extract an answer from the context | |
result = qa_model(question=question, context=context) | |
return result.get('answer', 'No answer found.'), context | |
except Exception as e: | |
print(f"Error during question answering: {e}") # Debug: Log error in terminal | |
return f"Error during question answering: {e}", "" | |
# Streamlit UI | |
def main(): | |
st.title("RAG Question Answering with Context Retrieval") | |
st.markdown("**Dataset Used:** _google_natural_questions_answerability_ ", unsafe_allow_html=True) | |
# User input for the question | |
question = st.text_input("Enter your question:", "What is the capital of Italy?") # Default question | |
# Display a log update | |
log = st.empty() | |
# Button to get the answer | |
if st.button("Get Answer"): | |
if not question: | |
st.error("Please provide a question.") | |
else: | |
try: | |
# Display a loading spinner and log message for the QA model | |
log.text("Loading QA model...") | |
with st.spinner("Loading QA model... Please wait."): | |
# Try loading the QA model | |
qa_model = load_qa_model() | |
# Display log message for context retrieval | |
log.text("Retrieving context...") | |
with st.spinner("Retrieving context..."): | |
answer, context = answer_question_with_context(question, qa_model) | |
if not context.strip(): | |
# If context is empty, let the user enter the context manually | |
st.warning("I couldn't find any relevant context for this question. Please enter it below:") | |
context = st.text_area("Enter your context here:", "", height=200, max_chars=1000) | |
if not context.strip(): | |
context = "I couldn't find any relevant context, and you didn't provide one either. Maybe next time!" | |
# Display the answer and context | |
st.subheader("Answer:") | |
st.write(answer) # Show the final answer | |
# Display the context | |
st.subheader("Context Used for Answering:") | |
st.text_area("Context:", context, height=200, max_chars=1000, key="context_input", disabled=False) # Editable context box | |
except Exception as e: | |
st.error(f"An error occurred: {e}") | |
log.text(f"Error: {e}") # Log error in place | |
# Display information about the application | |
st.markdown(""" | |
### About the Application | |
This is a **Retrieval-Augmented Generation (RAG)** application that answers questions by dynamically retrieving context from a dataset. Here's how it works: | |
1. **Dynamic Topic Extraction**: The application analyzes the user's question and extracts key topics (such as proper nouns or pronouns) to understand the context of the query. | |
2. **Context Retrieval**: Based on the extracted topic, the app searches for the most relevant documents (a few hundred) in the dataset. | |
3. **Answer Generation**: Using the retrieved context, an AI model (like RoBERTa) is used to generate the most accurate answer possible. The model combines the context with its internal knowledge to provide a robust and informed response. | |
4. **Customization**: If the application doesn't find enough relevant context automatically, you can manually input additional context to improve the answer. | |
The application leverages **Roberta-based question-answering models** to generate answers based on the context retrieved. This helps provide more accurate, context-specific answers compared to traditional approaches that rely solely on pre-trained model knowledge. | |
**Dataset Used**: The application dynamically pulls relevant documents from a dataset google_natural_questions_answerability, helping answer user questions more effectively. | |
""") | |
# Display Buy Me a Coffee button | |
st.markdown(""" | |
<div style="text-align: center;"> | |
<p>If you find this project useful, consider buying me a coffee to support further development! ☕️</p> | |
<a href="https://buymeacoffee.com/adnanailabs"> | |
<img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png" alt="Buy Me a Coffee" style="height: 50px;"> | |
</a> | |
</div> | |
""", unsafe_allow_html=True) | |
if __name__ == "__main__": | |
main() |