In [21]:
import os
from typing import Dict, List, Optional, TypedDict, Union
from langgraph.graph import Graph
from langchain_core.messages import HumanMessage, AIMessage
from pydantic import BaseModel, Field
import json
from typing import Annotated
from langchain_google_genai import ChatGoogleGenerativeAI


In [16]:
GOOGLE_API_KEY="AIzaSyA8eIxHBqeBWEP1g3t8bpvLxNaH5Lquemo"
os.environ["GOOGLE_API_KEY"] = GOOGLE_API_KEY

In [17]:
llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash-latest")

In [22]:
# Define our data model
class Data(BaseModel):
 name: str = Field(description="name")
 age: int = Field(description="age")
 hobby: Optional[List[str]] = Field(description="A list of hobbies.")

# Define state for the graph
class GraphState(TypedDict):
 messages: List[Union[HumanMessage, AIMessage]]
 collected_data: Dict
 current_field: Optional[str]

In [26]:
# Function to analyze messages and extract information
def extract_info(state: GraphState) -> Dict:
 # Get the last user message
 last_message = state["messages"][-2].content if len(state["messages"]) > 1 else ""
 
 # Prepare prompt for information extraction
 extraction_prompt = f"""
 Extract the following information from the user's message: '{last_message}'
 If the information is present, format it as JSON matching these fields:
 - name: person's name
 - age: person's age (as integer)
 - hobby: list of hobbies
 
 Only include fields that are clearly mentioned. Return 'null' for missing fields.
 """
 
 # Extract information using LLM
 response = llm.invoke([HumanMessage(content=extraction_prompt)])
 try:
 extracted_info = json.loads(response.content)
 # Update collected data with any new information
 for key, value in extracted_info.items():
 if value is not None:
 state["collected_data"][key] = value
 except:
 pass
 
 return {"messages": state["messages"], "collected_data": state["collected_data"], "current_field": state["current_field"]}

# Function to determine next field
def determine_next_field(state: Dict) -> Dict:
 required_fields = {"name", "age"}
 collected_fields = set(state["collected_data"].keys())
 missing_fields = required_fields - collected_fields
 
 if missing_fields:
 state["current_field"] = next(iter(missing_fields))
 else:
 state["current_field"] = "hobby" if "hobby" not in state["collected_data"] else None
 
 return state

# Function to generate response
def generate_response(state: Dict) -> Dict:
 if state["current_field"] is None:
 # All information collected
 data = Data(**state["collected_data"])
 response = f"Thank you! I've collected all the information:\n{data.model_dump_json(indent=2)}"
 else:
 # Ask for specific field
 field_descriptions = {
 "name": "your name",
 "age": "your age",
 "hobby": "any hobbies you have"
 }
 response = f"Could you please tell me {field_descriptions[state['current_field']]}?"
 
 state["messages"].append(AIMessage(content=response))
 return state

# Create the graph
def create_chat_graph() -> Graph:
 workflow = Graph()
 
 # Define the graph edges
 workflow.add_node("extract_info", extract_info)
 workflow.add_node("determine_next_field", determine_next_field)
 workflow.add_node("generate_response", generate_response)
 
 workflow.add_edge("extract_info", "determine_next_field")
 workflow.add_edge("determine_next_field", "generate_response")
 
 # Set the entry point
 workflow.set_entry_point("extract_info")
 workflow.set_finish_point("generate_response")
 
 return workflow

In [27]:
# Example usage
def run_chat():
 graph = create_chat_graph()
 app = graph.compile()
 
 # Initialize state
 state = {
 "messages": [AIMessage(content="Hi! I'd like to collect some information from you.")],
 "collected_data": {},
 "current_field": None
 }
 
 while True:
 # Print last message
 print("Assistant:", state["messages"][-1].content)
 
 # Get user input
 user_input = input("User: ")
 if user_input.lower() in ['quit', 'exit']:
 break
 
 # Update state with user message
 state["messages"].append(HumanMessage(content=user_input))
 
 # Run the graph and update state
 new_state = app.invoke(state)
 state = new_state["generate_response"] # Get the final state from the last node
 
 # Check if all information is collected
 if state["current_field"] is None and len(state["collected_data"]) >= 2: # At least name and age collected
 break

if __name__ == "__main__":
 run_chat()

Assistant: Hi! I'd like to collect some information from you.


KeyError: 'generate_response'