mtyrrell commited on
Commit
9e430dc
·
1 Parent(s): 0e839b6

future proofing for chatUI

Browse files
Files changed (1) hide show
  1. app/main.py +80 -31
app/main.py CHANGED
@@ -1,36 +1,30 @@
1
  import gradio as gr
2
  from gradio_client import Client
3
  from langgraph.graph import StateGraph, START, END
4
- from typing import TypedDict
 
 
5
 
6
  # Define the state schema
7
  class GraphState(TypedDict):
8
  query: str
9
  context: str
10
  result: str
11
-
12
-
13
- ### OPEN QUESTION: SHOULD WE PASS ALL PARAMS FROM THE ORCHESTRATOR TO THE NODES INSTEAD OF SETTING IN EACH MODULE?
14
-
15
- # node 1: chatUI
16
- # To be finalized
17
- # def chatui_node(state: GraphState) -> GraphState:
18
- # client = Client("giz/chatfed_chatui")
19
- # context = client.predict(
20
- # query=state["query"], # not sure if we need to pass the query here
21
- # api_name="/chat"
22
- # )
23
- # return {"query": query}
24
 
25
  # node 2: retriever
26
  def retrieve_node(state: GraphState) -> GraphState:
27
  client = Client("giz/chatfed_retriever") # HF repo name
28
  context = client.predict(
29
  query=state["query"],
30
- reports_filter="",
31
- sources_filter="",
32
- subtype_filter="",
33
- year_filter="",
34
  api_name="/retrieve"
35
  )
36
  return {"context": context}
@@ -49,13 +43,10 @@ def generate_node(state: GraphState) -> GraphState:
49
  workflow = StateGraph(GraphState)
50
 
51
  # Add nodes
52
- # workflow.add_node("chatui", chatui_node)
53
  workflow.add_node("retrieve", retrieve_node)
54
  workflow.add_node("generate", generate_node)
55
 
56
  # Add edges
57
- # workflow.add_edge(START, "chatui")
58
- # workflow.add_edge("chatui", "retrieve")
59
  workflow.add_edge(START, "retrieve")
60
  workflow.add_edge("retrieve", "generate")
61
  workflow.add_edge("generate", END)
@@ -63,35 +54,93 @@ workflow.add_edge("generate", END)
63
  # Compile the graph
64
  graph = workflow.compile()
65
 
66
- def pipeline(query: str):
 
 
 
 
 
 
 
67
  """
68
  Execute the ChatFed orchestration pipeline to process a user query.
69
 
70
  This function orchestrates a two-step workflow:
71
- 1. Retrieve relevant context using the ChatFed retriever service
72
  2. Generate a response using the ChatFed generator service with the retrieved context
73
 
74
  Args:
75
  query (str): The user's input query/question to be processed
 
 
 
 
76
 
77
  Returns:
78
  str: The generated response from the ChatFed generator service
79
  """
80
- # run the graph with the initial state
81
- initial_state = {"query": query, "context": "", "result": ""}
 
 
 
 
 
 
 
82
  final_state = graph.invoke(initial_state)
83
  return final_state["result"]
84
 
 
85
  ui = gr.Interface(
86
- fn=pipeline,
87
  inputs=gr.Textbox(lines=2, placeholder="Enter query here"),
88
  outputs="text",
89
- title="ChatFed Orchestrator",
90
  flagging_mode="never"
91
  )
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  if __name__ == "__main__":
94
- ui.launch(server_name="0.0.0.0",
95
- server_port=7860,
96
- mcp_server=True,
97
- show_error=True)
 
 
 
1
  import gradio as gr
2
  from gradio_client import Client
3
  from langgraph.graph import StateGraph, START, END
4
+ from typing import TypedDict, Optional
5
+
6
+ #OPEN QUESTION: SHOULD WE PASS ALL PARAMS FROM THE ORCHESTRATOR TO THE NODES INSTEAD OF SETTING IN EACH MODULE?
7
 
8
  # Define the state schema
9
  class GraphState(TypedDict):
10
  query: str
11
  context: str
12
  result: str
13
+ # Add orchestrator-level parameters (addressing your open question)
14
+ reports_filter: str
15
+ sources_filter: str
16
+ subtype_filter: str
17
+ year_filter: str
 
 
 
 
 
 
 
 
18
 
19
  # node 2: retriever
20
  def retrieve_node(state: GraphState) -> GraphState:
21
  client = Client("giz/chatfed_retriever") # HF repo name
22
  context = client.predict(
23
  query=state["query"],
24
+ reports_filter=state.get("reports_filter", ""),
25
+ sources_filter=state.get("sources_filter", ""),
26
+ subtype_filter=state.get("subtype_filter", ""),
27
+ year_filter=state.get("year_filter", ""),
28
  api_name="/retrieve"
29
  )
30
  return {"context": context}
 
43
  workflow = StateGraph(GraphState)
44
 
45
  # Add nodes
 
46
  workflow.add_node("retrieve", retrieve_node)
47
  workflow.add_node("generate", generate_node)
48
 
49
  # Add edges
 
 
50
  workflow.add_edge(START, "retrieve")
51
  workflow.add_edge("retrieve", "generate")
52
  workflow.add_edge("generate", END)
 
54
  # Compile the graph
55
  graph = workflow.compile()
56
 
57
+ # Single tool for processing queries
58
+ def process_query(
59
+ query: str,
60
+ reports_filter: str = "",
61
+ sources_filter: str = "",
62
+ subtype_filter: str = "",
63
+ year_filter: str = ""
64
+ ) -> str:
65
  """
66
  Execute the ChatFed orchestration pipeline to process a user query.
67
 
68
  This function orchestrates a two-step workflow:
69
+ 1. Retrieve relevant context using the ChatFed retriever service with optional filters
70
  2. Generate a response using the ChatFed generator service with the retrieved context
71
 
72
  Args:
73
  query (str): The user's input query/question to be processed
74
+ reports_filter (str, optional): Filter for specific report types. Defaults to "".
75
+ sources_filter (str, optional): Filter for specific data sources. Defaults to "".
76
+ subtype_filter (str, optional): Filter for document subtypes. Defaults to "".
77
+ year_filter (str, optional): Filter for specific years. Defaults to "".
78
 
79
  Returns:
80
  str: The generated response from the ChatFed generator service
81
  """
82
+ initial_state = {
83
+ "query": query,
84
+ "context": "",
85
+ "result": "",
86
+ "reports_filter": reports_filter or "",
87
+ "sources_filter": sources_filter or "",
88
+ "subtype_filter": subtype_filter or "",
89
+ "year_filter": year_filter or ""
90
+ }
91
  final_state = graph.invoke(initial_state)
92
  return final_state["result"]
93
 
94
+ # Simple testing interface
95
  ui = gr.Interface(
96
+ fn=process_query,
97
  inputs=gr.Textbox(lines=2, placeholder="Enter query here"),
98
  outputs="text",
 
99
  flagging_mode="never"
100
  )
101
 
102
+ # Guidance for ChatUI - can be removed later. Questionable whether front end even necessary. Maybe nice to show the graph.
103
+ with gr.Blocks(title="ChatFed Orchestrator") as demo:
104
+ gr.Markdown("# ChatFed Orchestrator")
105
+ gr.Markdown("This LangGraph server exposes MCP endpoints for the ChatUI module to call.")
106
+ gr.Markdown("**Available MCP Tools:**")
107
+ gr.Markdown("- `process_query`: accepts query with optional filters")
108
+
109
+ with gr.Accordion("MCP Endpoint Information", open=True):
110
+ gr.Markdown(f"""
111
+ **MCP Server Endpoint:** https://giz-chatfed-orchestrator.hf.space/gradio_api/mcp/sse
112
+
113
+ **For ChatUI Integration:**
114
+ ```python
115
+ from gradio_client import Client
116
+
117
+ # Connect to orchestrator
118
+ orchestrator_client = Client("https://giz-chatfed-orchestrator.hf.space")
119
+
120
+ # Basic usage (no filters)
121
+ response = orchestrator_client.predict(
122
+ query="query",
123
+ api_name="/process_query"
124
+ )
125
+
126
+ # Advanced usage with any combination of filters
127
+ response = orchestrator_client.predict(
128
+ query="query",
129
+ reports_filter="annual_reports",
130
+ sources_filter="internal",
131
+ year_filter="2024",
132
+ api_name="/process_query"
133
+ )
134
+ ```
135
+ """)
136
+
137
+ with gr.Accordion("Quick Testing Interface", open=True):
138
+ ui.render()
139
+
140
  if __name__ == "__main__":
141
+ demo.launch(
142
+ server_name="0.0.0.0",
143
+ server_port=7860,
144
+ mcp_server=True,
145
+ show_error=True
146
+ )