SoumyaJ's picture
Update tools.py
b3ecf94 verified
from langchain_astradb import AstraDBVectorStore
from langchain_huggingface import HuggingFaceEndpointEmbeddings
from langchain.tools.retriever import create_retriever_tool
from langchain_huggingface import HuggingFaceEmbeddings
import os
import pandas as pd
import requests
import yaml
HOLIDAY_KEYWORDS ={
"christmas": ["christmas", "santa", "carol", "holiday"]}
class RetrievalTool:
def __init__(self):
# self.embeddings = HuggingFaceEndpointEmbeddings(
# model= "sentence-transformers/all-MiniLM-L6-v2",
# task="feature-extraction",
# huggingfacehub_api_token= os.environ["HF_TOKEN"])
self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
self.vector_store = AstraDBVectorStore(collection_name="program_astra",
embedding=self.embeddings,
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
token= os.environ["ASTRA_DB_APPLICATION_TOKEN"],
namespace= os.environ["ASTRA_DB_NAMESPACE"])
self.calendarKey = os.environ["CALENDARIFIC_API_KEY"]
self.config = yaml.safe_load(open("config.yaml"))
def get_HolidayInCalendar(self,scheduleDate:str):
schedule_date = pd.to_datetime(scheduleDate, errors='coerce')
if schedule_date is pd.NaT:
raise ValueError("Invalid date format. Please use YYYY-MM-DD.")
schedule_date = schedule_date.date()
year = schedule_date.year
month = schedule_date.month
country = "US"
holidaylist = []
endpoint_url = f"https://calendarific.com/api/v2/holidays?api_key={self.calendarKey}&country={country}&year={year}&month={month}"
response = requests.get(endpoint_url)
eventName = ""
if response.status_code == 200:
holidays = response.json()['response']['holidays']
holidaylist = list(map(lambda x: x['description'],holidays))
if any("Christmas" in sentence for sentence in holidaylist):
eventName = "christmas"
elif any("New year" in sentence for sentence in holidaylist):
eventName = "new year"
return eventName
else:
return ""
#This is what will be coming from the front end to populate genre and schedule date for astra filtering
def buildDataToIncludeHolidayEvents(self, genre:str, scheduleDate:str):
if not genre or not scheduleDate:
raise ValueError("Genre and schedule date are required.")
genre_data = genre.strip().lower()
holidayEvent = self.get_HolidayInCalendar(scheduleDate)
return genre_data, holidayEvent
def get_retrievers(self, user_genres: list, holiday_event: str = None):
astraConfig = self.config["astra_db"]
astra_filter_genre = {"genre": {"$in": user_genres}}
if holiday_event:
keywords = HOLIDAY_KEYWORDS.get(holiday_event.lower(), [])
astra_filter_holiday = {
"$or": [
{"synopsis": {"$in": keywords}}
]}
retriever_holiday = self.vector_store.as_retriever(search_kwargs={"filter": astra_filter_holiday, "k": astraConfig["holidaySearch"]["k"]})
retriever_genre = self.vector_store.as_retriever(search_kwargs={"filter": astra_filter_genre, "k": astraConfig["genreSearchWithEvent"]["k"]})
return retriever_genre,retriever_holiday
else:
retriever = self.vector_store.as_retriever(search_kwargs={"filter": astra_filter_genre, "k": astraConfig["genreSearchWithoutEvent"]["k"]})
return retriever, None
#This is the function to pull the relevant docs based on the genre and date
def get_relevant_programmes(self, genre: str, scheduleDate: str)-> pd.DataFrame:
"""Retrieves relevant documents from the vector store based on the genre and date."""
if not self.vector_store or not self.embeddings:
raise ValueError("Vector store or embeddings not initialized.")
genre, holidayEvents = self.buildDataToIncludeHolidayEvents(genre, scheduleDate)
retriever_genre, retriever_holiday = self.get_retrievers([genre], holidayEvents)
documents= []
if retriever_holiday:
documents.extend(retriever_holiday.invoke(holidayEvents))
if retriever_genre:
documents.extend(retriever_genre.invoke(f"{genre} genre based programs"))
if not documents:
raise ValueError("No relevant documents found.")
program_df = pd.DataFrame([doc.metadata for doc in documents])
formatted_entries = []
for _, row in program_df.iterrows():
title = row['programme_title']
duration = row['duration']
rating = row['ratings']
synopsis = " ".join(row['synopsis']) if isinstance(row['synopsis'], list) else str(row['synopsis'])
genre = " ".join(row['genre']) if isinstance(row['genre'], list) else str(row['genre'])
formatted_entries.append(
f"programme_title: {title}, duration: {duration}, ratings: {rating}, synopsis: {synopsis}, genre: {genre}"
)
# Join all formatted strings with newline
docs = "\n".join(formatted_entries)
return docs, holidayEvents
# if __name__ == "__main__":
# tool = RetrievalTool()
# tool2 = tool.get_relevant_documents("comedy", "2023-10-25")
# print(tool2)