dataprincess commited on
Commit
f093d4b
·
verified ·
1 Parent(s): 20768f0
Files changed (1) hide show
  1. app.py +31 -36
app.py CHANGED
@@ -9,8 +9,7 @@ from tqdm.auto import tqdm
9
  import streamlit as st
10
  import re
11
 
12
-
13
- # Constants (hardcoded)
14
  FILE_PATH = "anjibot_chunks.json"
15
  BATCH_SIZE = 384
16
  INDEX_NAME = "groq-llama-3-rag"
@@ -55,44 +54,42 @@ for i in tqdm(range(0, len(data['id']), BATCH_SIZE)):
55
  index.upsert(vectors=to_upsert)
56
 
57
  def extract_course_code(text) -> list[str]:
 
58
  pattern = r'\b(?:geds?|stats?|maths?|cosc|seng|itgy)\s*\d{3}\b'
59
  match = re.findall(pattern, text, re.IGNORECASE)
60
  return match if match else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- def get_docs(query: str, top_k: int, batch_size: int = 5, threshold: float = 0.66) -> list[str]:
63
- queried_course_codes = extract_course_code(query)
64
 
65
- i = 0
66
- relevant_docs = []
67
-
68
- while True:
69
  xq = encoder.encode(query)
70
- res = index.query(vector=xq.tolist(), top_k=batch_size, include_metadata=True, offset=i)
71
-
72
- if len(res["matches"]) == 0:
73
- break
74
-
75
- for match in res["matches"]:
76
- similarity_score = match['score']
77
- content = match["metadata"]['content']
78
-
79
- if similarity_score >= threshold:
80
- if queried_course_codes:
81
- for course_code in queried_course_codes:
82
- if course_code in content:
83
- relevant_docs.append(content)
84
- break
85
-
86
- if relevant_docs:
87
- break
88
-
89
- i += batch_size
90
-
91
- if relevant_docs:
92
- return relevant_docs
93
- else:
94
- return ["No exact match found for the course code, even after searching with a higher similarity score."]
95
-
96
 
97
  def get_response(query: str, docs: list[str]) -> str:
98
  system_message = (
@@ -115,8 +112,6 @@ def get_response(query: str, docs: list[str]) -> str:
115
  )
116
  return chat_response.choices[0].message.content
117
 
118
-
119
-
120
  def handle_query(user_query: str):
121
 
122
  # Get relevant documents
 
9
  import streamlit as st
10
  import re
11
 
12
+ # Variables
 
13
  FILE_PATH = "anjibot_chunks.json"
14
  BATCH_SIZE = 384
15
  INDEX_NAME = "groq-llama-3-rag"
 
54
  index.upsert(vectors=to_upsert)
55
 
56
  def extract_course_code(text) -> list[str]:
57
+ # Improved pattern with correct case insensitivity and spacing allowance
58
  pattern = r'\b(?:geds?|stats?|maths?|cosc|seng|itgy)\s*\d{3}\b'
59
  match = re.findall(pattern, text, re.IGNORECASE)
60
  return match if match else None
61
+
62
+ def get_docs(query: str, top_k: int) -> list[str]:
63
+ # Extract course code(s) from the query
64
+ course_code = extract_course_code(query)
65
+ exact_matches = []
66
+
67
+ if course_code:
68
+ # Normalize course_code to lowercase for case-insensitive matching
69
+ course_code = [code.lower() for code in course_code]
70
+
71
+ # Check for exact match in metadata
72
+ exact_matches = [
73
+ x['content'] for x in data['metadata']
74
+ if any(code in x['content'].lower() for code in course_code)
75
+ ]
76
 
77
+ # Calculate remaining slots if we have fewer than top_k exact matches
78
+ remaining_slots = top_k - len(exact_matches)
79
 
80
+ if remaining_slots > 0:
81
+ # Perform embedding search for either the entire top_k if no exact match, or the remaining slots
 
 
82
  xq = encoder.encode(query)
83
+ res = index.query(vector=xq.tolist(), top_k=remaining_slots if exact_matches else top_k, include_metadata=True)
84
+
85
+ # Add embedding-based matches (avoiding duplicates)
86
+ embedding_matches = [x["metadata"]['content'] for x in res["matches"]]
87
+
88
+ # Combine exact matches with embedding matches
89
+ exact_matches.extend(embedding_matches)
90
+
91
+ # Return the first top_k results
92
+ return exact_matches[:top_k]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  def get_response(query: str, docs: list[str]) -> str:
95
  system_message = (
 
112
  )
113
  return chat_response.choices[0].message.content
114
 
 
 
115
  def handle_query(user_query: str):
116
 
117
  # Get relevant documents