|
from langchain_astradb import AstraDBVectorStore
|
|
from langchain_huggingface import HuggingFaceEndpointEmbeddings
|
|
from langchain.tools.retriever import create_retriever_tool
|
|
from langchain_huggingface import HuggingFaceEmbeddings
|
|
import os
|
|
import pandas as pd
|
|
import requests
|
|
import yaml
|
|
|
|
HOLIDAY_KEYWORDS ={
|
|
"christmas": ["christmas", "santa", "carol", "holiday"]}
|
|
|
|
class RetrievalTool:
|
|
def __init__(self):
|
|
|
|
|
|
|
|
|
|
|
|
self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
|
|
|
self.vector_store = AstraDBVectorStore(collection_name="program_astra",
|
|
embedding=self.embeddings,
|
|
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
|
token= os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
|
namespace= os.environ["ASTRA_DB_NAMESPACE"])
|
|
|
|
self.calendarKey = os.environ["CALENDARIFIC_API_KEY"]
|
|
self.config = yaml.safe_load(open("config.yaml"))
|
|
|
|
def get_HolidayInCalendar(self,scheduleDate:str):
|
|
schedule_date = pd.to_datetime(scheduleDate, errors='coerce')
|
|
if schedule_date is pd.NaT:
|
|
raise ValueError("Invalid date format. Please use YYYY-MM-DD.")
|
|
schedule_date = schedule_date.date()
|
|
|
|
year = schedule_date.year
|
|
month = schedule_date.month
|
|
holidaylist = []
|
|
|
|
endpoint_url = f"https://calendarific.com/api/v2/holidays?api_key={self.calendarKey}&country='US'&year={year}&month={month}"
|
|
response = requests.get(endpoint_url)
|
|
eventName = ""
|
|
if response.status_code == 200:
|
|
holidays = response.json()['response']['holidays']
|
|
holidaylist = list(map(lambda x: x['description'],holidays))
|
|
|
|
if any("Christmas" in sentence for sentence in holidaylist):
|
|
eventName = "christmas"
|
|
elif any("New year" in sentence for sentence in holidaylist):
|
|
eventName = "new year"
|
|
return eventName
|
|
else:
|
|
return ""
|
|
|
|
|
|
def buildDataToIncludeHolidayEvents(self, genre:str, scheduleDate:str):
|
|
if not genre or not scheduleDate:
|
|
raise ValueError("Genre and schedule date are required.")
|
|
genre_data = genre.strip().lower()
|
|
|
|
holidayEvent = self.get_HolidayInCalendar(scheduleDate)
|
|
|
|
return genre_data, holidayEvent
|
|
|
|
|
|
def get_retrievers(self, user_genres: list, holiday_event: str = None):
|
|
astraConfig = self.config["astra_db"]
|
|
astra_filter_genre = {"genre": {"$in": user_genres}}
|
|
|
|
if holiday_event:
|
|
keywords = HOLIDAY_KEYWORDS.get(holiday_event.lower(), [])
|
|
|
|
astra_filter_holiday = {
|
|
"$or": [
|
|
{"synopsis": {"$in": keywords}}
|
|
]}
|
|
|
|
retriever_holiday = self.vector_store.as_retriever(search_kwargs={"filter": astra_filter_holiday, "k": astraConfig["holidaySearch"]["k"]})
|
|
retriever_genre = self.vector_store.as_retriever(search_kwargs={"filter": astra_filter_genre, "k": astraConfig["genreSearchWithEvent"]["k"]})
|
|
|
|
return retriever_genre,retriever_holiday
|
|
|
|
else:
|
|
retriever = self.vector_store.as_retriever(search_kwargs={"filter": astra_filter_genre, "k": astraConfig["genreSearchWithoutEvent"]["k"]})
|
|
return retriever, None
|
|
|
|
|
|
def get_relevant_programmes(self, genre: str, scheduleDate: str)-> pd.DataFrame:
|
|
"""Retrieves relevant documents from the vector store based on the genre and date."""
|
|
if not self.vector_store or not self.embeddings:
|
|
raise ValueError("Vector store or embeddings not initialized.")
|
|
|
|
genre, holidayEvents = self.buildDataToIncludeHolidayEvents(genre, scheduleDate)
|
|
|
|
retriever_genre, retriever_holiday = self.get_retrievers([genre], holidayEvents)
|
|
documents= []
|
|
|
|
if retriever_holiday:
|
|
documents.extend(retriever_holiday.invoke(holidayEvents))
|
|
|
|
if retriever_genre:
|
|
documents.extend(retriever_genre.invoke(f"{genre} genre based programs"))
|
|
|
|
if not documents:
|
|
raise ValueError("No relevant documents found.")
|
|
program_df = pd.DataFrame([doc.metadata for doc in documents])
|
|
|
|
formatted_entries = []
|
|
for _, row in program_df.iterrows():
|
|
title = row['programme_title']
|
|
duration = row['duration']
|
|
rating = row['ratings']
|
|
synopsis = " ".join(row['synopsis']) if isinstance(row['synopsis'], list) else str(row['synopsis'])
|
|
genre = " ".join(row['genre']) if isinstance(row['genre'], list) else str(row['genre'])
|
|
|
|
formatted_entries.append(
|
|
f"programme_title: {title}, duration: {duration}, ratings: {rating}, synopsis: {synopsis}, genre: {genre}"
|
|
)
|
|
|
|
|
|
docs = "\n".join(formatted_entries)
|
|
return docs, holidayEvents
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|