|
from langchain_astradb import AstraDBVectorStore |
|
from langchain_huggingface import HuggingFaceEndpointEmbeddings |
|
from langchain.tools.retriever import create_retriever_tool |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
import os |
|
import pandas as pd |
|
import requests |
|
import yaml |
|
|
|
HOLIDAY_KEYWORDS ={ |
|
"christmas": ["christmas", "santa", "carol", "holiday"]} |
|
|
|
|
|
|
|
class RetrievalTool: |
|
def __init__(self): |
|
|
|
|
|
|
|
|
|
|
|
self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") |
|
|
|
self.vector_store = AstraDBVectorStore(collection_name="program_astra", |
|
embedding=self.embeddings, |
|
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], |
|
token= os.environ["ASTRA_DB_APPLICATION_TOKEN"], |
|
namespace= os.environ["ASTRA_DB_NAMESPACE"]) |
|
|
|
self.calendarKey = os.environ["CALENDARIFIC_API_KEY"] |
|
self.config = yaml.safe_load(open("config.yaml")) |
|
|
|
def get_HolidayInCalendar(self,scheduleDate:str): |
|
schedule_date = pd.to_datetime(scheduleDate, errors='coerce') |
|
if schedule_date is pd.NaT: |
|
raise ValueError("Invalid date format. Please use YYYY-MM-DD.") |
|
schedule_date = schedule_date.date() |
|
|
|
year = schedule_date.year |
|
month = schedule_date.month |
|
country = "US" |
|
holidaylist = [] |
|
|
|
endpoint_url = f"https://calendarific.com/api/v2/holidays?api_key={self.calendarKey}&country={country}&year={year}&month={month}" |
|
response = requests.get(endpoint_url) |
|
eventName = "" |
|
if response.status_code == 200: |
|
holidays = response.json()['response']['holidays'] |
|
holidaylist = list(map(lambda x: x['description'],holidays)) |
|
|
|
if any("Christmas" in sentence for sentence in holidaylist): |
|
eventName = "christmas" |
|
elif any("New year" in sentence for sentence in holidaylist): |
|
eventName = "new year" |
|
return eventName |
|
else: |
|
return "" |
|
|
|
|
|
def buildDataToIncludeHolidayEvents(self, genre:str, scheduleDate:str): |
|
if not genre or not scheduleDate: |
|
raise ValueError("Genre and schedule date are required.") |
|
genre_data = genre.strip().lower() |
|
|
|
holidayEvent = self.get_HolidayInCalendar(scheduleDate) |
|
|
|
return genre_data, holidayEvent |
|
|
|
|
|
def get_retrievers(self, user_genres: list, holiday_event: str = None): |
|
astraConfig = self.config["astra_db"] |
|
astra_filter_genre = {"genre": {"$in": user_genres}} |
|
|
|
if holiday_event: |
|
keywords = HOLIDAY_KEYWORDS.get(holiday_event.lower(), []) |
|
|
|
astra_filter_holiday = { |
|
"$or": [ |
|
{"synopsis": {"$in": keywords}} |
|
]} |
|
|
|
retriever_holiday = self.vector_store.as_retriever(search_kwargs={"filter": astra_filter_holiday, "k": astraConfig["holidaySearch"]["k"]}) |
|
retriever_genre = self.vector_store.as_retriever(search_kwargs={"filter": astra_filter_genre, "k": astraConfig["genreSearchWithEvent"]["k"]}) |
|
|
|
return retriever_genre,retriever_holiday |
|
|
|
else: |
|
retriever = self.vector_store.as_retriever(search_kwargs={"filter": astra_filter_genre, "k": astraConfig["genreSearchWithoutEvent"]["k"]}) |
|
return retriever, None |
|
|
|
|
|
def get_relevant_programmes(self, genre: str, scheduleDate: str)-> pd.DataFrame: |
|
"""Retrieves relevant documents from the vector store based on the genre and date.""" |
|
if not self.vector_store or not self.embeddings: |
|
raise ValueError("Vector store or embeddings not initialized.") |
|
|
|
genre, holidayEvents = self.buildDataToIncludeHolidayEvents(genre, scheduleDate) |
|
|
|
retriever_genre, retriever_holiday = self.get_retrievers([genre], holidayEvents) |
|
documents= [] |
|
|
|
if retriever_holiday: |
|
documents.extend(retriever_holiday.invoke(holidayEvents)) |
|
|
|
if retriever_genre: |
|
documents.extend(retriever_genre.invoke(f"{genre} genre based programs")) |
|
|
|
if not documents: |
|
raise ValueError("No relevant documents found.") |
|
program_df = pd.DataFrame([doc.metadata for doc in documents]) |
|
|
|
formatted_entries = [] |
|
for _, row in program_df.iterrows(): |
|
title = row['programme_title'] |
|
duration = row['duration'] |
|
rating = row['ratings'] |
|
synopsis = " ".join(row['synopsis']) if isinstance(row['synopsis'], list) else str(row['synopsis']) |
|
genre = " ".join(row['genre']) if isinstance(row['genre'], list) else str(row['genre']) |
|
|
|
formatted_entries.append( |
|
f"programme_title: {title}, duration: {duration}, ratings: {rating}, synopsis: {synopsis}, genre: {genre}" |
|
) |
|
|
|
|
|
docs = "\n".join(formatted_entries) |
|
return docs, holidayEvents |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|