SyedHasanCronosPMC's picture
Update app.py
10e4c2d verified
raw
history blame
3.31 kB
import gradio as gr
import os
import matplotlib.pyplot as plt
import pandas as pd
from langgraph.graph import StateGraph
from langgraph.types import Command
from langchain_core.messages import HumanMessage, AIMessage
from langchain_anthropic import ChatAnthropic
# Set API key from environment variable (Hugging Face secret)
os.environ["ANTHROPIC_API_KEY"] = os.getenv("ANTHROPIC_API_KEY")
llm = ChatAnthropic(model="claude-3-5-sonnet-latest")
# System message function
def make_system_prompt(task: str) -> str:
return (
f"You are an AI assistant. Your task is: {task}. "
"If you have the final answer, prefix it with 'FINAL ANSWER'."
)
# Node: research
def research_node(state: dict) -> Command:
user_msg = state["messages"][-1]
task = user_msg.content
system_prompt = make_system_prompt("Research only")
response = llm.invoke([HumanMessage(content=system_prompt), user_msg])
content = response.content
state["messages"].append(AIMessage(content=content, name="researcher"))
goto = "chart_generator" if "FINAL ANSWER" not in content else "__end__"
return Command(update={"messages": state["messages"]}, goto=goto)
# Node: chart generation
def chart_node(state: dict) -> Command:
user_msg = state["messages"][-1]
task = user_msg.content
system_prompt = make_system_prompt("Generate chart only")
response = llm.invoke([HumanMessage(content=system_prompt), user_msg])
content = response.content
state["messages"].append(AIMessage(content=content, name="chart_generator"))
return Command(update={"messages": state["messages"]}, goto="__end__")
# Build LangGraph
workflow = StateGraph(dict)
workflow.add_node("researcher", research_node)
workflow.add_node("chart_generator", chart_node)
workflow.set_entry_point("researcher")
workflow.set_finish_point("__end__")
workflow.add_edge("researcher", "chart_generator")
graph = workflow.compile()
# Stream function
def run_langgraph(input_text):
try:
events = graph.stream({"messages": [HumanMessage(content=input_text)]})
final_state = list(events)[-1]
messages = final_state["messages"]
final_output = messages[-1].content
if "FINAL ANSWER" in final_output:
# Dummy chart
years = [2020, 2021, 2022, 2023, 2024]
gdp = [21.4, 22.0, 23.1, 24.8, 26.2]
plt.figure()
plt.plot(years, gdp, marker="o")
plt.title("USA GDP Over Last 5 Years")
plt.xlabel("Year")
plt.ylabel("GDP in Trillions USD")
plt.grid(True)
plt.tight_layout()
chart_path = "gdp_chart.png"
plt.savefig(chart_path)
return final_output, chart_path
else:
return final_output, None
except Exception as e:
return f"Error: {str(e)}", None
# Gradio UI
def process_input(user_input):
return run_langgraph(user_input)
interface = gr.Interface(
fn=process_input,
inputs=gr.Textbox(label="Enter your research task"),
outputs=[gr.Textbox(label="Output"), gr.Image(type="filepath", label="Chart")],
title="LangGraph Research Automation",
description="Enter a research prompt and view chart output when applicable."
)
if __name__ == "__main__":
interface.launch()