midterm / server /workflow.py
Nagesh Muralidhar
midterm-submission
b66dccc
raw
history blame
8.9 kB
from typing import Dict, Any, List, Annotated, TypedDict, Union, Optional
from langgraph.graph import Graph, END
from agents import create_agents
from utils import save_transcript
import os
from dotenv import load_dotenv
import logging
# Configure logging
logger = logging.getLogger(__name__)
# Load environment variables
load_dotenv()
class AgentState(TypedDict):
messages: List[Dict[str, Any]]
current_agent: str
debate_turns: int
extractor_data: Dict[str, Any]
debate_history: List[Dict[str, Any]]
supervisor_notes: List[str]
supervisor_chunks: List[Dict[str, List[str]]]
final_podcast: Dict[str, Any]
agent_type: str
context: Optional[Dict[str, Any]]
def create_workflow(tavily_api_key: str):
# Initialize all agents
agents = create_agents(tavily_api_key)
# Create the graph
workflow = Graph()
# Define the extractor node function
async def run_extractor(state: AgentState) -> Dict[str, Any]:
query = state["messages"][-1]["content"]
logger.info(f"Extractor processing query: {query}")
try:
response = await agents["extractor"](query)
logger.info(f"Extractor response: {response}")
# Update state
state["extractor_data"] = response
# Get initial supervisor analysis
supervisor_analysis = await agents["supervisor"]({
"extractor": response,
"skeptic": {"content": "Not started"},
"believer": {"content": "Not started"}
})
logger.info(f"Initial supervisor analysis: {supervisor_analysis}")
state["supervisor_notes"].append(supervisor_analysis["content"])
state["supervisor_chunks"].append(supervisor_analysis.get("chunks", {}))
# Move to debate phase
state["current_agent"] = "debate"
return state
except Exception as e:
logger.error(f"Error in extractor: {str(e)}")
raise Exception(f"Error in extractor: {str(e)}")
# Define the debate node function
async def run_debate(state: AgentState) -> Dict[str, Any]:
logger.info(f"Debate turn {state['debate_turns']}")
try:
if state["debate_turns"] == 0:
# First turn: both agents respond to extractor
logger.info("Starting first debate turn")
# If we have context, use it to inform the agents' responses
context = state.get("context", {})
agent_chunks = context.get("agent_chunks", []) if context else []
# Create context-aware input for agents
context_input = {
"content": state["extractor_data"]["content"],
"chunks": agent_chunks
}
skeptic_response = await agents["skeptic"](context_input)
believer_response = await agents["believer"](context_input)
state["debate_history"].extend([
{"speaker": "skeptic", "content": skeptic_response["content"]},
{"speaker": "believer", "content": believer_response["content"]}
])
logger.info(f"First turn responses added: {state['debate_history'][-2:]}")
else:
# Alternating responses based on agent type if specified
if state["agent_type"] in ["believer", "skeptic"]:
current_speaker = state["agent_type"]
else:
# Default alternating behavior
last_speaker = state["debate_history"][-1]["speaker"]
current_speaker = "believer" if last_speaker == "skeptic" else "skeptic"
logger.info(f"Processing response for {current_speaker}")
# Create context-aware input
context = state.get("context", {})
agent_chunks = context.get("agent_chunks", []) if context else []
context_input = {
"content": state["debate_history"][-1]["content"],
"chunks": agent_chunks
}
response = await agents[current_speaker](context_input)
state["debate_history"].append({
"speaker": current_speaker,
"content": response["content"]
})
logger.info(f"Added response: {state['debate_history'][-1]}")
# Add supervisor note and chunks
supervisor_analysis = await agents["supervisor"]({
"extractor": state["extractor_data"],
"skeptic": {"content": state["debate_history"][-1]["content"]},
"believer": {"content": state["debate_history"][-2]["content"] if len(state["debate_history"]) > 1 else "Not started"}
})
logger.info(f"Supervisor analysis: {supervisor_analysis}")
state["supervisor_notes"].append(supervisor_analysis["content"])
state["supervisor_chunks"].append(supervisor_analysis.get("chunks", {}))
state["debate_turns"] += 1
logger.info(f"Debate turn {state['debate_turns']} completed")
# End the workflow after 2 debate turns
if state["debate_turns"] >= 2:
state["current_agent"] = "podcast"
logger.info("Moving to podcast production")
return state
except Exception as e:
logger.error(f"Error in debate: {str(e)}")
raise Exception(f"Error in debate: {str(e)}")
async def run_podcast_producer(state: AgentState) -> Dict[str, Any]:
logger.info("Starting podcast production")
try:
# Create podcast from debate
podcast_result = await agents["podcast_producer"](
state["debate_history"],
state["supervisor_notes"],
state["messages"][-1]["content"], # Pass the original user query
state["supervisor_chunks"],
{} # Empty quadrant analysis since we removed storage manager
)
logger.info(f"Podcast production result: {podcast_result}")
# Save transcript to JSON file
save_transcript(
podcast_script=podcast_result["content"],
user_query=state["messages"][-1]["content"]
)
# Store the result
state["final_podcast"] = podcast_result
# End the workflow
state["current_agent"] = END
return state
except Exception as e:
logger.error(f"Error in podcast production: {str(e)}")
raise Exception(f"Error in podcast production: {str(e)}")
# Add nodes to the graph
workflow.add_node("extractor", run_extractor)
workflow.add_node("debate", run_debate)
workflow.add_node("podcast", run_podcast_producer)
# Set the entry point
workflow.set_entry_point("extractor")
# Add edges
workflow.add_edge("extractor", "debate")
# Add conditional edges for debate
workflow.add_conditional_edges(
"debate",
lambda x: "podcast" if x["debate_turns"] >= 2 else "debate"
)
# Add edge from podcast to end
workflow.add_edge("podcast", END)
# Compile the graph
return workflow.compile()
async def run_workflow(
graph: Graph,
query: str,
agent_type: str = "believer",
context: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Run the workflow with a given query."""
# Initialize the state
initial_state = {
"messages": [{"role": "user", "content": query}],
"current_agent": "extractor",
"debate_turns": 0,
"extractor_data": {},
"debate_history": [],
"supervisor_notes": [],
"supervisor_chunks": [],
"final_podcast": {},
"agent_type": agent_type,
"context": context
}
# Run the graph
result = await graph.ainvoke(initial_state)
return {
"debate_history": result["debate_history"],
"supervisor_notes": result["supervisor_notes"],
"supervisor_chunks": result["supervisor_chunks"],
"extractor_data": result["extractor_data"],
"final_podcast": result["final_podcast"]
}