SyedHasanCronosPMC commited on
Commit
10e4c2d
·
verified ·
1 Parent(s): 88178c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -50
app.py CHANGED
@@ -2,88 +2,84 @@ import gradio as gr
2
  import os
3
  import matplotlib.pyplot as plt
4
  import pandas as pd
5
- from langgraph.graph import StateGraph, END
6
- from langgraph.prebuilt.tool_executor import ToolExecutor
7
- from langgraph.prebuilt.react import create_react_plan_and_execute
8
- from langgraph.checkpoint.sqlite import SqliteSaver
9
- from langgraph.graph.message import add_messages
10
-
11
- from langchain_core.messages import HumanMessage
12
  from langchain_anthropic import ChatAnthropic
13
 
14
- # Load API Key securely
15
  os.environ["ANTHROPIC_API_KEY"] = os.getenv("ANTHROPIC_API_KEY")
16
-
17
- # Define the LLM (Claude 3.5 Sonnet)
18
  llm = ChatAnthropic(model="claude-3-5-sonnet-latest")
19
 
20
- # System prompt modifier
21
- def make_system_prompt(suffix: str) -> str:
22
  return (
23
- "You are a helpful AI assistant, collaborating with other assistants."
24
- " Use the provided tools to progress toward answering the question."
25
- " If you cannot fully answer, another assistant will continue where you left off."
26
- " If you or the team has a complete answer, prefix it with FINAL ANSWER.\n"
27
- f"{suffix}"
28
  )
29
 
30
- # Node 1: Research assistant logic
31
- def research_node(state: dict) -> dict:
32
- messages = state.get("messages", [])
33
- prompt = make_system_prompt("You can only do research.")
34
- executor = create_react_plan_and_execute(llm=llm, tools=[], system_prompt=prompt)
35
- response = executor.invoke(messages)
36
- messages.append(HumanMessage(content=response.content, name="researcher"))
37
- next_node = "chart_generator" if "FINAL ANSWER" not in response.content else END
38
- return {"messages": messages, "next": next_node}
 
 
 
39
 
40
- # Node 2: Chart assistant logic
41
- def chart_node(state: dict) -> dict:
42
- messages = state.get("messages", [])
43
- prompt = make_system_prompt("You can only generate charts.")
44
- executor = create_react_plan_and_execute(llm=llm, tools=[], system_prompt=prompt)
45
- response = executor.invoke(messages)
46
- messages.append(HumanMessage(content=response.content, name="chart_generator"))
47
- return {"messages": messages, "next": END}
48
 
49
- # Define LangGraph flow
 
 
 
 
 
 
50
  workflow = StateGraph(dict)
51
  workflow.add_node("researcher", research_node)
52
  workflow.add_node("chart_generator", chart_node)
53
  workflow.set_entry_point("researcher")
54
- workflow.set_finish_point(END)
55
- workflow.add_conditional_edges("researcher", lambda x: x["next"])
56
- workflow.add_edge("chart_generator", END)
57
-
58
  graph = workflow.compile()
59
 
60
- # Function to run the graph and optionally return chart
61
  def run_langgraph(input_text):
62
  try:
63
  events = graph.stream({"messages": [HumanMessage(content=input_text)]})
64
- output = list(events)[-1]
65
- final_content = output["messages"][-1].content
 
66
 
67
- if "FINAL ANSWER" in final_content:
68
- # Example static chart
69
  years = [2020, 2021, 2022, 2023, 2024]
70
  gdp = [21.4, 22.0, 23.1, 24.8, 26.2]
71
  plt.figure()
72
  plt.plot(years, gdp, marker="o")
73
  plt.title("USA GDP Over Last 5 Years")
74
  plt.xlabel("Year")
75
- plt.ylabel("GDP in Trillions")
76
  plt.grid(True)
77
  plt.tight_layout()
78
- plt.savefig("gdp_chart.png")
79
- return "Chart generated based on FINAL ANSWER", "gdp_chart.png"
 
80
  else:
81
- return final_content, None
82
-
83
  except Exception as e:
84
  return f"Error: {str(e)}", None
85
 
86
- # Gradio Interface
87
  def process_input(user_input):
88
  return run_langgraph(user_input)
89
 
 
2
  import os
3
  import matplotlib.pyplot as plt
4
  import pandas as pd
5
+ from langgraph.graph import StateGraph
6
+ from langgraph.types import Command
7
+ from langchain_core.messages import HumanMessage, AIMessage
 
 
 
 
8
  from langchain_anthropic import ChatAnthropic
9
 
10
+ # Set API key from environment variable (Hugging Face secret)
11
  os.environ["ANTHROPIC_API_KEY"] = os.getenv("ANTHROPIC_API_KEY")
 
 
12
  llm = ChatAnthropic(model="claude-3-5-sonnet-latest")
13
 
14
+ # System message function
15
+ def make_system_prompt(task: str) -> str:
16
  return (
17
+ f"You are an AI assistant. Your task is: {task}. "
18
+ "If you have the final answer, prefix it with 'FINAL ANSWER'."
 
 
 
19
  )
20
 
21
+ # Node: research
22
+ def research_node(state: dict) -> Command:
23
+ user_msg = state["messages"][-1]
24
+ task = user_msg.content
25
+ system_prompt = make_system_prompt("Research only")
26
+
27
+ response = llm.invoke([HumanMessage(content=system_prompt), user_msg])
28
+ content = response.content
29
+
30
+ state["messages"].append(AIMessage(content=content, name="researcher"))
31
+ goto = "chart_generator" if "FINAL ANSWER" not in content else "__end__"
32
+ return Command(update={"messages": state["messages"]}, goto=goto)
33
 
34
+ # Node: chart generation
35
+ def chart_node(state: dict) -> Command:
36
+ user_msg = state["messages"][-1]
37
+ task = user_msg.content
38
+ system_prompt = make_system_prompt("Generate chart only")
 
 
 
39
 
40
+ response = llm.invoke([HumanMessage(content=system_prompt), user_msg])
41
+ content = response.content
42
+
43
+ state["messages"].append(AIMessage(content=content, name="chart_generator"))
44
+ return Command(update={"messages": state["messages"]}, goto="__end__")
45
+
46
+ # Build LangGraph
47
  workflow = StateGraph(dict)
48
  workflow.add_node("researcher", research_node)
49
  workflow.add_node("chart_generator", chart_node)
50
  workflow.set_entry_point("researcher")
51
+ workflow.set_finish_point("__end__")
52
+ workflow.add_edge("researcher", "chart_generator")
 
 
53
  graph = workflow.compile()
54
 
55
+ # Stream function
56
  def run_langgraph(input_text):
57
  try:
58
  events = graph.stream({"messages": [HumanMessage(content=input_text)]})
59
+ final_state = list(events)[-1]
60
+ messages = final_state["messages"]
61
+ final_output = messages[-1].content
62
 
63
+ if "FINAL ANSWER" in final_output:
64
+ # Dummy chart
65
  years = [2020, 2021, 2022, 2023, 2024]
66
  gdp = [21.4, 22.0, 23.1, 24.8, 26.2]
67
  plt.figure()
68
  plt.plot(years, gdp, marker="o")
69
  plt.title("USA GDP Over Last 5 Years")
70
  plt.xlabel("Year")
71
+ plt.ylabel("GDP in Trillions USD")
72
  plt.grid(True)
73
  plt.tight_layout()
74
+ chart_path = "gdp_chart.png"
75
+ plt.savefig(chart_path)
76
+ return final_output, chart_path
77
  else:
78
+ return final_output, None
 
79
  except Exception as e:
80
  return f"Error: {str(e)}", None
81
 
82
+ # Gradio UI
83
  def process_input(user_input):
84
  return run_langgraph(user_input)
85