File size: 5,784 Bytes
b3ecf94 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from langchain_astradb import AstraDBVectorStore
from langchain_huggingface import HuggingFaceEndpointEmbeddings
from langchain.tools.retriever import create_retriever_tool
from langchain_huggingface import HuggingFaceEmbeddings
import os
import pandas as pd
import requests
import yaml
HOLIDAY_KEYWORDS ={
"christmas": ["christmas", "santa", "carol", "holiday"]}
class RetrievalTool:
def __init__(self):
# self.embeddings = HuggingFaceEndpointEmbeddings(
# model= "sentence-transformers/all-MiniLM-L6-v2",
# task="feature-extraction",
# huggingfacehub_api_token= os.environ["HF_TOKEN"])
self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
self.vector_store = AstraDBVectorStore(collection_name="program_astra",
embedding=self.embeddings,
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
token= os.environ["ASTRA_DB_APPLICATION_TOKEN"],
namespace= os.environ["ASTRA_DB_NAMESPACE"])
self.calendarKey = os.environ["CALENDARIFIC_API_KEY"]
self.config = yaml.safe_load(open("config.yaml"))
def get_HolidayInCalendar(self,scheduleDate:str):
schedule_date = pd.to_datetime(scheduleDate, errors='coerce')
if schedule_date is pd.NaT:
raise ValueError("Invalid date format. Please use YYYY-MM-DD.")
schedule_date = schedule_date.date()
year = schedule_date.year
month = schedule_date.month
country = "US"
holidaylist = []
endpoint_url = f"https://calendarific.com/api/v2/holidays?api_key={self.calendarKey}&country={country}&year={year}&month={month}"
response = requests.get(endpoint_url)
eventName = ""
if response.status_code == 200:
holidays = response.json()['response']['holidays']
holidaylist = list(map(lambda x: x['description'],holidays))
if any("Christmas" in sentence for sentence in holidaylist):
eventName = "christmas"
elif any("New year" in sentence for sentence in holidaylist):
eventName = "new year"
return eventName
else:
return ""
#This is what will be coming from the front end to populate genre and schedule date for astra filtering
def buildDataToIncludeHolidayEvents(self, genre:str, scheduleDate:str):
if not genre or not scheduleDate:
raise ValueError("Genre and schedule date are required.")
genre_data = genre.strip().lower()
holidayEvent = self.get_HolidayInCalendar(scheduleDate)
return genre_data, holidayEvent
def get_retrievers(self, user_genres: list, holiday_event: str = None):
astraConfig = self.config["astra_db"]
astra_filter_genre = {"genre": {"$in": user_genres}}
if holiday_event:
keywords = HOLIDAY_KEYWORDS.get(holiday_event.lower(), [])
astra_filter_holiday = {
"$or": [
{"synopsis": {"$in": keywords}}
]}
retriever_holiday = self.vector_store.as_retriever(search_kwargs={"filter": astra_filter_holiday, "k": astraConfig["holidaySearch"]["k"]})
retriever_genre = self.vector_store.as_retriever(search_kwargs={"filter": astra_filter_genre, "k": astraConfig["genreSearchWithEvent"]["k"]})
return retriever_genre,retriever_holiday
else:
retriever = self.vector_store.as_retriever(search_kwargs={"filter": astra_filter_genre, "k": astraConfig["genreSearchWithoutEvent"]["k"]})
return retriever, None
#This is the function to pull the relevant docs based on the genre and date
def get_relevant_programmes(self, genre: str, scheduleDate: str)-> pd.DataFrame:
"""Retrieves relevant documents from the vector store based on the genre and date."""
if not self.vector_store or not self.embeddings:
raise ValueError("Vector store or embeddings not initialized.")
genre, holidayEvents = self.buildDataToIncludeHolidayEvents(genre, scheduleDate)
retriever_genre, retriever_holiday = self.get_retrievers([genre], holidayEvents)
documents= []
if retriever_holiday:
documents.extend(retriever_holiday.invoke(holidayEvents))
if retriever_genre:
documents.extend(retriever_genre.invoke(f"{genre} genre based programs"))
if not documents:
raise ValueError("No relevant documents found.")
program_df = pd.DataFrame([doc.metadata for doc in documents])
formatted_entries = []
for _, row in program_df.iterrows():
title = row['programme_title']
duration = row['duration']
rating = row['ratings']
synopsis = " ".join(row['synopsis']) if isinstance(row['synopsis'], list) else str(row['synopsis'])
genre = " ".join(row['genre']) if isinstance(row['genre'], list) else str(row['genre'])
formatted_entries.append(
f"programme_title: {title}, duration: {duration}, ratings: {rating}, synopsis: {synopsis}, genre: {genre}"
)
# Join all formatted strings with newline
docs = "\n".join(formatted_entries)
return docs, holidayEvents
# if __name__ == "__main__":
# tool = RetrievalTool()
# tool2 = tool.get_relevant_documents("comedy", "2023-10-25")
# print(tool2)
|