dataprincess commited on
Commit
20768f0
·
verified ·
1 Parent(s): 5cf9a77

added regex

Browse files
Files changed (1) hide show
  1. app.py +41 -13
app.py CHANGED
@@ -7,15 +7,8 @@ from pinecone import Pinecone, ServerlessSpec
7
  from groq import Groq
8
  from tqdm.auto import tqdm
9
  import streamlit as st
 
10
 
11
- # Required imports
12
- import json
13
- import time
14
- import os
15
- from sentence_transformers import SentenceTransformer
16
- from pinecone import Pinecone, ServerlessSpec
17
- from groq import Groq
18
- from tqdm.auto import tqdm
19
 
20
  # Constants (hardcoded)
21
  FILE_PATH = "anjibot_chunks.json"
@@ -61,10 +54,45 @@ for i in tqdm(range(0, len(data['id']), BATCH_SIZE)):
61
  to_upsert = list(zip(batch["id"], embeds, batch["metadata"]))
62
  index.upsert(vectors=to_upsert)
63
 
64
- def get_docs(query: str, top_k: int) -> list[str]:
65
- xq = encoder.encode(query)
66
- res = index.query(vector=xq.tolist(), top_k=top_k, include_metadata=True)
67
- return [x["metadata"]['content'] for x in res["matches"]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  def get_response(query: str, docs: list[str]) -> str:
70
  system_message = (
@@ -92,7 +120,7 @@ def get_response(query: str, docs: list[str]) -> str:
92
  def handle_query(user_query: str):
93
 
94
  # Get relevant documents
95
- docs = get_docs(user_query, top_k=5)
96
 
97
  # Generate and return response
98
  response = get_response(user_query, docs=docs)
 
7
  from groq import Groq
8
  from tqdm.auto import tqdm
9
  import streamlit as st
10
+ import re
11
 
 
 
 
 
 
 
 
 
12
 
13
  # Constants (hardcoded)
14
  FILE_PATH = "anjibot_chunks.json"
 
54
  to_upsert = list(zip(batch["id"], embeds, batch["metadata"]))
55
  index.upsert(vectors=to_upsert)
56
 
57
+ def extract_course_code(text) -> list[str]:
58
+ pattern = r'\b(?:geds?|stats?|maths?|cosc|seng|itgy)\s*\d{3}\b'
59
+ match = re.findall(pattern, text, re.IGNORECASE)
60
+ return match if match else None
61
+
62
+ def get_docs(query: str, top_k: int, batch_size: int = 5, threshold: float = 0.66) -> list[str]:
63
+ queried_course_codes = extract_course_code(query)
64
+
65
+ i = 0
66
+ relevant_docs = []
67
+
68
+ while True:
69
+ xq = encoder.encode(query)
70
+ res = index.query(vector=xq.tolist(), top_k=batch_size, include_metadata=True, offset=i)
71
+
72
+ if len(res["matches"]) == 0:
73
+ break
74
+
75
+ for match in res["matches"]:
76
+ similarity_score = match['score']
77
+ content = match["metadata"]['content']
78
+
79
+ if similarity_score >= threshold:
80
+ if queried_course_codes:
81
+ for course_code in queried_course_codes:
82
+ if course_code in content:
83
+ relevant_docs.append(content)
84
+ break
85
+
86
+ if relevant_docs:
87
+ break
88
+
89
+ i += batch_size
90
+
91
+ if relevant_docs:
92
+ return relevant_docs
93
+ else:
94
+ return ["No exact match found for the course code, even after searching with a higher similarity score."]
95
+
96
 
97
  def get_response(query: str, docs: list[str]) -> str:
98
  system_message = (
 
120
  def handle_query(user_query: str):
121
 
122
  # Get relevant documents
123
+ docs = get_docs(user_query, top_k=5, threshold=0.66)
124
 
125
  # Generate and return response
126
  response = get_response(user_query, docs=docs)