File size: 3,310 Bytes
de9f6ca
 
d1d9178
1effecd
10e4c2d
 
 
de9f6ca
 
10e4c2d
de9f6ca
 
 
10e4c2d
 
de9f6ca
10e4c2d
 
de9f6ca
 
10e4c2d
 
 
 
 
 
 
 
 
 
 
 
de9f6ca
10e4c2d
 
 
 
 
de9f6ca
10e4c2d
 
 
 
 
 
 
1effecd
de9f6ca
 
d85a20b
10e4c2d
 
de9f6ca
 
10e4c2d
1effecd
d1d9178
1effecd
10e4c2d
 
 
d85a20b
10e4c2d
 
d85a20b
 
 
1effecd
d85a20b
 
10e4c2d
d85a20b
 
10e4c2d
 
 
d85a20b
10e4c2d
d1d9178
d85a20b
 
10e4c2d
d85a20b
 
de9f6ca
 
d85a20b
 
 
de9f6ca
d1d9178
de9f6ca
 
 
d85a20b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
import os
import matplotlib.pyplot as plt
import pandas as pd
from langgraph.graph import StateGraph
from langgraph.types import Command
from langchain_core.messages import HumanMessage, AIMessage
from langchain_anthropic import ChatAnthropic

# Set API key from environment variable (Hugging Face secret)
os.environ["ANTHROPIC_API_KEY"] = os.getenv("ANTHROPIC_API_KEY")
llm = ChatAnthropic(model="claude-3-5-sonnet-latest")

# System message function
def make_system_prompt(task: str) -> str:
    return (
        f"You are an AI assistant. Your task is: {task}. "
        "If you have the final answer, prefix it with 'FINAL ANSWER'."
    )

# Node: research
def research_node(state: dict) -> Command:
    user_msg = state["messages"][-1]
    task = user_msg.content
    system_prompt = make_system_prompt("Research only")

    response = llm.invoke([HumanMessage(content=system_prompt), user_msg])
    content = response.content

    state["messages"].append(AIMessage(content=content, name="researcher"))
    goto = "chart_generator" if "FINAL ANSWER" not in content else "__end__"
    return Command(update={"messages": state["messages"]}, goto=goto)

# Node: chart generation
def chart_node(state: dict) -> Command:
    user_msg = state["messages"][-1]
    task = user_msg.content
    system_prompt = make_system_prompt("Generate chart only")

    response = llm.invoke([HumanMessage(content=system_prompt), user_msg])
    content = response.content

    state["messages"].append(AIMessage(content=content, name="chart_generator"))
    return Command(update={"messages": state["messages"]}, goto="__end__")

# Build LangGraph
workflow = StateGraph(dict)
workflow.add_node("researcher", research_node)
workflow.add_node("chart_generator", chart_node)
workflow.set_entry_point("researcher")
workflow.set_finish_point("__end__")
workflow.add_edge("researcher", "chart_generator")
graph = workflow.compile()

# Stream function
def run_langgraph(input_text):
    try:
        events = graph.stream({"messages": [HumanMessage(content=input_text)]})
        final_state = list(events)[-1]
        messages = final_state["messages"]
        final_output = messages[-1].content

        if "FINAL ANSWER" in final_output:
            # Dummy chart
            years = [2020, 2021, 2022, 2023, 2024]
            gdp = [21.4, 22.0, 23.1, 24.8, 26.2]
            plt.figure()
            plt.plot(years, gdp, marker="o")
            plt.title("USA GDP Over Last 5 Years")
            plt.xlabel("Year")
            plt.ylabel("GDP in Trillions USD")
            plt.grid(True)
            plt.tight_layout()
            chart_path = "gdp_chart.png"
            plt.savefig(chart_path)
            return final_output, chart_path
        else:
            return final_output, None
    except Exception as e:
        return f"Error: {str(e)}", None

# Gradio UI
def process_input(user_input):
    return run_langgraph(user_input)

interface = gr.Interface(
    fn=process_input,
    inputs=gr.Textbox(label="Enter your research task"),
    outputs=[gr.Textbox(label="Output"), gr.Image(type="filepath", label="Chart")],
    title="LangGraph Research Automation",
    description="Enter a research prompt and view chart output when applicable."
)

if __name__ == "__main__":
    interface.launch()