File size: 3,700 Bytes
de9f6ca
 
d1d9178
1effecd
10e4c2d
038f4ad
 
de9f6ca
 
038f4ad
 
 
 
 
 
 
 
de9f6ca
 
038f4ad
 
 
 
 
de9f6ca
038f4ad
 
 
de9f6ca
 
038f4ad
 
 
 
 
 
 
 
 
de9f6ca
038f4ad
 
de9f6ca
038f4ad
 
 
 
 
10e4c2d
038f4ad
 
 
 
 
 
 
 
 
 
 
 
 
10e4c2d
038f4ad
1effecd
de9f6ca
 
d85a20b
10e4c2d
 
de9f6ca
 
038f4ad
1effecd
d1d9178
038f4ad
 
 
d85a20b
038f4ad
 
d85a20b
 
038f4ad
d85a20b
1effecd
d85a20b
 
10e4c2d
d85a20b
 
038f4ad
 
 
d85a20b
038f4ad
d1d9178
d85a20b
 
038f4ad
d85a20b
 
de9f6ca
 
d85a20b
 
 
de9f6ca
d1d9178
de9f6ca
 
 
d85a20b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import gradio as gr
import os
import matplotlib.pyplot as plt
import pandas as pd
from langgraph.graph import StateGraph
from langgraph.prebuilt import create_agent_executor
from langchain_core.messages import HumanMessage
from langchain_anthropic import ChatAnthropic

# Fallback Command class (since langgraph.types is not available in v0.0.41)
class Command:
    def __init__(self, update=None, next=None, goto=None):
        self.update = update or {}
        self.next = next
        self.goto = goto

# Set the API key from Hugging Face secrets
os.environ["ANTHROPIC_API_KEY"] = os.getenv("ANTHROPIC_API_KEY")

# Claude 3.5 Sonnet model
llm = ChatAnthropic(model="claude-3-5-sonnet-20240229")

# Utility to build system prompts
def make_system_prompt(suffix: str) -> str:
    return (
        "You are a helpful AI assistant collaborating with others."
        " Use your tools to assist. If you can't complete a task, leave it to another agent."
        " Prefix final output with 'FINAL ANSWER' to signal completion.\n" + suffix
    )

# Research agent node
def research_node(state):
    agent = create_agent_executor(
        llm,
        tools=[],
        system_message=make_system_prompt("You can only perform research.")
    )
    result = agent.invoke(state)
    last_msg = result["messages"][-1]

    # Determine next step
    goto = "chart_generator" if "FINAL ANSWER" not in last_msg.content else "__end__"

    result["messages"][-1] = HumanMessage(
        content=last_msg.content,
        name="researcher"
    )
    return Command(update={"messages": result["messages"]}, goto=goto)

# Chart generation agent node
def chart_node(state):
    agent = create_agent_executor(
        llm,
        tools=[],
        system_message=make_system_prompt("You can only generate charts.")
    )
    result = agent.invoke(state)
    result["messages"][-1] = HumanMessage(
        content=result["messages"][-1].content,
        name="chart_generator"
    )
    return Command(update={"messages": result["messages"]}, goto="__end__")

# LangGraph state setup
workflow = StateGraph(dict)
workflow.add_node("researcher", research_node)
workflow.add_node("chart_generator", chart_node)
workflow.set_entry_point("researcher")
workflow.set_finish_point("__end__")
workflow.add_edge("researcher", "chart_generator")
graph = workflow.compile()

# LangGraph runner
def run_langgraph(input_text):
    try:
        events = graph.stream({"messages": [("user", input_text)]})
        output = list(events)
        final_response = output[-1]["messages"][-1].content

        if "FINAL ANSWER" in final_response:
            # Dummy chart generation
            years = [2020, 2021, 2022, 2023, 2024]
            gdp = [21.4, 22.0, 23.1, 24.8, 26.2]

            plt.figure()
            plt.plot(years, gdp, marker="o")
            plt.title("USA GDP Over Last 5 Years")
            plt.xlabel("Year")
            plt.ylabel("GDP in Trillions USD")
            plt.grid(True)
            plt.tight_layout()
            plt.savefig("gdp_chart.png")

            return "Chart generated based on FINAL ANSWER.", "gdp_chart.png"
        else:
            return final_response, None
    except Exception as e:
        return f"Error: {str(e)}", None

# Gradio interface
def process_input(user_input):
    return run_langgraph(user_input)

interface = gr.Interface(
    fn=process_input,
    inputs=gr.Textbox(label="Enter your research task"),
    outputs=[gr.Textbox(label="Output"), gr.Image(type="filepath", label="Chart")],
    title="LangGraph Research Automation",
    description="Enter a research prompt and view chart output when applicable."
)

if __name__ == "__main__":
    interface.launch()