dataprincess commited on
Commit
b2efd5e
·
verified ·
1 Parent(s): 1667f47

combined rag.py with app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -3
app.py CHANGED
@@ -1,7 +1,114 @@
1
  # Required imports
 
 
 
 
 
 
2
  import streamlit as st
3
- from rag import handle_query
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  def main():
7
  st.title("Ask Anjibot 2.0")
@@ -22,8 +129,6 @@ def main():
22
  response = st.write_stream(handle_query(prompt))
23
  st.session_state.messages.append({"role": "assistant", "content": response})
24
 
25
- append_to_sheet(prompt, response)
26
-
27
  if __name__ == "__main__":
28
  main()
29
 
 
1
  # Required imports
2
+ import json
3
+ import time
4
+ from sentence_transformers import SentenceTransformer
5
+ from pinecone import Pinecone, ServerlessSpec
6
+ from groq import Groq
7
+ from tqdm.auto import tqdm
8
  import streamlit as st
 
9
 
10
+ # Constants (hardcoded)
11
+ FILE_PATH = "anjibot_chunks.json"
12
+ BATCH_SIZE = 384
13
+ INDEX_NAME = "groq-llama-3-rag"
14
+ PINECONE_API_KEY = st.secrets["PINECONE_API_KEY"]
15
+ GROQ_API_KEY = st.secrets["GROQ_API_KEY"]
16
+ DIMENSIONS = 768
17
+
18
+
19
+ def load_data(file_path: str) -> dict:
20
+ with open(file_path, 'r') as file:
21
+ return json.load(file)
22
+
23
+
24
+ def initialize_pinecone(api_key: str, index_name: str, dims: int) -> any:
25
+ pc = Pinecone(api_key=api_key)
26
+ spec = ServerlessSpec(cloud="aws", region='us-east-1')
27
+
28
+ existing_indexes = [index_info["name"] for index_info in pc.list_indexes()]
29
+
30
+ # Check if index already exists; if not, create it
31
+ if index_name not in existing_indexes:
32
+ pc.create_index(index_name, dimension=dims, metric='cosine', spec=spec)
33
+
34
+ # Wait for the index to be initialized
35
+ while not pc.describe_index(index_name).status['ready']:
36
+ time.sleep(1)
37
+
38
+ return pc.Index(index_name)
39
+
40
+
41
+ def upsert_data_to_pinecone(index: any, data: dict):
42
+ encoder = SentenceTransformer('dwzhu/e5-base-4k')
43
+
44
+ for i in tqdm(range(0, len(data['id']), BATCH_SIZE)):
45
+ # Find end of batch
46
+ i_end = min(len(data['id']), i + BATCH_SIZE)
47
+
48
+ # Create batch
49
+ batch = {k: v[i:i_end] for k, v in data.items()}
50
+
51
+ # Create embeddings
52
+ chunks = [f'{x["title"]}: {x["content"]}' for x in batch["metadata"]]
53
+ embeds = encoder.encode(chunks)
54
+
55
+ # Ensure correct length
56
+ assert len(embeds) == (i_end - i)
57
+
58
+ # Upsert to Pinecone
59
+ to_upsert = list(zip(batch["id"], embeds, batch["metadata"]))
60
+ index.upsert(vectors=to_upsert)
61
+
62
+
63
+ def get_docs(query: str, index: any, encoder: any, top_k: int) -> list[str]:
64
+ xq = encoder.encode(query)
65
+ res = index.query(vector=xq.tolist(), top_k=top_k, include_metadata=True)
66
+ return [x["metadata"]['content'] for x in res["matches"]]
67
+
68
+
69
+ def get_response(query: str, docs: list[str], groq_client: any) -> str:
70
+ system_message = (
71
+ "You are Anjibot, the AI course rep of 400 Level Computer Science department. You are always helpful, jovial, can be sarcastica but still sweet.\n"
72
+ "Provide the answer to class related queries using\n"
73
+ "context provided below.\n"
74
+ "If you don't the answer to the user's question based on your pretrained knowledge and the context provided, just direct the user to Anji the human course rep.\n"
75
+ "Anji's phone number: 08145170886.\n\n"
76
+ "CONTEXT:\n"
77
+ "\n---\n".join(docs)
78
+ )
79
+ messages = [
80
+ {"role": "system", "content": system_message},
81
+ {"role": "user", "content": query}
82
+ ]
83
+
84
+ chat_response = groq_client.chat.completions.create(
85
+ model="llama3-70b-8192",
86
+ messages=messages
87
+ )
88
+ return chat_response.choices[0].message.content
89
+
90
+
91
+ def handle_query(user_query: str):
92
+ # Load data
93
+ data = load_data(FILE_PATH)
94
+
95
+ # Initialize Pinecone
96
+ index = initialize_pinecone(PINECONE_API_KEY, INDEX_NAME, DIMENSIONS)
97
+
98
+ # Upsert data into Pinecone
99
+ upsert_data_to_pinecone(index, data)
100
+
101
+ # Initialize encoder and Groq client
102
+ encoder = SentenceTransformer('dwzhu/e5-base-4k')
103
+ groq_client = Groq(api_key=GROQ_API_KEY)
104
+
105
+ # Get relevant documents
106
+ docs = get_docs(user_query, index, encoder, top_k=5)
107
+
108
+ # Generate and return response
109
+ response = get_response(user_query, docs, groq_client)
110
+
111
+ return response
112
 
113
  def main():
114
  st.title("Ask Anjibot 2.0")
 
129
  response = st.write_stream(handle_query(prompt))
130
  st.session_state.messages.append({"role": "assistant", "content": response})
131
 
 
 
132
  if __name__ == "__main__":
133
  main()
134