Upload 4 files
Browse files- app.py +158 -0
- config.yaml +8 -0
- requirements.txt +10 -0
- tools.py +130 -0
app.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tools import RetrievalTool
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
from langchain_core.prompts import PromptTemplate
|
4 |
+
from langchain_groq import ChatGroq
|
5 |
+
from pydantic import BaseModel, Field
|
6 |
+
from fastapi import FastAPI, HTTPException
|
7 |
+
from fastapi.middleware.cors import CORSMiddleware
|
8 |
+
from fastapi.responses import JSONResponse
|
9 |
+
import pandas as pd
|
10 |
+
import uvicorn
|
11 |
+
import re
|
12 |
+
import os
|
13 |
+
|
14 |
+
load_dotenv()
|
15 |
+
|
16 |
+
app = FastAPI()
|
17 |
+
|
18 |
+
app.add_middleware(
|
19 |
+
CORSMiddleware,
|
20 |
+
allow_origins=["*"],
|
21 |
+
allow_credentials=True,
|
22 |
+
allow_methods=["*"],
|
23 |
+
allow_headers=["*"],
|
24 |
+
)
|
25 |
+
|
26 |
+
os.environ["ASTRA_DB_API_ENDPOINT"] = os.getenv("ASTRA_DB_API_ENDPOINT")
|
27 |
+
os.environ["ASTRA_DB_APPLICATION_TOKEN"] = os.getenv("ASTRA_DB_APPLICATION_TOKEN")
|
28 |
+
os.environ["ASTRA_DB_NAMESPACE"] = os.getenv("ASTRA_DB_NAMESPACE")
|
29 |
+
os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN")
|
30 |
+
os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")
|
31 |
+
os.environ["CALENDARIFIC_API_KEY"] = os.getenv("CALENDARIFIC_API_KEY")
|
32 |
+
|
33 |
+
retrieval_tool = RetrievalTool()
|
34 |
+
|
35 |
+
class ScheduleRecommendationModel(BaseModel):
|
36 |
+
programme_schedule: str = Field(description="The entire list of recommended programs..")
|
37 |
+
reasoning: str = Field(description="The reasoning behind the recommendation.")
|
38 |
+
|
39 |
+
template = """
|
40 |
+
You are a smart TV schedule assistant.
|
41 |
+
|
42 |
+
Your job is to generate a clean, formatted program schedule for a specific day.
|
43 |
+
|
44 |
+
Constraints:
|
45 |
+
- Do NOT include any explanation, notes, or Markdown.
|
46 |
+
- Do NOT repeat any program on the same day.
|
47 |
+
- Format must be exactly: HH:MM - HH:MM : ProgramName
|
48 |
+
- Use the provided channel start time as the beginning of the schedule.
|
49 |
+
- Prime time is from 18:00 to 22:00 — prioritize the highest-rated programs here.
|
50 |
+
- If the date is a holiday (e.g., Christmas), ensure 2 to 3 holiday-themed programs (based on keywords like "Christmas", "Santa", or "Carol" in the synopsis) are included in the schedule.
|
51 |
+
- If it's a weekend, favor family-friendly or entertainment-heavy content.
|
52 |
+
- If it's a weekday, prefer shorter or lighter content during the day and prioritize core genre in prime time.
|
53 |
+
- Do not schedule past 23:59.
|
54 |
+
|
55 |
+
Inputs:
|
56 |
+
- Genre: {genre}
|
57 |
+
- DayType: {day_type} # Either "weekday" or "weekend"
|
58 |
+
- Holiday: {holiday_event} # Either "Christmas", "New year" or None
|
59 |
+
- Start Time: {start_time}
|
60 |
+
- Available Programs:
|
61 |
+
{program_list}
|
62 |
+
|
63 |
+
Now generate the full day schedule starting from {start_time} using the above constraints.
|
64 |
+
"""
|
65 |
+
summary_template = """
|
66 |
+
You are a smart TV reasoning summary assistant.
|
67 |
+
|
68 |
+
Your task is to clearly explain the thought process behind a given TV schedule recommendation.
|
69 |
+
|
70 |
+
The summary should help the user understand why specific programs were selected, why they appear at certain times, and how the genre, ratings, time of day, day type (weekday/weekend), and special events (e.g., holidays like Christmas) influenced the schedule.
|
71 |
+
|
72 |
+
✳️ Instructions:
|
73 |
+
|
74 |
+
Do not add any information that is not already present in the reasoning.Do not hallucinate or make assumptions.
|
75 |
+
|
76 |
+
The summary must reflect the actual reasoning provided by the model.
|
77 |
+
|
78 |
+
Write in a natural, human-readable tone, suitable for a user reading a TV planner explanation.
|
79 |
+
|
80 |
+
Keep it concise but detailed enough to convey scheduling logic (approx. 8-10 lines).
|
81 |
+
|
82 |
+
Highlight how prime-time slots (18:00–22:00) were used for high-rated programs.
|
83 |
+
|
84 |
+
If applicable, explain how holiday content or weekend scheduling influenced the selection.
|
85 |
+
Use the reasoning provided to you and summarize it in a clear and concise manner.
|
86 |
+
{reasoning}
|
87 |
+
"""
|
88 |
+
prompt = PromptTemplate.from_template(template)
|
89 |
+
summary_prompt = PromptTemplate.from_template(summary_template)
|
90 |
+
llm = ChatGroq(model_name = "deepseek-r1-distill-llama-70b", api_key = os.environ["GROQ_API_KEY"])
|
91 |
+
summary_llm = ChatGroq(model_name = "gemma2-9b-it", api_key = os.environ["GROQ_API_KEY"])
|
92 |
+
|
93 |
+
chain = prompt | llm
|
94 |
+
summary_chain = summary_prompt | summary_llm
|
95 |
+
|
96 |
+
def get_dynamic_schedule(program_df:str, genre:str, start_time:str, day_type:str, holiday_event:str):
|
97 |
+
try:
|
98 |
+
response = chain.invoke({"program_list": program_df,
|
99 |
+
"genre": genre,
|
100 |
+
"day_type": day_type,
|
101 |
+
"holiday_event": holiday_event,
|
102 |
+
"start_time": start_time})
|
103 |
+
|
104 |
+
text_data = response.content
|
105 |
+
think_match = re.search(r'<think>(.*?)</think>', text_data, re.DOTALL)
|
106 |
+
if think_match:
|
107 |
+
reasoning = think_match.group(1).strip()
|
108 |
+
reasoning_answer = summarize_reasoning(reasoning)
|
109 |
+
final_answer = text_data.split("</think>")[-1].strip()
|
110 |
+
return ScheduleRecommendationModel(programme_schedule=final_answer, reasoning=reasoning_answer)
|
111 |
+
|
112 |
+
# if text_data and "</think>" in text_data:
|
113 |
+
# result = re.split(r'</think>', text_data, maxsplit=1)[-1].strip()
|
114 |
+
# return result
|
115 |
+
|
116 |
+
return ScheduleRecommendationModel(programme_schedule=response, reasoning="Error while generating reasoning.")
|
117 |
+
|
118 |
+
except Exception as e:
|
119 |
+
return f"Error: {str(e)}"
|
120 |
+
|
121 |
+
def get_weekday_or_weekend(date:str):
|
122 |
+
try:
|
123 |
+
schedule_date = pd.to_datetime(date)
|
124 |
+
if schedule_date.weekday() < 5: # Monday to Friday
|
125 |
+
return "weekday"
|
126 |
+
else: # Saturday and Sunday
|
127 |
+
return "weekend"
|
128 |
+
except ValueError:
|
129 |
+
raise ValueError("Invalid date format. Please use YYYY-MM-DD.")
|
130 |
+
|
131 |
+
def get_schedule_recommendation(genre:str, date:str, start_time:str):
|
132 |
+
program_list, holidayEvent = retrieval_tool.get_relevant_programmes(genre, date)
|
133 |
+
|
134 |
+
day_of_week = get_weekday_or_weekend(date)
|
135 |
+
schedule_recommendation = get_dynamic_schedule(program_list, genre, start_time, day_of_week, holidayEvent)
|
136 |
+
print("Schedule Recommendation:", schedule_recommendation)
|
137 |
+
return schedule_recommendation
|
138 |
+
|
139 |
+
|
140 |
+
def summarize_reasoning(reasoning:str):
|
141 |
+
if reasoning:
|
142 |
+
response = summary_chain.invoke({"reasoning": reasoning})
|
143 |
+
return response.content
|
144 |
+
return "Error while generating reasoning."
|
145 |
+
|
146 |
+
@app.post("/api/v1/getScheduleRecommendation/")
|
147 |
+
async def extract_details(genre:str, date:str, start_time:str):
|
148 |
+
try:
|
149 |
+
return get_schedule_recommendation(genre, date, start_time)
|
150 |
+
except HTTPException as e:
|
151 |
+
return JSONResponse(status_code=500, content={"error": str(e)})
|
152 |
+
|
153 |
+
if __name__ == "__main__":
|
154 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
155 |
+
|
156 |
+
# if __name__ == "__main__":
|
157 |
+
# get_schedule_recommendation('comedy', '2023-12-25', '09:00')
|
158 |
+
|
config.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
astra_db:
|
2 |
+
genreSearchWithEvent:
|
3 |
+
k : 35
|
4 |
+
holidaySearch:
|
5 |
+
k : 5
|
6 |
+
genreSearchWithoutEvent:
|
7 |
+
k : 40
|
8 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
langchain
|
2 |
+
langchain-community
|
3 |
+
langchain-core
|
4 |
+
langchain-groq
|
5 |
+
langchain-astradb
|
6 |
+
langchain-huggingface
|
7 |
+
python-dotenv
|
8 |
+
pandas
|
9 |
+
fastapi
|
10 |
+
uvicorn
|
tools.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_astradb import AstraDBVectorStore
|
2 |
+
from langchain_huggingface import HuggingFaceEndpointEmbeddings
|
3 |
+
from langchain.tools.retriever import create_retriever_tool
|
4 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
5 |
+
import os
|
6 |
+
import pandas as pd
|
7 |
+
import requests
|
8 |
+
import yaml
|
9 |
+
|
10 |
+
HOLIDAY_KEYWORDS ={
|
11 |
+
"christmas": ["christmas", "santa", "carol", "holiday"]}
|
12 |
+
|
13 |
+
class RetrievalTool:
|
14 |
+
def __init__(self):
|
15 |
+
# self.embeddings = HuggingFaceEndpointEmbeddings(
|
16 |
+
# model= "sentence-transformers/all-MiniLM-L6-v2",
|
17 |
+
# task="feature-extraction",
|
18 |
+
# huggingfacehub_api_token= os.environ["HF_TOKEN"])
|
19 |
+
|
20 |
+
self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
21 |
+
|
22 |
+
self.vector_store = AstraDBVectorStore(collection_name="program_astra",
|
23 |
+
embedding=self.embeddings,
|
24 |
+
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
25 |
+
token= os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
26 |
+
namespace= os.environ["ASTRA_DB_NAMESPACE"])
|
27 |
+
|
28 |
+
self.calendarKey = os.environ["CALENDARIFIC_API_KEY"]
|
29 |
+
self.config = yaml.safe_load(open("config.yaml"))
|
30 |
+
|
31 |
+
def get_HolidayInCalendar(self,scheduleDate:str):
|
32 |
+
schedule_date = pd.to_datetime(scheduleDate, errors='coerce')
|
33 |
+
if schedule_date is pd.NaT:
|
34 |
+
raise ValueError("Invalid date format. Please use YYYY-MM-DD.")
|
35 |
+
schedule_date = schedule_date.date()
|
36 |
+
|
37 |
+
year = schedule_date.year
|
38 |
+
month = schedule_date.month
|
39 |
+
holidaylist = []
|
40 |
+
|
41 |
+
endpoint_url = f"https://calendarific.com/api/v2/holidays?api_key={self.calendarKey}&country='US'&year={year}&month={month}"
|
42 |
+
response = requests.get(endpoint_url)
|
43 |
+
eventName = ""
|
44 |
+
if response.status_code == 200:
|
45 |
+
holidays = response.json()['response']['holidays']
|
46 |
+
holidaylist = list(map(lambda x: x['description'],holidays))
|
47 |
+
|
48 |
+
if any("Christmas" in sentence for sentence in holidaylist):
|
49 |
+
eventName = "christmas"
|
50 |
+
elif any("New year" in sentence for sentence in holidaylist):
|
51 |
+
eventName = "new year"
|
52 |
+
return eventName
|
53 |
+
else:
|
54 |
+
return ""
|
55 |
+
|
56 |
+
#This is what will be coming from the front end to populate genre and schedule date for astra filtering
|
57 |
+
def buildDataToIncludeHolidayEvents(self, genre:str, scheduleDate:str):
|
58 |
+
if not genre or not scheduleDate:
|
59 |
+
raise ValueError("Genre and schedule date are required.")
|
60 |
+
genre_data = genre.strip().lower()
|
61 |
+
|
62 |
+
holidayEvent = self.get_HolidayInCalendar(scheduleDate)
|
63 |
+
|
64 |
+
return genre_data, holidayEvent
|
65 |
+
|
66 |
+
|
67 |
+
def get_retrievers(self, user_genres: list, holiday_event: str = None):
|
68 |
+
astraConfig = self.config["astra_db"]
|
69 |
+
astra_filter_genre = {"genre": {"$in": user_genres}}
|
70 |
+
|
71 |
+
if holiday_event:
|
72 |
+
keywords = HOLIDAY_KEYWORDS.get(holiday_event.lower(), [])
|
73 |
+
|
74 |
+
astra_filter_holiday = {
|
75 |
+
"$or": [
|
76 |
+
{"synopsis": {"$in": keywords}}
|
77 |
+
]}
|
78 |
+
|
79 |
+
retriever_holiday = self.vector_store.as_retriever(search_kwargs={"filter": astra_filter_holiday, "k": astraConfig["holidaySearch"]["k"]})
|
80 |
+
retriever_genre = self.vector_store.as_retriever(search_kwargs={"filter": astra_filter_genre, "k": astraConfig["genreSearchWithEvent"]["k"]})
|
81 |
+
|
82 |
+
return retriever_genre,retriever_holiday
|
83 |
+
|
84 |
+
else:
|
85 |
+
retriever = self.vector_store.as_retriever(search_kwargs={"filter": astra_filter_genre, "k": astraConfig["genreSearchWithoutEvent"]["k"]})
|
86 |
+
return retriever, None
|
87 |
+
|
88 |
+
#This is the function to pull the relevant docs based on the genre and date
|
89 |
+
def get_relevant_programmes(self, genre: str, scheduleDate: str)-> pd.DataFrame:
|
90 |
+
"""Retrieves relevant documents from the vector store based on the genre and date."""
|
91 |
+
if not self.vector_store or not self.embeddings:
|
92 |
+
raise ValueError("Vector store or embeddings not initialized.")
|
93 |
+
|
94 |
+
genre, holidayEvents = self.buildDataToIncludeHolidayEvents(genre, scheduleDate)
|
95 |
+
|
96 |
+
retriever_genre, retriever_holiday = self.get_retrievers([genre], holidayEvents)
|
97 |
+
documents= []
|
98 |
+
|
99 |
+
if retriever_holiday:
|
100 |
+
documents.extend(retriever_holiday.invoke(holidayEvents))
|
101 |
+
|
102 |
+
if retriever_genre:
|
103 |
+
documents.extend(retriever_genre.invoke(f"{genre} genre based programs"))
|
104 |
+
|
105 |
+
if not documents:
|
106 |
+
raise ValueError("No relevant documents found.")
|
107 |
+
program_df = pd.DataFrame([doc.metadata for doc in documents])
|
108 |
+
|
109 |
+
formatted_entries = []
|
110 |
+
for _, row in program_df.iterrows():
|
111 |
+
title = row['programme_title']
|
112 |
+
duration = row['duration']
|
113 |
+
rating = row['ratings']
|
114 |
+
synopsis = " ".join(row['synopsis']) if isinstance(row['synopsis'], list) else str(row['synopsis'])
|
115 |
+
genre = " ".join(row['genre']) if isinstance(row['genre'], list) else str(row['genre'])
|
116 |
+
|
117 |
+
formatted_entries.append(
|
118 |
+
f"programme_title: {title}, duration: {duration}, ratings: {rating}, synopsis: {synopsis}, genre: {genre}"
|
119 |
+
)
|
120 |
+
|
121 |
+
# Join all formatted strings with newline
|
122 |
+
docs = "\n".join(formatted_entries)
|
123 |
+
return docs, holidayEvents
|
124 |
+
|
125 |
+
# if __name__ == "__main__":
|
126 |
+
# tool = RetrievalTool()
|
127 |
+
# tool2 = tool.get_relevant_documents("comedy", "2023-10-25")
|
128 |
+
# print(tool2)
|
129 |
+
|
130 |
+
|