ppsingh's picture
Create app.py
fcc054f verified
raw
history blame
7.41 kB
import gradio as gr
from gradio_client import Client
from langgraph.graph import StateGraph, START, END
from typing import TypedDict, Optional
import io
from PIL import Image
import os
#OPEN QUESTION: SHOULD WE PASS ALL PARAMS FROM THE ORCHESTRATOR TO THE NODES INSTEAD OF SETTING IN EACH MODULE?
HF_TOKEN = os.environ.get("HF_TOKEN")
import configparser
import logging
import os
import ast
import re
from dotenv import load_dotenv
# Local .env file
load_dotenv()
def getconfig(configfile_path: str):
"""
Read the config file
Params
----------------
configfile_path: file path of .cfg file
"""
config = configparser.ConfigParser()
try:
config.read_file(open(configfile_path))
return config
except:
logging.warning("config file not found")
def get_auth(provider: str) -> dict:
"""Get authentication configuration for different providers"""
auth_configs = {
"huggingface": {"api_key": os.getenv("HF_TOKEN")},
"qdrant": {"api_key": os.getenv("QDRANT_API_KEY")},
}
provider = provider.lower() # Normalize to lowercase
if provider not in auth_configs:
raise ValueError(f"Unsupported provider: {provider}")
auth_config = auth_configs[provider]
api_key = auth_config.get("api_key")
if not api_key:
logging.warning(f"No API key found for provider '{provider}'. Please set the appropriate environment variable.")
auth_config["api_key"] = None
return auth_config
# Define the state schema
class GraphState(TypedDict):
query: str
context: str
result: str
# Add orchestrator-level parameters (addressing your open question)
reports_filter: str
sources_filter: str
subtype_filter: str
year_filter: str
# node 2: retriever
def retrieve_node(state: GraphState) -> GraphState:
client = Client("giz/chatfed_retriever", hf_token=HF_TOKEN) # HF repo name
context = client.predict(
query=state["query"],
reports_filter=state.get("reports_filter", ""),
sources_filter=state.get("sources_filter", ""),
subtype_filter=state.get("subtype_filter", ""),
year_filter=state.get("year_filter", ""),
api_name="/retrieve"
)
return {"context": context}
# node 3: generator
def generate_node(state: GraphState) -> GraphState:
client = Client("giz/chatfed_generator", hf_token=HF_TOKEN)
result = client.predict(
query=state["query"],
context=state["context"],
api_name="/generate"
)
return {"result": result}
# build the graph
workflow = StateGraph(GraphState)
# Add nodes
workflow.add_node("retrieve", retrieve_node)
workflow.add_node("generate", generate_node)
# Add edges
workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "generate")
workflow.add_edge("generate", END)
# Compile the graph
graph = workflow.compile()
# Single tool for processing queries
def process_query(
query: str,
reports_filter: str = "",
sources_filter: str = "",
subtype_filter: str = "",
year_filter: str = ""
) -> str:
"""
Execute the ChatFed orchestration pipeline to process a user query.
This function orchestrates a two-step workflow:
1. Retrieve relevant context using the ChatFed retriever service with optional filters
2. Generate a response using the ChatFed generator service with the retrieved context
Args:
query (str): The user's input query/question to be processed
reports_filter (str, optional): Filter for specific report types. Defaults to "".
sources_filter (str, optional): Filter for specific data sources. Defaults to "".
subtype_filter (str, optional): Filter for document subtypes. Defaults to "".
year_filter (str, optional): Filter for specific years. Defaults to "".
Returns:
str: The generated response from the ChatFed generator service
"""
initial_state = {
"query": query,
"context": "",
"result": "",
"reports_filter": reports_filter or "",
"sources_filter": sources_filter or "",
"subtype_filter": subtype_filter or "",
"year_filter": year_filter or ""
}
final_state = graph.invoke(initial_state)
return final_state["result"]
# Simple testing interface
ui = gr.Interface(
fn=process_query,
inputs=gr.Textbox(lines=2, placeholder="Enter query here"),
outputs="text",
flagging_mode="never"
)
# Add a function to generate the graph visualization
def get_graph_visualization():
"""Generate and return the LangGraph workflow visualization as a PIL Image."""
# Generate the graph as PNG bytes
graph_png_bytes = graph.get_graph().draw_mermaid_png()
# Convert bytes to PIL Image for Gradio display
graph_image = Image.open(io.BytesIO(graph_png_bytes))
return graph_image
# Guidance for ChatUI - can be removed later. Questionable whether front end even necessary. Maybe nice to show the graph.
with gr.Blocks(title="ChatFed Orchestrator") as demo:
gr.Markdown("# ChatFed Orchestrator")
gr.Markdown("This LangGraph server exposes MCP endpoints for the ChatUI module to call (which triggers the graph).")
with gr.Row():
# Left column - Graph visualization
with gr.Column(scale=1):
gr.Markdown("**Workflow Visualization**")
graph_display = gr.Image(
value=get_graph_visualization(),
label="LangGraph Workflow",
interactive=False,
height=300
)
# Add a refresh button for the graph
refresh_graph_btn = gr.Button("🔄 Refresh Graph", size="sm")
refresh_graph_btn.click(
fn=get_graph_visualization,
outputs=graph_display
)
# Right column - Interface and documentation
with gr.Column(scale=2):
gr.Markdown("**Available MCP Tools:**")
with gr.Accordion("MCP Endpoint Information", open=True):
gr.Markdown(f"""
**MCP Server Endpoint:** https://giz-chatfed-orchestrator.hf.space/gradio_api/mcp/sse
**For ChatUI Integration:**
```python
from gradio_client import Client
# Connect to orchestrator
orchestrator_client = Client("https://giz-chatfed-orchestrator.hf.space")
# Basic usage (no filters)
response = orchestrator_client.predict(
query="query",
api_name="/process_query"
)
# Advanced usage with any combination of filters
response = orchestrator_client.predict(
query="query",
reports_filter="annual_reports",
sources_filter="internal",
year_filter="2024",
api_name="/process_query"
)
```
""")
with gr.Accordion("Quick Testing Interface", open=True):
ui.render()
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
mcp_server=True,
show_error=True
)