SyedHasanCronosPMC commited on
Commit
038f4ad
·
verified ·
1 Parent(s): 10e4c2d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -42
app.py CHANGED
@@ -3,47 +3,65 @@ import os
3
  import matplotlib.pyplot as plt
4
  import pandas as pd
5
  from langgraph.graph import StateGraph
6
- from langgraph.types import Command
7
- from langchain_core.messages import HumanMessage, AIMessage
8
  from langchain_anthropic import ChatAnthropic
9
 
10
- # Set API key from environment variable (Hugging Face secret)
 
 
 
 
 
 
 
11
  os.environ["ANTHROPIC_API_KEY"] = os.getenv("ANTHROPIC_API_KEY")
12
- llm = ChatAnthropic(model="claude-3-5-sonnet-latest")
13
 
14
- # System message function
15
- def make_system_prompt(task: str) -> str:
 
 
 
16
  return (
17
- f"You are an AI assistant. Your task is: {task}. "
18
- "If you have the final answer, prefix it with 'FINAL ANSWER'."
 
19
  )
20
 
21
- # Node: research
22
- def research_node(state: dict) -> Command:
23
- user_msg = state["messages"][-1]
24
- task = user_msg.content
25
- system_prompt = make_system_prompt("Research only")
26
-
27
- response = llm.invoke([HumanMessage(content=system_prompt), user_msg])
28
- content = response.content
29
-
30
- state["messages"].append(AIMessage(content=content, name="researcher"))
31
- goto = "chart_generator" if "FINAL ANSWER" not in content else "__end__"
32
- return Command(update={"messages": state["messages"]}, goto=goto)
33
 
34
- # Node: chart generation
35
- def chart_node(state: dict) -> Command:
36
- user_msg = state["messages"][-1]
37
- task = user_msg.content
38
- system_prompt = make_system_prompt("Generate chart only")
39
 
40
- response = llm.invoke([HumanMessage(content=system_prompt), user_msg])
41
- content = response.content
 
 
 
42
 
43
- state["messages"].append(AIMessage(content=content, name="chart_generator"))
44
- return Command(update={"messages": state["messages"]}, goto="__end__")
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- # Build LangGraph
47
  workflow = StateGraph(dict)
48
  workflow.add_node("researcher", research_node)
49
  workflow.add_node("chart_generator", chart_node)
@@ -52,18 +70,18 @@ workflow.set_finish_point("__end__")
52
  workflow.add_edge("researcher", "chart_generator")
53
  graph = workflow.compile()
54
 
55
- # Stream function
56
  def run_langgraph(input_text):
57
  try:
58
- events = graph.stream({"messages": [HumanMessage(content=input_text)]})
59
- final_state = list(events)[-1]
60
- messages = final_state["messages"]
61
- final_output = messages[-1].content
62
 
63
- if "FINAL ANSWER" in final_output:
64
- # Dummy chart
65
  years = [2020, 2021, 2022, 2023, 2024]
66
  gdp = [21.4, 22.0, 23.1, 24.8, 26.2]
 
67
  plt.figure()
68
  plt.plot(years, gdp, marker="o")
69
  plt.title("USA GDP Over Last 5 Years")
@@ -71,15 +89,15 @@ def run_langgraph(input_text):
71
  plt.ylabel("GDP in Trillions USD")
72
  plt.grid(True)
73
  plt.tight_layout()
74
- chart_path = "gdp_chart.png"
75
- plt.savefig(chart_path)
76
- return final_output, chart_path
77
  else:
78
- return final_output, None
79
  except Exception as e:
80
  return f"Error: {str(e)}", None
81
 
82
- # Gradio UI
83
  def process_input(user_input):
84
  return run_langgraph(user_input)
85
 
 
3
  import matplotlib.pyplot as plt
4
  import pandas as pd
5
  from langgraph.graph import StateGraph
6
+ from langgraph.prebuilt import create_agent_executor
7
+ from langchain_core.messages import HumanMessage
8
  from langchain_anthropic import ChatAnthropic
9
 
10
+ # Fallback Command class (since langgraph.types is not available in v0.0.41)
11
+ class Command:
12
+ def __init__(self, update=None, next=None, goto=None):
13
+ self.update = update or {}
14
+ self.next = next
15
+ self.goto = goto
16
+
17
+ # Set the API key from Hugging Face secrets
18
  os.environ["ANTHROPIC_API_KEY"] = os.getenv("ANTHROPIC_API_KEY")
 
19
 
20
+ # Claude 3.5 Sonnet model
21
+ llm = ChatAnthropic(model="claude-3-5-sonnet-20240229")
22
+
23
+ # Utility to build system prompts
24
+ def make_system_prompt(suffix: str) -> str:
25
  return (
26
+ "You are a helpful AI assistant collaborating with others."
27
+ " Use your tools to assist. If you can't complete a task, leave it to another agent."
28
+ " Prefix final output with 'FINAL ANSWER' to signal completion.\n" + suffix
29
  )
30
 
31
+ # Research agent node
32
+ def research_node(state):
33
+ agent = create_agent_executor(
34
+ llm,
35
+ tools=[],
36
+ system_message=make_system_prompt("You can only perform research.")
37
+ )
38
+ result = agent.invoke(state)
39
+ last_msg = result["messages"][-1]
 
 
 
40
 
41
+ # Determine next step
42
+ goto = "chart_generator" if "FINAL ANSWER" not in last_msg.content else "__end__"
 
 
 
43
 
44
+ result["messages"][-1] = HumanMessage(
45
+ content=last_msg.content,
46
+ name="researcher"
47
+ )
48
+ return Command(update={"messages": result["messages"]}, goto=goto)
49
 
50
+ # Chart generation agent node
51
+ def chart_node(state):
52
+ agent = create_agent_executor(
53
+ llm,
54
+ tools=[],
55
+ system_message=make_system_prompt("You can only generate charts.")
56
+ )
57
+ result = agent.invoke(state)
58
+ result["messages"][-1] = HumanMessage(
59
+ content=result["messages"][-1].content,
60
+ name="chart_generator"
61
+ )
62
+ return Command(update={"messages": result["messages"]}, goto="__end__")
63
 
64
+ # LangGraph state setup
65
  workflow = StateGraph(dict)
66
  workflow.add_node("researcher", research_node)
67
  workflow.add_node("chart_generator", chart_node)
 
70
  workflow.add_edge("researcher", "chart_generator")
71
  graph = workflow.compile()
72
 
73
+ # LangGraph runner
74
  def run_langgraph(input_text):
75
  try:
76
+ events = graph.stream({"messages": [("user", input_text)]})
77
+ output = list(events)
78
+ final_response = output[-1]["messages"][-1].content
 
79
 
80
+ if "FINAL ANSWER" in final_response:
81
+ # Dummy chart generation
82
  years = [2020, 2021, 2022, 2023, 2024]
83
  gdp = [21.4, 22.0, 23.1, 24.8, 26.2]
84
+
85
  plt.figure()
86
  plt.plot(years, gdp, marker="o")
87
  plt.title("USA GDP Over Last 5 Years")
 
89
  plt.ylabel("GDP in Trillions USD")
90
  plt.grid(True)
91
  plt.tight_layout()
92
+ plt.savefig("gdp_chart.png")
93
+
94
+ return "Chart generated based on FINAL ANSWER.", "gdp_chart.png"
95
  else:
96
+ return final_response, None
97
  except Exception as e:
98
  return f"Error: {str(e)}", None
99
 
100
+ # Gradio interface
101
  def process_input(user_input):
102
  return run_langgraph(user_input)
103