Spaces:
Sleeping
Sleeping
Muhammad Adnan
commited on
Commit
·
ae6eb20
1
Parent(s):
3623388
Initial commit of Streamlit app
Browse files- app.py +147 -0
- data_ret.py +57 -0
- requirements.txt +8 -0
- similarity_search.py +94 -0
app.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from transformers import pipeline
|
3 |
+
from similarity_search import get_relevant_context # Import function from similarity_search.py
|
4 |
+
from bs4 import BeautifulSoup # For stripping HTML/XML tags
|
5 |
+
import spacy # Import spaCy for NLP tasks
|
6 |
+
|
7 |
+
# Load the spaCy model (make sure to download it first via 'python -m spacy download en_core_web_sm')
|
8 |
+
nlp = spacy.load("en_core_web_sm")
|
9 |
+
|
10 |
+
# Load the Roberta model for question answering
|
11 |
+
def load_qa_model():
|
12 |
+
print("Loading QA model...")
|
13 |
+
try:
|
14 |
+
qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2")
|
15 |
+
print("QA model loaded.")
|
16 |
+
return qa_model
|
17 |
+
except Exception as e:
|
18 |
+
print(f"Error loading QA model: {e}")
|
19 |
+
raise RuntimeError("Failed to load the QA model.")
|
20 |
+
|
21 |
+
# Function to clean the context text (remove HTML tags and optional stop words)
|
22 |
+
def clean_text(context, remove_stop_words=False):
|
23 |
+
# Remove HTML/XML tags
|
24 |
+
clean_context = BeautifulSoup(context, "html.parser").get_text()
|
25 |
+
|
26 |
+
if remove_stop_words:
|
27 |
+
stop_words = set(["the", "a", "an", "of", "and", "to", "in", "for", "on", "at", "by", "with", "about", "as", "from"])
|
28 |
+
clean_context = " ".join([word for word in clean_context.split() if word.lower() not in stop_words])
|
29 |
+
|
30 |
+
return clean_context
|
31 |
+
|
32 |
+
# Function to extract proper nouns or pronouns from the question for context retrieval
|
33 |
+
def extract_topic_from_question(question):
|
34 |
+
# Process the text with spaCy
|
35 |
+
doc = nlp(question)
|
36 |
+
|
37 |
+
# Define pronouns to exclude manually if necessary
|
38 |
+
excluded_pronouns = ['I', 'you', 'he', 'she', 'it', 'they', 'we', 'them', 'this', 'that', 'these', 'those']
|
39 |
+
|
40 |
+
# Extract proper nouns (PROPN) and pronouns (PRON), but exclude certain pronouns and stopwords
|
41 |
+
proper_nouns_or_pronouns = [
|
42 |
+
token.text for token in doc
|
43 |
+
if (
|
44 |
+
token.pos_ == 'PROPN' or token.pos_ == 'PRON') and token.text.lower() not in excluded_pronouns and not token.is_stop
|
45 |
+
]
|
46 |
+
|
47 |
+
# If no proper nouns or pronouns are found, remove stopwords and return whatever is left
|
48 |
+
if not proper_nouns_or_pronouns:
|
49 |
+
remaining_tokens = [
|
50 |
+
token.text for token in doc
|
51 |
+
if not token.is_stop # Just remove stopwords, keep all other tokens
|
52 |
+
]
|
53 |
+
return " ".join(remaining_tokens)
|
54 |
+
|
55 |
+
# Otherwise, return the proper nouns or pronouns
|
56 |
+
return " ".join(proper_nouns_or_pronouns)
|
57 |
+
|
58 |
+
# Inside the answer_question_with_context function, add debugging statements:
|
59 |
+
def answer_question_with_context(question, qa_model):
|
60 |
+
try:
|
61 |
+
print(question)
|
62 |
+
# Extract topic from question (proper nouns or pronouns)
|
63 |
+
topic = extract_topic_from_question(question)
|
64 |
+
print(f"Extracted topic (proper nouns or pronouns): {topic}" if topic else "No proper nouns or pronouns extracted.")
|
65 |
+
|
66 |
+
# Retrieve relevant context based on the extracted topic
|
67 |
+
context = get_relevant_context(question, topic)
|
68 |
+
print(f"Retrieved Context: {context}") # Debug: Show context result
|
69 |
+
|
70 |
+
if not context.strip():
|
71 |
+
return "No context found for answering.", ""
|
72 |
+
|
73 |
+
# Clean the context
|
74 |
+
context = clean_text(context, remove_stop_words=True)
|
75 |
+
|
76 |
+
# Use the QA model to extract an answer from the context
|
77 |
+
result = qa_model(question=question, context=context)
|
78 |
+
return result.get('answer', 'No answer found.'), context
|
79 |
+
except Exception as e:
|
80 |
+
print(f"Error during question answering: {e}") # Debug: Log error in terminal
|
81 |
+
return f"Error during question answering: {e}", ""
|
82 |
+
|
83 |
+
# Streamlit UI
|
84 |
+
def main():
|
85 |
+
st.title("RAG Question Answering with Context Retrieval")
|
86 |
+
|
87 |
+
# User input for the question
|
88 |
+
question = st.text_input("Enter your question:", "What is the capital of Italy?") # Default question
|
89 |
+
|
90 |
+
# Display a log update
|
91 |
+
log = st.empty()
|
92 |
+
|
93 |
+
# Button to get the answer
|
94 |
+
if st.button("Get Answer"):
|
95 |
+
if not question:
|
96 |
+
st.error("Please provide a question.")
|
97 |
+
else:
|
98 |
+
try:
|
99 |
+
# Display a loading spinner and log message for the QA model
|
100 |
+
log.text("Loading QA model...")
|
101 |
+
with st.spinner("Loading QA model... Please wait."):
|
102 |
+
|
103 |
+
# Try loading the QA model
|
104 |
+
qa_model = load_qa_model()
|
105 |
+
|
106 |
+
# Display log message for context retrieval
|
107 |
+
log.text("Retrieving context...")
|
108 |
+
with st.spinner("Retrieving context..."):
|
109 |
+
|
110 |
+
answer, context = answer_question_with_context(question, qa_model)
|
111 |
+
|
112 |
+
if not context.strip():
|
113 |
+
# If context is empty, let the user enter the context manually
|
114 |
+
st.warning("I couldn't find any relevant context for this question. Please enter it below:")
|
115 |
+
context = st.text_area("Enter your context here:", "", height=200, max_chars=1000)
|
116 |
+
if not context.strip():
|
117 |
+
context = "I couldn't find any relevant context, and you didn't provide one either. Maybe next time!"
|
118 |
+
|
119 |
+
# Display the answer and context
|
120 |
+
st.subheader("Answer:")
|
121 |
+
st.write(answer) # Show the final answer
|
122 |
+
|
123 |
+
# Display the context
|
124 |
+
st.subheader("Context Used for Answering:")
|
125 |
+
st.text_area("Context:", context, height=200, max_chars=1000, key="context_input", disabled=False) # Editable context box
|
126 |
+
|
127 |
+
except Exception as e:
|
128 |
+
st.error(f"An error occurred: {e}")
|
129 |
+
log.text(f"Error: {e}") # Log error in place
|
130 |
+
|
131 |
+
# Display information about the application
|
132 |
+
st.markdown("""
|
133 |
+
### About the Application
|
134 |
+
This is a **Retrieval-Augmented Generation (RAG)** application that answers questions by dynamically retrieving context from a dataset. Here's how it works:
|
135 |
+
|
136 |
+
1. **Dynamic Topic Extraction**: The application analyzes the user's question and extracts key topics (such as proper nouns or pronouns) to understand the context of the query.
|
137 |
+
2. **Context Retrieval**: Based on the extracted topic, the app searches for the most relevant documents (a few hundred) in the dataset.
|
138 |
+
3. **Answer Generation**: Using the retrieved context, an AI model (like RoBERTa) is used to generate the most accurate answer possible. The model combines the context with its internal knowledge to provide a robust and informed response.
|
139 |
+
4. **Customization**: If the application doesn't find enough relevant context automatically, you can manually input additional context to improve the answer.
|
140 |
+
|
141 |
+
The application leverages **Roberta-based question-answering models** to generate answers based on the context retrieved. This helps provide more accurate, context-specific answers compared to traditional approaches that rely solely on pre-trained model knowledge.
|
142 |
+
|
143 |
+
**Dataset Used**: The application dynamically pulls relevant documents from a dataset (e.g., academic papers, FAQ pages, product manuals, etc.), helping answer user questions more effectively.
|
144 |
+
""")
|
145 |
+
|
146 |
+
if __name__ == "__main__":
|
147 |
+
main()
|
data_ret.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset
|
2 |
+
|
3 |
+
# Load the dataset (specify split as 'train' to load the training data)
|
4 |
+
dataset = load_dataset('tom-010/google_natural_questions_answerability', split='train')
|
5 |
+
|
6 |
+
# Function to filter based on a query/topic and return relevant data
|
7 |
+
def search_relevant_data(topic="Artificial Intelligence", max_words=100, top_n=100):
|
8 |
+
# Filter the dataset based on the presence of the topic in 'question', 'answer', or 'text' fields
|
9 |
+
filtered_data = dataset.filter(
|
10 |
+
lambda x: (
|
11 |
+
(x['question'] is not None and topic.lower() in x['question'].lower()) or
|
12 |
+
(x['answer'] is not None and topic.lower() in x['answer'].lower()) or
|
13 |
+
(x['text'] is not None and topic.lower() in x['text'].lower())
|
14 |
+
)
|
15 |
+
)
|
16 |
+
|
17 |
+
# Ensure we only select up to the available number of rows
|
18 |
+
#num_to_select = min(top_n, len(filtered_data)) # Choose the minimum of top_n and available data
|
19 |
+
#filtered_data = filtered_data.select(range(num_to_select)) # Select up to 'num_to_select' rows
|
20 |
+
filtered_data = filtered_data.select(range(min(top_n, len(filtered_data))))
|
21 |
+
|
22 |
+
|
23 |
+
# Create a list to store the relevant data
|
24 |
+
relevant_documents = []
|
25 |
+
|
26 |
+
# Display and store an excerpt of the answer for each relevant entry
|
27 |
+
for entry in filtered_data:
|
28 |
+
# Check the type of 'entry' first to ensure it's a dictionary
|
29 |
+
if isinstance(entry, dict):
|
30 |
+
question = entry.get('question', '') # Accessing the 'question' field safely
|
31 |
+
answer = entry.get('answer', '') # Accessing the 'answer' field safely
|
32 |
+
text = entry.get('text', '') # Accessing the 'text' field safely
|
33 |
+
|
34 |
+
# Only store the first 'max_words' words of the answer or text
|
35 |
+
answer_excerpt = ' '.join(answer.split()[:max_words]) if answer else ""
|
36 |
+
text_excerpt = ' '.join(text.split()[:max_words]) if text else ""
|
37 |
+
|
38 |
+
# Append relevant document information to the list
|
39 |
+
relevant_documents.append({
|
40 |
+
"question": question,
|
41 |
+
"answer": answer_excerpt,
|
42 |
+
"text": text_excerpt
|
43 |
+
})
|
44 |
+
|
45 |
+
# Debugging: Print a preview of the data (optional)
|
46 |
+
#print(f"Question: {question[:20]}...") # Print first 20 chars of the question
|
47 |
+
#print(f"Answer (first {max_words} words): {answer_excerpt[:20]}...") # Print first 20 words of the answer
|
48 |
+
#print(f"Text (first {max_words} words): {text_excerpt[:20]}...") # Print first 20 words of the text
|
49 |
+
#print("-" * 50)
|
50 |
+
else:
|
51 |
+
print("Unexpected entry format:", entry)
|
52 |
+
|
53 |
+
return relevant_documents # Return the list of relevant documents
|
54 |
+
|
55 |
+
# Sample search query
|
56 |
+
#relevant_data = search_relevant_data("vatican city") # Change to the desired query/topic
|
57 |
+
#print(f"Found {len(relevant_data)} relevant documents.")
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit==1.20.0
|
2 |
+
transformers==4.33.0
|
3 |
+
sentence-transformers==2.2.0
|
4 |
+
scipy==1.10.0
|
5 |
+
numpy==1.24.2
|
6 |
+
datasets==2.9.0
|
7 |
+
beautifulsoup4==4.12.0
|
8 |
+
spacy==3.5.0
|
similarity_search.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sentence_transformers import SentenceTransformer
|
2 |
+
from scipy.spatial.distance import cosine
|
3 |
+
import numpy as np
|
4 |
+
from data_ret import search_relevant_data # Assuming this function fetches the data from some source
|
5 |
+
import streamlit as st
|
6 |
+
|
7 |
+
# Load the Sentence Transformer model for similarity search
|
8 |
+
def load_similarity_model():
|
9 |
+
st.write("Loading similarity model...") # Show status on Streamlit
|
10 |
+
retriever_model = SentenceTransformer("all-mpnet-base-v2")
|
11 |
+
st.write("Similarity model loaded.")
|
12 |
+
return retriever_model
|
13 |
+
|
14 |
+
# Create embeddings for the retrieved documents
|
15 |
+
def create_embeddings(documents, model):
|
16 |
+
if not documents:
|
17 |
+
st.write("No documents provided for embedding.")
|
18 |
+
return np.array([]) # Return empty array if no documents
|
19 |
+
|
20 |
+
st.write(f"Creating embeddings for {len(documents)} documents...") # Show progress
|
21 |
+
embeddings = []
|
22 |
+
|
23 |
+
# Track progress of the embedding creation using Streamlit's progress bar
|
24 |
+
progress_bar = st.progress(0)
|
25 |
+
step = 1 / len(documents) # This ensures the progress bar value stays within [0.0, 1.0]
|
26 |
+
|
27 |
+
# Include 'text' in the document text along with 'question' and 'answer'
|
28 |
+
document_texts = [doc['question'] + " " + doc['answer'] + " " + doc.get('text', '') for doc in documents]
|
29 |
+
|
30 |
+
for i, doc_text in enumerate(document_texts):
|
31 |
+
embedding = model.encode(doc_text)
|
32 |
+
embeddings.append(embedding)
|
33 |
+
progress_bar.progress(i * step) # Update the progress bar within valid range
|
34 |
+
|
35 |
+
embeddings = np.array(embeddings)
|
36 |
+
st.write(f"Embeddings created with shape: {embeddings.shape}")
|
37 |
+
return embeddings
|
38 |
+
|
39 |
+
# Retrieve documents based on the question embedding
|
40 |
+
def retrieve_documents(question_embedding, document_embeddings, top_k=5):
|
41 |
+
if document_embeddings.size == 0:
|
42 |
+
st.write("No document embeddings available for retrieval.")
|
43 |
+
return []
|
44 |
+
|
45 |
+
st.write("Calculating similarities between question and documents...")
|
46 |
+
similarities = np.array([1 - cosine(question_embedding, doc_embedding) for doc_embedding in document_embeddings])
|
47 |
+
|
48 |
+
# Get indices of top K similarities (highest similarity first)
|
49 |
+
top_indices = similarities.argsort()[-top_k:][::-1] # Sort in descending order
|
50 |
+
return top_indices
|
51 |
+
|
52 |
+
# Main function to get the context from the most relevant documents based on topic and question
|
53 |
+
def get_relevant_context(question, topic):
|
54 |
+
try:
|
55 |
+
st.write("Searching for relevant documents based on the topic...")
|
56 |
+
relevant_documents = search_relevant_data(topic) # Use dynamic topic for search query
|
57 |
+
|
58 |
+
st.write(f"Found {len(relevant_documents)} relevant documents.")
|
59 |
+
|
60 |
+
if not relevant_documents:
|
61 |
+
return "No relevant documents found."
|
62 |
+
|
63 |
+
retriever_model = load_similarity_model() # Load the similarity model
|
64 |
+
|
65 |
+
# Create document embeddings and show progress
|
66 |
+
document_embeddings = create_embeddings(relevant_documents, retriever_model)
|
67 |
+
|
68 |
+
if document_embeddings.size == 0:
|
69 |
+
return "No embeddings created for relevant documents."
|
70 |
+
|
71 |
+
st.write("Generating question embedding and retrieving relevant documents...")
|
72 |
+
question_embedding = retriever_model.encode(question)
|
73 |
+
relevant_doc_indices = retrieve_documents(question_embedding, document_embeddings)
|
74 |
+
|
75 |
+
if len(relevant_doc_indices) == 0:
|
76 |
+
return "No relevant documents found after embedding."
|
77 |
+
|
78 |
+
# Extract context from the top relevant documents
|
79 |
+
contexts = []
|
80 |
+
for idx in relevant_doc_indices:
|
81 |
+
doc = relevant_documents[idx]
|
82 |
+
context = doc.get('answer', '') + " " + doc.get('text', '')
|
83 |
+
if context.strip():
|
84 |
+
contexts.append(context)
|
85 |
+
|
86 |
+
if not contexts:
|
87 |
+
return "No valid contexts available for answering."
|
88 |
+
|
89 |
+
# Return the combined context for question answering
|
90 |
+
return " ".join(contexts)
|
91 |
+
|
92 |
+
except Exception as e:
|
93 |
+
st.write(f"Error processing question: {str(e)}")
|
94 |
+
return f"Error: {str(e)}"
|