ppsingh commited on
Commit
fcc054f
·
verified ·
1 Parent(s): 1b0f717

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +231 -0
app.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio_client import Client
3
+ from langgraph.graph import StateGraph, START, END
4
+ from typing import TypedDict, Optional
5
+ import io
6
+ from PIL import Image
7
+ import os
8
+
9
+
10
+ #OPEN QUESTION: SHOULD WE PASS ALL PARAMS FROM THE ORCHESTRATOR TO THE NODES INSTEAD OF SETTING IN EACH MODULE?
11
+ HF_TOKEN = os.environ.get("HF_TOKEN")
12
+
13
+
14
+
15
+ import configparser
16
+ import logging
17
+ import os
18
+ import ast
19
+ import re
20
+ from dotenv import load_dotenv
21
+
22
+ # Local .env file
23
+ load_dotenv()
24
+
25
+ def getconfig(configfile_path: str):
26
+ """
27
+ Read the config file
28
+ Params
29
+ ----------------
30
+ configfile_path: file path of .cfg file
31
+ """
32
+ config = configparser.ConfigParser()
33
+ try:
34
+ config.read_file(open(configfile_path))
35
+ return config
36
+ except:
37
+ logging.warning("config file not found")
38
+
39
+
40
+ def get_auth(provider: str) -> dict:
41
+ """Get authentication configuration for different providers"""
42
+ auth_configs = {
43
+ "huggingface": {"api_key": os.getenv("HF_TOKEN")},
44
+ "qdrant": {"api_key": os.getenv("QDRANT_API_KEY")},
45
+ }
46
+
47
+ provider = provider.lower() # Normalize to lowercase
48
+
49
+ if provider not in auth_configs:
50
+ raise ValueError(f"Unsupported provider: {provider}")
51
+
52
+ auth_config = auth_configs[provider]
53
+ api_key = auth_config.get("api_key")
54
+
55
+ if not api_key:
56
+ logging.warning(f"No API key found for provider '{provider}'. Please set the appropriate environment variable.")
57
+ auth_config["api_key"] = None
58
+
59
+ return auth_config
60
+
61
+
62
+ # Define the state schema
63
+ class GraphState(TypedDict):
64
+ query: str
65
+ context: str
66
+ result: str
67
+ # Add orchestrator-level parameters (addressing your open question)
68
+ reports_filter: str
69
+ sources_filter: str
70
+ subtype_filter: str
71
+ year_filter: str
72
+
73
+ # node 2: retriever
74
+ def retrieve_node(state: GraphState) -> GraphState:
75
+ client = Client("giz/chatfed_retriever", hf_token=HF_TOKEN) # HF repo name
76
+ context = client.predict(
77
+ query=state["query"],
78
+ reports_filter=state.get("reports_filter", ""),
79
+ sources_filter=state.get("sources_filter", ""),
80
+ subtype_filter=state.get("subtype_filter", ""),
81
+ year_filter=state.get("year_filter", ""),
82
+ api_name="/retrieve"
83
+ )
84
+ return {"context": context}
85
+
86
+ # node 3: generator
87
+ def generate_node(state: GraphState) -> GraphState:
88
+ client = Client("giz/chatfed_generator", hf_token=HF_TOKEN)
89
+ result = client.predict(
90
+ query=state["query"],
91
+ context=state["context"],
92
+ api_name="/generate"
93
+ )
94
+ return {"result": result}
95
+
96
+ # build the graph
97
+ workflow = StateGraph(GraphState)
98
+
99
+ # Add nodes
100
+ workflow.add_node("retrieve", retrieve_node)
101
+ workflow.add_node("generate", generate_node)
102
+
103
+ # Add edges
104
+ workflow.add_edge(START, "retrieve")
105
+ workflow.add_edge("retrieve", "generate")
106
+ workflow.add_edge("generate", END)
107
+
108
+ # Compile the graph
109
+ graph = workflow.compile()
110
+
111
+ # Single tool for processing queries
112
+ def process_query(
113
+ query: str,
114
+ reports_filter: str = "",
115
+ sources_filter: str = "",
116
+ subtype_filter: str = "",
117
+ year_filter: str = ""
118
+ ) -> str:
119
+ """
120
+ Execute the ChatFed orchestration pipeline to process a user query.
121
+
122
+ This function orchestrates a two-step workflow:
123
+ 1. Retrieve relevant context using the ChatFed retriever service with optional filters
124
+ 2. Generate a response using the ChatFed generator service with the retrieved context
125
+
126
+ Args:
127
+ query (str): The user's input query/question to be processed
128
+ reports_filter (str, optional): Filter for specific report types. Defaults to "".
129
+ sources_filter (str, optional): Filter for specific data sources. Defaults to "".
130
+ subtype_filter (str, optional): Filter for document subtypes. Defaults to "".
131
+ year_filter (str, optional): Filter for specific years. Defaults to "".
132
+
133
+ Returns:
134
+ str: The generated response from the ChatFed generator service
135
+ """
136
+ initial_state = {
137
+ "query": query,
138
+ "context": "",
139
+ "result": "",
140
+ "reports_filter": reports_filter or "",
141
+ "sources_filter": sources_filter or "",
142
+ "subtype_filter": subtype_filter or "",
143
+ "year_filter": year_filter or ""
144
+ }
145
+ final_state = graph.invoke(initial_state)
146
+ return final_state["result"]
147
+
148
+ # Simple testing interface
149
+ ui = gr.Interface(
150
+ fn=process_query,
151
+ inputs=gr.Textbox(lines=2, placeholder="Enter query here"),
152
+ outputs="text",
153
+ flagging_mode="never"
154
+ )
155
+
156
+ # Add a function to generate the graph visualization
157
+ def get_graph_visualization():
158
+ """Generate and return the LangGraph workflow visualization as a PIL Image."""
159
+ # Generate the graph as PNG bytes
160
+ graph_png_bytes = graph.get_graph().draw_mermaid_png()
161
+
162
+ # Convert bytes to PIL Image for Gradio display
163
+ graph_image = Image.open(io.BytesIO(graph_png_bytes))
164
+ return graph_image
165
+
166
+
167
+ # Guidance for ChatUI - can be removed later. Questionable whether front end even necessary. Maybe nice to show the graph.
168
+ with gr.Blocks(title="ChatFed Orchestrator") as demo:
169
+ gr.Markdown("# ChatFed Orchestrator")
170
+ gr.Markdown("This LangGraph server exposes MCP endpoints for the ChatUI module to call (which triggers the graph).")
171
+
172
+ with gr.Row():
173
+ # Left column - Graph visualization
174
+ with gr.Column(scale=1):
175
+ gr.Markdown("**Workflow Visualization**")
176
+ graph_display = gr.Image(
177
+ value=get_graph_visualization(),
178
+ label="LangGraph Workflow",
179
+ interactive=False,
180
+ height=300
181
+ )
182
+
183
+ # Add a refresh button for the graph
184
+ refresh_graph_btn = gr.Button("🔄 Refresh Graph", size="sm")
185
+ refresh_graph_btn.click(
186
+ fn=get_graph_visualization,
187
+ outputs=graph_display
188
+ )
189
+
190
+ # Right column - Interface and documentation
191
+ with gr.Column(scale=2):
192
+ gr.Markdown("**Available MCP Tools:**")
193
+
194
+ with gr.Accordion("MCP Endpoint Information", open=True):
195
+ gr.Markdown(f"""
196
+ **MCP Server Endpoint:** https://giz-chatfed-orchestrator.hf.space/gradio_api/mcp/sse
197
+
198
+ **For ChatUI Integration:**
199
+ ```python
200
+ from gradio_client import Client
201
+
202
+ # Connect to orchestrator
203
+ orchestrator_client = Client("https://giz-chatfed-orchestrator.hf.space")
204
+
205
+ # Basic usage (no filters)
206
+ response = orchestrator_client.predict(
207
+ query="query",
208
+ api_name="/process_query"
209
+ )
210
+
211
+ # Advanced usage with any combination of filters
212
+ response = orchestrator_client.predict(
213
+ query="query",
214
+ reports_filter="annual_reports",
215
+ sources_filter="internal",
216
+ year_filter="2024",
217
+ api_name="/process_query"
218
+ )
219
+ ```
220
+ """)
221
+
222
+ with gr.Accordion("Quick Testing Interface", open=True):
223
+ ui.render()
224
+
225
+ if __name__ == "__main__":
226
+ demo.launch(
227
+ server_name="0.0.0.0",
228
+ server_port=7860,
229
+ mcp_server=True,
230
+ show_error=True
231
+ )