File size: 5,784 Bytes
b3ecf94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from langchain_astradb import AstraDBVectorStore
from langchain_huggingface import HuggingFaceEndpointEmbeddings
from langchain.tools.retriever import create_retriever_tool
from langchain_huggingface import HuggingFaceEmbeddings
import os
import pandas as pd
import requests
import yaml

HOLIDAY_KEYWORDS ={
    "christmas": ["christmas", "santa", "carol", "holiday"]}



class RetrievalTool:
    def __init__(self):
        # self.embeddings = HuggingFaceEndpointEmbeddings(
        # model= "sentence-transformers/all-MiniLM-L6-v2",
        # task="feature-extraction",
        # huggingfacehub_api_token= os.environ["HF_TOKEN"])
        
        self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    
        self.vector_store = AstraDBVectorStore(collection_name="program_astra",
                                           embedding=self.embeddings,
                                           api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
                                           token= os.environ["ASTRA_DB_APPLICATION_TOKEN"],
                                           namespace= os.environ["ASTRA_DB_NAMESPACE"])  
        
        self.calendarKey = os.environ["CALENDARIFIC_API_KEY"]
        self.config = yaml.safe_load(open("config.yaml")) 
    
    def get_HolidayInCalendar(self,scheduleDate:str):        
        schedule_date = pd.to_datetime(scheduleDate, errors='coerce')
        if schedule_date is pd.NaT:
            raise ValueError("Invalid date format. Please use YYYY-MM-DD.")
        schedule_date = schedule_date.date()
        
        year = schedule_date.year
        month = schedule_date.month
        country = "US"
        holidaylist = []

        endpoint_url = f"https://calendarific.com/api/v2/holidays?api_key={self.calendarKey}&country={country}&year={year}&month={month}"
        response = requests.get(endpoint_url)
        eventName = ""
        if response.status_code == 200:
            holidays = response.json()['response']['holidays']
            holidaylist = list(map(lambda x: x['description'],holidays))
            
            if any("Christmas" in sentence for sentence in holidaylist):
                eventName = "christmas"
            elif any("New year" in sentence for sentence in holidaylist):
                eventName = "new year"  
            return eventName  
        else:        
            return ""   
    
    #This is what will be coming from the front end to populate genre and schedule date for astra filtering    
    def buildDataToIncludeHolidayEvents(self, genre:str, scheduleDate:str):        
        if not genre or not scheduleDate:
            raise ValueError("Genre and schedule date are required.")
        genre_data = genre.strip().lower()  
        
        holidayEvent = self.get_HolidayInCalendar(scheduleDate)
        
        return genre_data, holidayEvent   
    
    
    def get_retrievers(self, user_genres: list, holiday_event: str = None):
        astraConfig = self.config["astra_db"]
        astra_filter_genre = {"genre": {"$in": user_genres}}

        if holiday_event:
            keywords = HOLIDAY_KEYWORDS.get(holiday_event.lower(), [])
        
            astra_filter_holiday = {
            "$or": [
                {"synopsis": {"$in": keywords}}
            ]}

            retriever_holiday = self.vector_store.as_retriever(search_kwargs={"filter": astra_filter_holiday, "k": astraConfig["holidaySearch"]["k"]})
            retriever_genre = self.vector_store.as_retriever(search_kwargs={"filter": astra_filter_genre, "k": astraConfig["genreSearchWithEvent"]["k"]})

            return retriever_genre,retriever_holiday

        else:        
            retriever = self.vector_store.as_retriever(search_kwargs={"filter": astra_filter_genre, "k": astraConfig["genreSearchWithoutEvent"]["k"]})
            return retriever, None  
    
    #This is the function to pull the relevant docs based on the genre and date
    def get_relevant_programmes(self, genre: str, scheduleDate: str)-> pd.DataFrame:
        """Retrieves relevant documents from the vector store based on the genre and date."""
        if not self.vector_store or not self.embeddings:
            raise ValueError("Vector store or embeddings not initialized.")  
        
        genre, holidayEvents = self.buildDataToIncludeHolidayEvents(genre, scheduleDate)
        
        retriever_genre, retriever_holiday = self.get_retrievers([genre], holidayEvents) 
        documents= []

        if retriever_holiday:
            documents.extend(retriever_holiday.invoke(holidayEvents))

        if retriever_genre:
            documents.extend(retriever_genre.invoke(f"{genre} genre based programs"))  
            
        if not documents:
            raise ValueError("No relevant documents found.")
        program_df = pd.DataFrame([doc.metadata for doc in documents])
        
        formatted_entries = []
        for _, row in program_df.iterrows():
            title = row['programme_title']
            duration = row['duration']
            rating = row['ratings']
            synopsis = " ".join(row['synopsis']) if isinstance(row['synopsis'], list) else str(row['synopsis'])
            genre = " ".join(row['genre']) if isinstance(row['genre'], list) else str(row['genre'])

            formatted_entries.append(
            f"programme_title: {title}, duration: {duration}, ratings: {rating}, synopsis: {synopsis}, genre: {genre}"
            )

        # Join all formatted strings with newline
        docs = "\n".join(formatted_entries)
        return docs, holidayEvents  

# if __name__ == "__main__":    
#     tool = RetrievalTool()  
#     tool2 = tool.get_relevant_documents("comedy", "2023-10-25")  
#     print(tool2)