from langchain_astradb import AstraDBVectorStore from langchain_huggingface import HuggingFaceEndpointEmbeddings from langchain.tools.retriever import create_retriever_tool from langchain_huggingface import HuggingFaceEmbeddings import os import pandas as pd import requests import yaml HOLIDAY_KEYWORDS ={ "christmas": ["christmas", "santa", "carol", "holiday"]} class RetrievalTool: def __init__(self): # self.embeddings = HuggingFaceEndpointEmbeddings( # model= "sentence-transformers/all-MiniLM-L6-v2", # task="feature-extraction", # huggingfacehub_api_token= os.environ["HF_TOKEN"]) self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") self.vector_store = AstraDBVectorStore(collection_name="program_astra", embedding=self.embeddings, api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], token= os.environ["ASTRA_DB_APPLICATION_TOKEN"], namespace= os.environ["ASTRA_DB_NAMESPACE"]) self.calendarKey = os.environ["CALENDARIFIC_API_KEY"] self.config = yaml.safe_load(open("config.yaml")) def get_HolidayInCalendar(self,scheduleDate:str): schedule_date = pd.to_datetime(scheduleDate, errors='coerce') if schedule_date is pd.NaT: raise ValueError("Invalid date format. Please use YYYY-MM-DD.") schedule_date = schedule_date.date() year = schedule_date.year month = schedule_date.month country = "US" holidaylist = [] endpoint_url = f"https://calendarific.com/api/v2/holidays?api_key={self.calendarKey}&country={country}&year={year}&month={month}" response = requests.get(endpoint_url) eventName = "" if response.status_code == 200: holidays = response.json()['response']['holidays'] holidaylist = list(map(lambda x: x['description'],holidays)) if any("Christmas" in sentence for sentence in holidaylist): eventName = "christmas" elif any("New year" in sentence for sentence in holidaylist): eventName = "new year" return eventName else: return "" #This is what will be coming from the front end to populate genre and schedule date for astra filtering def buildDataToIncludeHolidayEvents(self, genre:str, scheduleDate:str): if not genre or not scheduleDate: raise ValueError("Genre and schedule date are required.") genre_data = genre.strip().lower() holidayEvent = self.get_HolidayInCalendar(scheduleDate) return genre_data, holidayEvent def get_retrievers(self, user_genres: list, holiday_event: str = None): astraConfig = self.config["astra_db"] astra_filter_genre = {"genre": {"$in": user_genres}} if holiday_event: keywords = HOLIDAY_KEYWORDS.get(holiday_event.lower(), []) astra_filter_holiday = { "$or": [ {"synopsis": {"$in": keywords}} ]} retriever_holiday = self.vector_store.as_retriever(search_kwargs={"filter": astra_filter_holiday, "k": astraConfig["holidaySearch"]["k"]}) retriever_genre = self.vector_store.as_retriever(search_kwargs={"filter": astra_filter_genre, "k": astraConfig["genreSearchWithEvent"]["k"]}) return retriever_genre,retriever_holiday else: retriever = self.vector_store.as_retriever(search_kwargs={"filter": astra_filter_genre, "k": astraConfig["genreSearchWithoutEvent"]["k"]}) return retriever, None #This is the function to pull the relevant docs based on the genre and date def get_relevant_programmes(self, genre: str, scheduleDate: str)-> pd.DataFrame: """Retrieves relevant documents from the vector store based on the genre and date.""" if not self.vector_store or not self.embeddings: raise ValueError("Vector store or embeddings not initialized.") genre, holidayEvents = self.buildDataToIncludeHolidayEvents(genre, scheduleDate) retriever_genre, retriever_holiday = self.get_retrievers([genre], holidayEvents) documents= [] if retriever_holiday: documents.extend(retriever_holiday.invoke(holidayEvents)) if retriever_genre: documents.extend(retriever_genre.invoke(f"{genre} genre based programs")) if not documents: raise ValueError("No relevant documents found.") program_df = pd.DataFrame([doc.metadata for doc in documents]) formatted_entries = [] for _, row in program_df.iterrows(): title = row['programme_title'] duration = row['duration'] rating = row['ratings'] synopsis = " ".join(row['synopsis']) if isinstance(row['synopsis'], list) else str(row['synopsis']) genre = " ".join(row['genre']) if isinstance(row['genre'], list) else str(row['genre']) formatted_entries.append( f"programme_title: {title}, duration: {duration}, ratings: {rating}, synopsis: {synopsis}, genre: {genre}" ) # Join all formatted strings with newline docs = "\n".join(formatted_entries) return docs, holidayEvents # if __name__ == "__main__": # tool = RetrievalTool() # tool2 = tool.get_relevant_documents("comedy", "2023-10-25") # print(tool2)