{ "cells": [ { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "import os\n", "from typing import Dict, List, Optional, TypedDict, Union\n", "from langgraph.graph import Graph\n", "from langchain_core.messages import HumanMessage, AIMessage\n", "from pydantic import BaseModel, Field\n", "import json\n", "from typing import Annotated\n", "from langchain_google_genai import ChatGoogleGenerativeAI\n" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "GOOGLE_API_KEY=\"AIzaSyA8eIxHBqeBWEP1g3t8bpvLxNaH5Lquemo\"\n", "os.environ[\"GOOGLE_API_KEY\"] = GOOGLE_API_KEY" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "llm = ChatGoogleGenerativeAI(model=\"gemini-1.5-flash-latest\")" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "# Define our data model\n", "class Data(BaseModel):\n", " name: str = Field(description=\"name\")\n", " age: int = Field(description=\"age\")\n", " hobby: Optional[List[str]] = Field(description=\"A list of hobbies.\")\n", "\n", "# Define state for the graph\n", "class GraphState(TypedDict):\n", " messages: List[Union[HumanMessage, AIMessage]]\n", " collected_data: Dict\n", " current_field: Optional[str]" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "# Function to analyze messages and extract information\n", "def extract_info(state: GraphState) -> Dict:\n", " # Get the last user message\n", " last_message = state[\"messages\"][-2].content if len(state[\"messages\"]) > 1 else \"\"\n", " \n", " # Prepare prompt for information extraction\n", " extraction_prompt = f\"\"\"\n", " Extract the following information from the user's message: '{last_message}'\n", " If the information is present, format it as JSON matching these fields:\n", " - name: person's name\n", " - age: person's age (as integer)\n", " - hobby: list of hobbies\n", " \n", " Only include fields that are clearly mentioned. Return 'null' for missing fields.\n", " \"\"\"\n", " \n", " # Extract information using LLM\n", " response = llm.invoke([HumanMessage(content=extraction_prompt)])\n", " try:\n", " extracted_info = json.loads(response.content)\n", " # Update collected data with any new information\n", " for key, value in extracted_info.items():\n", " if value is not None:\n", " state[\"collected_data\"][key] = value\n", " except:\n", " pass\n", " \n", " return {\"messages\": state[\"messages\"], \"collected_data\": state[\"collected_data\"], \"current_field\": state[\"current_field\"]}\n", "\n", "# Function to determine next field\n", "def determine_next_field(state: Dict) -> Dict:\n", " required_fields = {\"name\", \"age\"}\n", " collected_fields = set(state[\"collected_data\"].keys())\n", " missing_fields = required_fields - collected_fields\n", " \n", " if missing_fields:\n", " state[\"current_field\"] = next(iter(missing_fields))\n", " else:\n", " state[\"current_field\"] = \"hobby\" if \"hobby\" not in state[\"collected_data\"] else None\n", " \n", " return state\n", "\n", "# Function to generate response\n", "def generate_response(state: Dict) -> Dict:\n", " if state[\"current_field\"] is None:\n", " # All information collected\n", " data = Data(**state[\"collected_data\"])\n", " response = f\"Thank you! I've collected all the information:\\n{data.model_dump_json(indent=2)}\"\n", " else:\n", " # Ask for specific field\n", " field_descriptions = {\n", " \"name\": \"your name\",\n", " \"age\": \"your age\",\n", " \"hobby\": \"any hobbies you have\"\n", " }\n", " response = f\"Could you please tell me {field_descriptions[state['current_field']]}?\"\n", " \n", " state[\"messages\"].append(AIMessage(content=response))\n", " return state\n", "\n", "# Create the graph\n", "def create_chat_graph() -> Graph:\n", " workflow = Graph()\n", " \n", " # Define the graph edges\n", " workflow.add_node(\"extract_info\", extract_info)\n", " workflow.add_node(\"determine_next_field\", determine_next_field)\n", " workflow.add_node(\"generate_response\", generate_response)\n", " \n", " workflow.add_edge(\"extract_info\", \"determine_next_field\")\n", " workflow.add_edge(\"determine_next_field\", \"generate_response\")\n", " \n", " # Set the entry point\n", " workflow.set_entry_point(\"extract_info\")\n", " workflow.set_finish_point(\"generate_response\")\n", " \n", " return workflow" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Assistant: Hi! I'd like to collect some information from you.\n" ] }, { "ename": "KeyError", "evalue": "'generate_response'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[27], line 34\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[38;5;28;01mbreak\u001b[39;00m\n\u001b[1;32m 33\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;18m__name__\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m__main__\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m---> 34\u001b[0m run_chat()\n", "Cell \u001b[0;32mIn[27], line 27\u001b[0m, in \u001b[0;36mrun_chat\u001b[0;34m()\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[38;5;66;03m# Run the graph and update state\u001b[39;00m\n\u001b[1;32m 26\u001b[0m new_state \u001b[38;5;241m=\u001b[39m app\u001b[38;5;241m.\u001b[39minvoke(state)\n\u001b[0;32m---> 27\u001b[0m state \u001b[38;5;241m=\u001b[39m new_state[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgenerate_response\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;66;03m# Get the final state from the last node\u001b[39;00m\n\u001b[1;32m 29\u001b[0m \u001b[38;5;66;03m# Check if all information is collected\u001b[39;00m\n\u001b[1;32m 30\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m state[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcurrent_field\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(state[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcollected_data\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m2\u001b[39m: \u001b[38;5;66;03m# At least name and age collected\u001b[39;00m\n", "\u001b[0;31mKeyError\u001b[0m: 'generate_response'" ] } ], "source": [ "# Example usage\n", "def run_chat():\n", " graph = create_chat_graph()\n", " app = graph.compile()\n", " \n", " # Initialize state\n", " state = {\n", " \"messages\": [AIMessage(content=\"Hi! I'd like to collect some information from you.\")],\n", " \"collected_data\": {},\n", " \"current_field\": None\n", " }\n", " \n", " while True:\n", " # Print last message\n", " print(\"Assistant:\", state[\"messages\"][-1].content)\n", " \n", " # Get user input\n", " user_input = input(\"User: \")\n", " if user_input.lower() in ['quit', 'exit']:\n", " break\n", " \n", " # Update state with user message\n", " state[\"messages\"].append(HumanMessage(content=user_input))\n", " \n", " # Run the graph and update state\n", " new_state = app.invoke(state)\n", " state = new_state[\"generate_response\"] # Get the final state from the last node\n", " \n", " # Check if all information is collected\n", " if state[\"current_field\"] is None and len(state[\"collected_data\"]) >= 2: # At least name and age collected\n", " break\n", "\n", "if __name__ == \"__main__\":\n", " run_chat()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "paintrekbot", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.8" } }, "nbformat": 4, "nbformat_minor": 2 }