Shriharsh commited on
Commit
c2198d6
·
verified ·
1 Parent(s): 1104992

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -9
app.py CHANGED
@@ -15,8 +15,8 @@ logging.basicConfig(
15
  logger = logging.getLogger()
16
 
17
  # Load models
18
- qa_model = pipeline("question-answering", model="distilbert-base-uncased-distilled-squad")
19
- embedder = SentenceTransformer('all-MiniLM-L6-v2')
20
 
21
  # Helper function to extract text from PDF
22
  def extract_text_from_pdf(file_path):
@@ -24,19 +24,27 @@ def extract_text_from_pdf(file_path):
24
  with open(file_path, "rb") as file:
25
  pdf_reader = PyPDF2.PdfReader(file)
26
  for page in pdf_reader.pages:
27
- text += page.extract_text() + "\n"
28
  return text
29
 
30
  # Find the most relevant section in the document
 
31
  def find_relevant_section(query, sections, section_embeddings):
 
 
 
 
 
32
  stopwords = {"and", "the", "is", "for", "to", "a", "an", "of", "in", "on", "at", "with", "by", "it", "as", "so", "what"}
33
 
34
  # Semantic search
35
  query_embedding = embedder.encode(query, convert_to_tensor=True)
36
- similarities = util.cos_sim(query_embedding, section_embeddings)[0]
37
  best_idx = similarities.argmax().item()
38
  best_section = sections[best_idx]
39
  similarity_score = similarities[best_idx].item()
 
 
40
 
41
  SIMILARITY_THRESHOLD = 0.4
42
  if similarity_score >= SIMILARITY_THRESHOLD:
@@ -46,10 +54,14 @@ def find_relevant_section(query, sections, section_embeddings):
46
  logger.info(f"Low similarity ({similarity_score}). Falling back to keyword search.")
47
 
48
  # Keyword-based fallback search with stopword filtering
 
49
  query_words = {word for word in query.lower().split() if word not in stopwords}
50
  for section in sections:
51
  section_words = {word for word in section.lower().split() if word not in stopwords}
52
  common_words = query_words.intersection(section_words)
 
 
 
53
  if len(common_words) >= 2:
54
  logger.info(f"Keyword match found for query: {query} with common words: {common_words}")
55
  return section
@@ -57,11 +69,14 @@ def find_relevant_section(query, sections, section_embeddings):
57
  logger.info(f"No good keyword match found. Returning default fallback response.")
58
  return "I don’t have enough information to answer that."
59
 
60
- # Process the uploaded file with detailed logging
61
  def process_file(file, state):
 
 
62
  if file is None:
63
  logger.info("No file uploaded.")
64
  return [("Bot", "Please upload a file.")], state
 
 
65
 
66
  file_path = file.name
67
  if file_path.lower().endswith(".pdf"):
@@ -74,9 +89,12 @@ def process_file(file, state):
74
  else:
75
  logger.error(f"Unsupported file format: {file_path}")
76
  return [("Bot", "Unsupported file format. Please upload a PDF or TXT file.")], state
77
-
 
78
  sections = text.split('\n\n')
79
  section_embeddings = embedder.encode(sections, convert_to_tensor=True)
 
 
80
  state['document_text'] = text
81
  state['sections'] = sections
82
  state['section_embeddings'] = section_embeddings
@@ -87,56 +105,77 @@ def process_file(file, state):
87
  logger.info(f"Processed file: {file_path}")
88
  return state['chat_history'], state
89
 
90
- # Handle user input (queries and feedback)
91
  def handle_input(user_input, state):
 
 
92
  if state['mode'] == 'waiting_for_upload':
93
  state['chat_history'].append(("Bot", "Please upload a file first."))
94
  logger.info("User attempted to interact without uploading a file.")
 
95
  elif state['mode'] == 'waiting_for_query':
96
  query = user_input
97
  state['current_query'] = query
98
  state['feedback_count'] = 0
 
 
99
  context = find_relevant_section(query, state['sections'], state['section_embeddings'])
 
 
100
  if context == "I don’t have enough information to answer that.":
101
  answer = context
102
  else:
103
  result = qa_model(question=query, context=context)
104
  answer = result["answer"]
 
105
  state['last_answer'] = answer
106
  state['mode'] = 'waiting_for_feedback'
107
  state['chat_history'].append(("User", query))
108
  state['chat_history'].append(("Bot", f"Answer: {answer}\nPlease provide feedback: good, too vague, not helpful."))
109
  logger.info(f"Query: {query}, Answer: {answer}")
 
110
  elif state['mode'] == 'waiting_for_feedback':
111
  feedback = user_input.lower()
112
  state['chat_history'].append(("User", feedback))
113
  logger.info(f"Feedback: {feedback}")
 
 
 
114
  if feedback == "good" or state['feedback_count'] >= 2:
115
  state['mode'] = 'waiting_for_query'
 
116
  if feedback == "good":
117
  state['chat_history'].append(("Bot", "Thank you for your feedback. You can ask another question."))
118
  logger.info("Feedback accepted as 'good'. Waiting for next query.")
 
119
  else:
120
  state['chat_history'].append(("Bot", "Maximum feedback iterations reached. You can ask another question."))
121
  logger.info("Max feedback iterations reached. Waiting for next query.")
 
122
  else:
123
  query = state['current_query']
124
  context = find_relevant_section(query, state['sections'], state['section_embeddings'])
 
125
  if feedback == "too vague":
126
  adjusted_answer = f"{state['last_answer']}\n\n(More details:\n{context[:500]}...)"
 
127
  elif feedback == "not helpful":
128
  adjusted_answer = qa_model(question=query + " Please provide more detailed information with examples.", context=context)['answer']
 
129
  else:
130
  state['chat_history'].append(("Bot", "Please provide valid feedback: good, too vague, not helpful."))
131
  logger.info(f"Invalid feedback received: {feedback}")
132
  return state['chat_history'], state
 
133
  state['last_answer'] = adjusted_answer
134
  state['feedback_count'] += 1
135
  state['chat_history'].append(("Bot", f"Updated answer: {adjusted_answer}\nPlease provide feedback: good, too vague, not helpful."))
136
  logger.info(f"Adjusted answer: {adjusted_answer}")
 
137
  return state['chat_history'], state
138
 
139
  # Function to return the up-to-date log file for download
 
140
  def get_log_file():
141
  # Flush all log handlers to ensure log file is current
142
  for handler in logger.handlers:
@@ -148,7 +187,7 @@ def get_log_file():
148
  logger.info("Log file downloaded by user.")
149
  return "support_bot_log.txt"
150
 
151
- # Initial state
152
  initial_state = {
153
  'document_text': None,
154
  'sections': None,
@@ -182,4 +221,4 @@ with gr.Blocks() as demo:
182
  # Set up download log button
183
  download_btn.click(fn=get_log_file, inputs=[], outputs=download_file)
184
 
185
- demo.launch(share=True)
 
15
  logger = logging.getLogger()
16
 
17
  # Load models
18
+ qa_model = pipeline("question-answering", model="distilbert-base-uncased-distilled-squad") # Load the Hugging Face QA model for extracting answers from retrieved context.
19
+ embedder = SentenceTransformer('all-MiniLM-L6-v2') # Loading SentenceTransformer to convert text into vector embeddings for cosine similarity search.
20
 
21
  # Helper function to extract text from PDF
22
  def extract_text_from_pdf(file_path):
 
24
  with open(file_path, "rb") as file:
25
  pdf_reader = PyPDF2.PdfReader(file)
26
  for page in pdf_reader.pages:
27
+ text += page.extract_text() + "\n" # Extract text from each page and concatenate.
28
  return text
29
 
30
  # Find the most relevant section in the document
31
+
32
  def find_relevant_section(query, sections, section_embeddings):
33
+ """
34
+ 1. First, it performs a semantic search using cosine similarity.
35
+ 2. If the similarity score is below a threshold, it falls back to a keyword-based search.
36
+ """
37
+
38
  stopwords = {"and", "the", "is", "for", "to", "a", "an", "of", "in", "on", "at", "with", "by", "it", "as", "so", "what"}
39
 
40
  # Semantic search
41
  query_embedding = embedder.encode(query, convert_to_tensor=True)
42
+ similarities = util.cos_sim(query_embedding, section_embeddings)[0] # Compute cosine similarity between the query embedding and all section embeddings.
43
  best_idx = similarities.argmax().item()
44
  best_section = sections[best_idx]
45
  similarity_score = similarities[best_idx].item()
46
+
47
+ # Defining a threshold to determine if semantic search is confident enough.
48
 
49
  SIMILARITY_THRESHOLD = 0.4
50
  if similarity_score >= SIMILARITY_THRESHOLD:
 
54
  logger.info(f"Low similarity ({similarity_score}). Falling back to keyword search.")
55
 
56
  # Keyword-based fallback search with stopword filtering
57
+
58
  query_words = {word for word in query.lower().split() if word not in stopwords}
59
  for section in sections:
60
  section_words = {word for word in section.lower().split() if word not in stopwords}
61
  common_words = query_words.intersection(section_words)
62
+
63
+ # If at least two words match, return this section.
64
+
65
  if len(common_words) >= 2:
66
  logger.info(f"Keyword match found for query: {query} with common words: {common_words}")
67
  return section
 
69
  logger.info(f"No good keyword match found. Returning default fallback response.")
70
  return "I don’t have enough information to answer that."
71
 
 
72
  def process_file(file, state):
73
+ """Handles the uploaded file, processes its text, and prepares it for querying."""
74
+
75
  if file is None:
76
  logger.info("No file uploaded.")
77
  return [("Bot", "Please upload a file.")], state
78
+
79
+ # Determine file type and extract text accordingly.
80
 
81
  file_path = file.name
82
  if file_path.lower().endswith(".pdf"):
 
89
  else:
90
  logger.error(f"Unsupported file format: {file_path}")
91
  return [("Bot", "Unsupported file format. Please upload a PDF or TXT file.")], state
92
+
93
+ # Split document into sections and encode them into embeddings.
94
  sections = text.split('\n\n')
95
  section_embeddings = embedder.encode(sections, convert_to_tensor=True)
96
+
97
+ # Store extracted text and embeddings in the chatbot's state dictionary.
98
  state['document_text'] = text
99
  state['sections'] = sections
100
  state['section_embeddings'] = section_embeddings
 
105
  logger.info(f"Processed file: {file_path}")
106
  return state['chat_history'], state
107
 
108
+
109
  def handle_input(user_input, state):
110
+ """Processes user queries, fetches answers, and handles feedback loops."""
111
+
112
  if state['mode'] == 'waiting_for_upload':
113
  state['chat_history'].append(("Bot", "Please upload a file first."))
114
  logger.info("User attempted to interact without uploading a file.")
115
+
116
  elif state['mode'] == 'waiting_for_query':
117
  query = user_input
118
  state['current_query'] = query
119
  state['feedback_count'] = 0
120
+
121
+ # Finding the best matching section.
122
  context = find_relevant_section(query, state['sections'], state['section_embeddings'])
123
+
124
+ # Generating an answer using the QA model.
125
  if context == "I don’t have enough information to answer that.":
126
  answer = context
127
  else:
128
  result = qa_model(question=query, context=context)
129
  answer = result["answer"]
130
+
131
  state['last_answer'] = answer
132
  state['mode'] = 'waiting_for_feedback'
133
  state['chat_history'].append(("User", query))
134
  state['chat_history'].append(("Bot", f"Answer: {answer}\nPlease provide feedback: good, too vague, not helpful."))
135
  logger.info(f"Query: {query}, Answer: {answer}")
136
+
137
  elif state['mode'] == 'waiting_for_feedback':
138
  feedback = user_input.lower()
139
  state['chat_history'].append(("User", feedback))
140
  logger.info(f"Feedback: {feedback}")
141
+
142
+ # Handling feedback responses.
143
+
144
  if feedback == "good" or state['feedback_count'] >= 2:
145
  state['mode'] = 'waiting_for_query'
146
+
147
  if feedback == "good":
148
  state['chat_history'].append(("Bot", "Thank you for your feedback. You can ask another question."))
149
  logger.info("Feedback accepted as 'good'. Waiting for next query.")
150
+
151
  else:
152
  state['chat_history'].append(("Bot", "Maximum feedback iterations reached. You can ask another question."))
153
  logger.info("Max feedback iterations reached. Waiting for next query.")
154
+
155
  else:
156
  query = state['current_query']
157
  context = find_relevant_section(query, state['sections'], state['section_embeddings'])
158
+
159
  if feedback == "too vague":
160
  adjusted_answer = f"{state['last_answer']}\n\n(More details:\n{context[:500]}...)"
161
+
162
  elif feedback == "not helpful":
163
  adjusted_answer = qa_model(question=query + " Please provide more detailed information with examples.", context=context)['answer']
164
+
165
  else:
166
  state['chat_history'].append(("Bot", "Please provide valid feedback: good, too vague, not helpful."))
167
  logger.info(f"Invalid feedback received: {feedback}")
168
  return state['chat_history'], state
169
+
170
  state['last_answer'] = adjusted_answer
171
  state['feedback_count'] += 1
172
  state['chat_history'].append(("Bot", f"Updated answer: {adjusted_answer}\nPlease provide feedback: good, too vague, not helpful."))
173
  logger.info(f"Adjusted answer: {adjusted_answer}")
174
+
175
  return state['chat_history'], state
176
 
177
  # Function to return the up-to-date log file for download
178
+
179
  def get_log_file():
180
  # Flush all log handlers to ensure log file is current
181
  for handler in logger.handlers:
 
187
  logger.info("Log file downloaded by user.")
188
  return "support_bot_log.txt"
189
 
190
+ # Initial state setup
191
  initial_state = {
192
  'document_text': None,
193
  'sections': None,
 
221
  # Set up download log button
222
  download_btn.click(fn=get_log_file, inputs=[], outputs=download_file)
223
 
224
+ demo.launch(share=True)