SoumyaJ commited on
Commit
b3ecf94
·
verified ·
1 Parent(s): 69046d4

Update tools.py

Browse files
Files changed (1) hide show
  1. tools.py +133 -130
tools.py CHANGED
@@ -1,130 +1,133 @@
1
- from langchain_astradb import AstraDBVectorStore
2
- from langchain_huggingface import HuggingFaceEndpointEmbeddings
3
- from langchain.tools.retriever import create_retriever_tool
4
- from langchain_huggingface import HuggingFaceEmbeddings
5
- import os
6
- import pandas as pd
7
- import requests
8
- import yaml
9
-
10
- HOLIDAY_KEYWORDS ={
11
- "christmas": ["christmas", "santa", "carol", "holiday"]}
12
-
13
- class RetrievalTool:
14
- def __init__(self):
15
- # self.embeddings = HuggingFaceEndpointEmbeddings(
16
- # model= "sentence-transformers/all-MiniLM-L6-v2",
17
- # task="feature-extraction",
18
- # huggingfacehub_api_token= os.environ["HF_TOKEN"])
19
-
20
- self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
21
-
22
- self.vector_store = AstraDBVectorStore(collection_name="program_astra",
23
- embedding=self.embeddings,
24
- api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
25
- token= os.environ["ASTRA_DB_APPLICATION_TOKEN"],
26
- namespace= os.environ["ASTRA_DB_NAMESPACE"])
27
-
28
- self.calendarKey = os.environ["CALENDARIFIC_API_KEY"]
29
- self.config = yaml.safe_load(open("config.yaml"))
30
-
31
- def get_HolidayInCalendar(self,scheduleDate:str):
32
- schedule_date = pd.to_datetime(scheduleDate, errors='coerce')
33
- if schedule_date is pd.NaT:
34
- raise ValueError("Invalid date format. Please use YYYY-MM-DD.")
35
- schedule_date = schedule_date.date()
36
-
37
- year = schedule_date.year
38
- month = schedule_date.month
39
- holidaylist = []
40
-
41
- endpoint_url = f"https://calendarific.com/api/v2/holidays?api_key={self.calendarKey}&country='US'&year={year}&month={month}"
42
- response = requests.get(endpoint_url)
43
- eventName = ""
44
- if response.status_code == 200:
45
- holidays = response.json()['response']['holidays']
46
- holidaylist = list(map(lambda x: x['description'],holidays))
47
-
48
- if any("Christmas" in sentence for sentence in holidaylist):
49
- eventName = "christmas"
50
- elif any("New year" in sentence for sentence in holidaylist):
51
- eventName = "new year"
52
- return eventName
53
- else:
54
- return ""
55
-
56
- #This is what will be coming from the front end to populate genre and schedule date for astra filtering
57
- def buildDataToIncludeHolidayEvents(self, genre:str, scheduleDate:str):
58
- if not genre or not scheduleDate:
59
- raise ValueError("Genre and schedule date are required.")
60
- genre_data = genre.strip().lower()
61
-
62
- holidayEvent = self.get_HolidayInCalendar(scheduleDate)
63
-
64
- return genre_data, holidayEvent
65
-
66
-
67
- def get_retrievers(self, user_genres: list, holiday_event: str = None):
68
- astraConfig = self.config["astra_db"]
69
- astra_filter_genre = {"genre": {"$in": user_genres}}
70
-
71
- if holiday_event:
72
- keywords = HOLIDAY_KEYWORDS.get(holiday_event.lower(), [])
73
-
74
- astra_filter_holiday = {
75
- "$or": [
76
- {"synopsis": {"$in": keywords}}
77
- ]}
78
-
79
- retriever_holiday = self.vector_store.as_retriever(search_kwargs={"filter": astra_filter_holiday, "k": astraConfig["holidaySearch"]["k"]})
80
- retriever_genre = self.vector_store.as_retriever(search_kwargs={"filter": astra_filter_genre, "k": astraConfig["genreSearchWithEvent"]["k"]})
81
-
82
- return retriever_genre,retriever_holiday
83
-
84
- else:
85
- retriever = self.vector_store.as_retriever(search_kwargs={"filter": astra_filter_genre, "k": astraConfig["genreSearchWithoutEvent"]["k"]})
86
- return retriever, None
87
-
88
- #This is the function to pull the relevant docs based on the genre and date
89
- def get_relevant_programmes(self, genre: str, scheduleDate: str)-> pd.DataFrame:
90
- """Retrieves relevant documents from the vector store based on the genre and date."""
91
- if not self.vector_store or not self.embeddings:
92
- raise ValueError("Vector store or embeddings not initialized.")
93
-
94
- genre, holidayEvents = self.buildDataToIncludeHolidayEvents(genre, scheduleDate)
95
-
96
- retriever_genre, retriever_holiday = self.get_retrievers([genre], holidayEvents)
97
- documents= []
98
-
99
- if retriever_holiday:
100
- documents.extend(retriever_holiday.invoke(holidayEvents))
101
-
102
- if retriever_genre:
103
- documents.extend(retriever_genre.invoke(f"{genre} genre based programs"))
104
-
105
- if not documents:
106
- raise ValueError("No relevant documents found.")
107
- program_df = pd.DataFrame([doc.metadata for doc in documents])
108
-
109
- formatted_entries = []
110
- for _, row in program_df.iterrows():
111
- title = row['programme_title']
112
- duration = row['duration']
113
- rating = row['ratings']
114
- synopsis = " ".join(row['synopsis']) if isinstance(row['synopsis'], list) else str(row['synopsis'])
115
- genre = " ".join(row['genre']) if isinstance(row['genre'], list) else str(row['genre'])
116
-
117
- formatted_entries.append(
118
- f"programme_title: {title}, duration: {duration}, ratings: {rating}, synopsis: {synopsis}, genre: {genre}"
119
- )
120
-
121
- # Join all formatted strings with newline
122
- docs = "\n".join(formatted_entries)
123
- return docs, holidayEvents
124
-
125
- # if __name__ == "__main__":
126
- # tool = RetrievalTool()
127
- # tool2 = tool.get_relevant_documents("comedy", "2023-10-25")
128
- # print(tool2)
129
-
130
-
 
 
 
 
1
+ from langchain_astradb import AstraDBVectorStore
2
+ from langchain_huggingface import HuggingFaceEndpointEmbeddings
3
+ from langchain.tools.retriever import create_retriever_tool
4
+ from langchain_huggingface import HuggingFaceEmbeddings
5
+ import os
6
+ import pandas as pd
7
+ import requests
8
+ import yaml
9
+
10
+ HOLIDAY_KEYWORDS ={
11
+ "christmas": ["christmas", "santa", "carol", "holiday"]}
12
+
13
+
14
+
15
+ class RetrievalTool:
16
+ def __init__(self):
17
+ # self.embeddings = HuggingFaceEndpointEmbeddings(
18
+ # model= "sentence-transformers/all-MiniLM-L6-v2",
19
+ # task="feature-extraction",
20
+ # huggingfacehub_api_token= os.environ["HF_TOKEN"])
21
+
22
+ self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
23
+
24
+ self.vector_store = AstraDBVectorStore(collection_name="program_astra",
25
+ embedding=self.embeddings,
26
+ api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
27
+ token= os.environ["ASTRA_DB_APPLICATION_TOKEN"],
28
+ namespace= os.environ["ASTRA_DB_NAMESPACE"])
29
+
30
+ self.calendarKey = os.environ["CALENDARIFIC_API_KEY"]
31
+ self.config = yaml.safe_load(open("config.yaml"))
32
+
33
+ def get_HolidayInCalendar(self,scheduleDate:str):
34
+ schedule_date = pd.to_datetime(scheduleDate, errors='coerce')
35
+ if schedule_date is pd.NaT:
36
+ raise ValueError("Invalid date format. Please use YYYY-MM-DD.")
37
+ schedule_date = schedule_date.date()
38
+
39
+ year = schedule_date.year
40
+ month = schedule_date.month
41
+ country = "US"
42
+ holidaylist = []
43
+
44
+ endpoint_url = f"https://calendarific.com/api/v2/holidays?api_key={self.calendarKey}&country={country}&year={year}&month={month}"
45
+ response = requests.get(endpoint_url)
46
+ eventName = ""
47
+ if response.status_code == 200:
48
+ holidays = response.json()['response']['holidays']
49
+ holidaylist = list(map(lambda x: x['description'],holidays))
50
+
51
+ if any("Christmas" in sentence for sentence in holidaylist):
52
+ eventName = "christmas"
53
+ elif any("New year" in sentence for sentence in holidaylist):
54
+ eventName = "new year"
55
+ return eventName
56
+ else:
57
+ return ""
58
+
59
+ #This is what will be coming from the front end to populate genre and schedule date for astra filtering
60
+ def buildDataToIncludeHolidayEvents(self, genre:str, scheduleDate:str):
61
+ if not genre or not scheduleDate:
62
+ raise ValueError("Genre and schedule date are required.")
63
+ genre_data = genre.strip().lower()
64
+
65
+ holidayEvent = self.get_HolidayInCalendar(scheduleDate)
66
+
67
+ return genre_data, holidayEvent
68
+
69
+
70
+ def get_retrievers(self, user_genres: list, holiday_event: str = None):
71
+ astraConfig = self.config["astra_db"]
72
+ astra_filter_genre = {"genre": {"$in": user_genres}}
73
+
74
+ if holiday_event:
75
+ keywords = HOLIDAY_KEYWORDS.get(holiday_event.lower(), [])
76
+
77
+ astra_filter_holiday = {
78
+ "$or": [
79
+ {"synopsis": {"$in": keywords}}
80
+ ]}
81
+
82
+ retriever_holiday = self.vector_store.as_retriever(search_kwargs={"filter": astra_filter_holiday, "k": astraConfig["holidaySearch"]["k"]})
83
+ retriever_genre = self.vector_store.as_retriever(search_kwargs={"filter": astra_filter_genre, "k": astraConfig["genreSearchWithEvent"]["k"]})
84
+
85
+ return retriever_genre,retriever_holiday
86
+
87
+ else:
88
+ retriever = self.vector_store.as_retriever(search_kwargs={"filter": astra_filter_genre, "k": astraConfig["genreSearchWithoutEvent"]["k"]})
89
+ return retriever, None
90
+
91
+ #This is the function to pull the relevant docs based on the genre and date
92
+ def get_relevant_programmes(self, genre: str, scheduleDate: str)-> pd.DataFrame:
93
+ """Retrieves relevant documents from the vector store based on the genre and date."""
94
+ if not self.vector_store or not self.embeddings:
95
+ raise ValueError("Vector store or embeddings not initialized.")
96
+
97
+ genre, holidayEvents = self.buildDataToIncludeHolidayEvents(genre, scheduleDate)
98
+
99
+ retriever_genre, retriever_holiday = self.get_retrievers([genre], holidayEvents)
100
+ documents= []
101
+
102
+ if retriever_holiday:
103
+ documents.extend(retriever_holiday.invoke(holidayEvents))
104
+
105
+ if retriever_genre:
106
+ documents.extend(retriever_genre.invoke(f"{genre} genre based programs"))
107
+
108
+ if not documents:
109
+ raise ValueError("No relevant documents found.")
110
+ program_df = pd.DataFrame([doc.metadata for doc in documents])
111
+
112
+ formatted_entries = []
113
+ for _, row in program_df.iterrows():
114
+ title = row['programme_title']
115
+ duration = row['duration']
116
+ rating = row['ratings']
117
+ synopsis = " ".join(row['synopsis']) if isinstance(row['synopsis'], list) else str(row['synopsis'])
118
+ genre = " ".join(row['genre']) if isinstance(row['genre'], list) else str(row['genre'])
119
+
120
+ formatted_entries.append(
121
+ f"programme_title: {title}, duration: {duration}, ratings: {rating}, synopsis: {synopsis}, genre: {genre}"
122
+ )
123
+
124
+ # Join all formatted strings with newline
125
+ docs = "\n".join(formatted_entries)
126
+ return docs, holidayEvents
127
+
128
+ # if __name__ == "__main__":
129
+ # tool = RetrievalTool()
130
+ # tool2 = tool.get_relevant_documents("comedy", "2023-10-25")
131
+ # print(tool2)
132
+
133
+