om4r932 commited on
Commit
c61d3cd
·
verified ·
1 Parent(s): 265ee2a

Deploy Gradio first time

Browse files
Files changed (3) hide show
  1. app.py +261 -0
  2. requirements.txt +6 -0
  3. server.py +118 -0
app.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mcp_groq_gradio.py
2
+ import asyncio
3
+ import gradio as gr
4
+ import os
5
+ import json
6
+ import httpx
7
+ from contextlib import AsyncExitStack
8
+ from typing import Optional
9
+ from dotenv import load_dotenv
10
+ from groq import Groq
11
+
12
+ # MCP imports
13
+ from mcp import ClientSession, StdioServerParameters
14
+ from mcp.client.stdio import stdio_client
15
+
16
+ # Load environment
17
+ load_dotenv()
18
+
19
+ class MCPGroqClient:
20
+ """Unified client handling MCP server and Groq integration"""
21
+
22
+ def __init__(self):
23
+ self.session: Optional[ClientSession] = None
24
+ self.current_model = ""
25
+ self.exit_stack = AsyncExitStack()
26
+ self.groq = None
27
+ self.available_tools = None
28
+
29
+ async def connect(self):
30
+ """Establish MCP STDIO connection"""
31
+ server_params = StdioServerParameters(
32
+ command="uv",
33
+ args=["run","server.py"]
34
+ )
35
+
36
+ transport = await self.exit_stack.enter_async_context(
37
+ stdio_client(server_params)
38
+ )
39
+ stdio_reader, stdio_writer = transport
40
+ self.session = await self.exit_stack.enter_async_context(
41
+ ClientSession(stdio_reader, stdio_writer)
42
+ )
43
+ await self.session.initialize()
44
+
45
+ self.available_tools = await self.session.list_tools()
46
+
47
+ async def stream_response(self, query: str):
48
+ """Handle streaming with failed_generation debugging"""
49
+ messages = [{"role": "user", "content": query}]
50
+
51
+ try:
52
+ tools = await self._get_mcp_tools()
53
+
54
+ # Get sync stream and create async wrapper
55
+ sync_stream = self.groq.chat.completions.create(
56
+ model=self.current_model,
57
+ max_tokens=5500,
58
+ messages=messages,
59
+ tools=tools,
60
+ stream=True
61
+ )
62
+
63
+ async def async_wrapper():
64
+ for chunk in sync_stream:
65
+ yield chunk
66
+ sync_stream.close()
67
+
68
+ full_response = ""
69
+ async for chunk in async_wrapper():
70
+ if content := chunk.choices[0].delta.content:
71
+ full_response += content
72
+ yield content
73
+
74
+ if tool_calls := chunk.choices[0].delta.tool_calls:
75
+ await self._process_tool_calls(tool_calls, messages)
76
+ async for tool_chunk in self._stream_tool_response(messages):
77
+ full_response += tool_chunk
78
+ yield tool_chunk
79
+
80
+ except Exception as e:
81
+ # Handle Groq-specific errors
82
+ if hasattr(e, "body") and "failed_generation" in e.body:
83
+ failed_generation = e.response_body["failed_generation"]
84
+ yield f"\n⚠️ Error: Failed to call a function. Invalid generation:\n{failed_generation}"
85
+ else:
86
+ yield f"\n⚠️ Critical Error: {str(e)}"
87
+ finally:
88
+ if 'sync_stream' in locals():
89
+ sync_stream.close()
90
+
91
+ async def _get_mcp_tools(self):
92
+ response = await self.session.list_tools()
93
+ return [{
94
+ "type": "function",
95
+ "function": {
96
+ "name": tool.name,
97
+ "description": tool.description,
98
+ "parameters": tool.inputSchema
99
+ }
100
+ } for tool in response.tools]
101
+
102
+ async def _get_available_models(self):
103
+ try:
104
+ models = self.groq.models.list()
105
+ return sorted([model.id for model in models.data if model.active])
106
+ except Exception as e:
107
+ print(e)
108
+ return sorted([
109
+ "llama-3.3-70b-versatile",
110
+ "llama-3.1-8b-instant",
111
+ "gemma2-9b-it"
112
+ ])
113
+
114
+ async def _process_tool_calls(self, tool_calls, messages):
115
+ for tool in tool_calls:
116
+ func = tool.function
117
+ result = await self.session.call_tool(
118
+ func.name,
119
+ json.loads(func.arguments)
120
+ )
121
+ messages.append({
122
+ "role": "tool",
123
+ "content": str(result.content),
124
+ "tool_call_id": tool.id
125
+ })
126
+
127
+ async def _stream_tool_response(self, messages):
128
+ """Async wrapper for tool response streaming"""
129
+ sync_stream = self.groq.chat.completions.create(
130
+ model=self.current_model,
131
+ max_tokens=5500,
132
+ messages=messages,
133
+ stream=True
134
+ )
135
+
136
+ async def tool_async_wrapper():
137
+ for chunk in sync_stream:
138
+ yield chunk
139
+ sync_stream.close()
140
+
141
+ async for chunk in tool_async_wrapper():
142
+ if content := chunk.choices[0].delta.content:
143
+ yield content
144
+
145
+ def create_interface():
146
+ # Initialize client without API key
147
+ client = MCPGroqClient()
148
+ client.groq = None # Remove initial Groq client
149
+
150
+ with gr.Blocks(theme=gr.themes.Soft(), title="MCP-Groq Client") as interface:
151
+ gr.Markdown("## MCP STDIO/Groq Chat Interface")
152
+
153
+ # Connection Section
154
+ with gr.Row():
155
+ api_key_input = gr.Textbox(
156
+ label="Groq API Key",
157
+ placeholder="gsk_...",
158
+ type="password",
159
+ interactive=True
160
+ )
161
+ connect_btn = gr.Button("Connect", variant="primary")
162
+ connection_status = gr.Textbox(
163
+ label="Status",
164
+ interactive=False,
165
+ value="Disconnected"
166
+ )
167
+
168
+ # Main Chat Interface (initially hidden)
169
+ with gr.Row(visible=False) as chat_row:
170
+ with gr.Column(scale=0.6):
171
+ chatbot = gr.Chatbot(height=600)
172
+ input_box = gr.Textbox(placeholder="Type message...")
173
+ submit_btn = gr.Button("Send", variant="primary")
174
+
175
+ with gr.Column(scale=0.4):
176
+ model_selector = gr.Dropdown(
177
+ label="Available Models",
178
+ interactive=True,
179
+ visible=False
180
+ )
181
+
182
+ available_tools = gr.Textbox(
183
+ label="Tools Available",
184
+ interactive=False,
185
+ visible=False
186
+ )
187
+
188
+ # Connect Button Logic
189
+ def connect_client(api_key):
190
+ try:
191
+ # Initialize Groq client with provided API key
192
+ client.groq = Groq(api_key=api_key, http_client=httpx.Client(verify=False))
193
+
194
+ loop = asyncio.new_event_loop()
195
+ asyncio.set_event_loop(loop)
196
+ loop.run_until_complete(client.connect())
197
+
198
+ models = client.groq.models.list().data
199
+ active_models = sorted([m.id for m in models if m.active])
200
+ tools_list = "\n".join([f"• {t.name}: {t.description}" for t in client.available_tools.tools])
201
+
202
+ return {
203
+ connection_status: "Connected ✅",
204
+ chat_row: gr.update(visible=True),
205
+ model_selector: gr.update(
206
+ choices=active_models,
207
+ value=active_models[0] if active_models else "",
208
+ visible=True
209
+ ),
210
+ available_tools: gr.update(
211
+ value=tools_list,
212
+ visible=True
213
+ ),
214
+ connect_btn: gr.update(visible=False),
215
+ api_key_input: gr.update(interactive=False)
216
+ }
217
+ except Exception as e:
218
+ return {
219
+ connection_status: f"Connection failed: {str(e)}",
220
+ chat_row: gr.update(visible=False)
221
+ }
222
+
223
+ connect_btn.click(
224
+ connect_client,
225
+ inputs=api_key_input,
226
+ outputs=[connection_status, chat_row, model_selector, connect_btn, api_key_input, available_tools]
227
+ )
228
+
229
+ # Chat Handling
230
+ async def chat_stream(query, history, selected_model):
231
+ client.current_model = selected_model
232
+
233
+ # Initialize fresh client session
234
+ if not client.session:
235
+ await client.connect()
236
+
237
+ accumulated_response = ""
238
+ async for chunk in client.stream_response(query):
239
+ accumulated_response += chunk
240
+ yield "", history + [(query, accumulated_response)]
241
+
242
+ yield "", history + [(query, accumulated_response)]
243
+
244
+ submit_btn.click(
245
+ chat_stream,
246
+ [input_box, chatbot, model_selector],
247
+ [input_box, chatbot],
248
+ show_progress="hidden"
249
+ )
250
+ input_box.submit(
251
+ chat_stream,
252
+ [input_box, chatbot, model_selector],
253
+ [input_box, chatbot],
254
+ show_progress="hidden"
255
+ )
256
+
257
+ return interface
258
+
259
+ if __name__ == "__main__":
260
+ interface = create_interface()
261
+ interface.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ python-dotenv
2
+ mcp[cli]
3
+ uv
4
+ httpx
5
+ groq
6
+ gradio
server.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Literal
2
+ import httpx
3
+ from mcp.server.fastmcp import FastMCP
4
+
5
+ # Initialize FastMCP server
6
+ mcp = FastMCP("arxiv-omar")
7
+
8
+ # Constants
9
+ CUSTOM_ARXIV_API_BASE = "https://om4r932-arxiv.hf.space"
10
+ DDG_API_BASE = "https://ychkhan-ptt-endpoints.hf.space"
11
+
12
+ # Helpers
13
+ async def make_request(url: str, data: dict = None) -> dict[str, Any] | None:
14
+ if data is None:
15
+ return None
16
+ headers = {
17
+ "Accept": "application/json"
18
+ }
19
+ async with httpx.AsyncClient(verify=False) as client:
20
+ try:
21
+ response = await client.post(url, headers=headers, json=data)
22
+ print(response)
23
+ response.raise_for_status()
24
+ return response.json()
25
+ except Exception as e:
26
+ return None
27
+
28
+ def format_search(pub_id: str, content: dict) -> str:
29
+ return f"""
30
+ arXiv publication ID : {pub_id}
31
+ Title : {content["title"]}
32
+ Authors : {content["authors"]}
33
+ Release Date : {content["date"]}
34
+ Abstract : {content["abstract"]}
35
+ PDF link : {content["pdf"]}
36
+ """
37
+
38
+ def format_extract(message: dict) -> str:
39
+ return f"""
40
+ Title of PDF : {message.get("title", "No title has been found")}
41
+ Text : {message.get("text", "No text !")}
42
+ """
43
+
44
+ def format_result_search(page: dict):
45
+ return f"""
46
+ Title : {page.get("title", "No titles found !")}
47
+ Little description : {page.get("body", "No description")}
48
+ PDF url : {page.get("url", None)}
49
+ """
50
+
51
+ # Tools
52
+ @mcp.tool()
53
+ async def get_publications(keyword: str, limit: int = 15) -> str:
54
+ """
55
+ Get arXiv publications based on keywords and limit of documents
56
+
57
+ Args:
58
+ keyword: Keywords separated by spaces
59
+ limit: Numbers of maximum publications returned (by default, 15)
60
+ """
61
+ url = f"{CUSTOM_ARXIV_API_BASE}/search"
62
+ data = await make_request(url, data={'keyword': keyword, 'limit': limit})
63
+ if data["error"]:
64
+ return data["message"]
65
+ if not data:
66
+ return "Unable to fetch publications"
67
+ if len(data["message"].keys()) == 0:
68
+ return "No publications found"
69
+
70
+ publications = [format_search(pub_id, content) for (pub_id, content) in data["message"].items()]
71
+ return "\n--\n".join(publications)
72
+
73
+ @mcp.tool()
74
+ async def web_search(query: str) -> str:
75
+ """
76
+ Search the Web (thanks to DuckDuckGo) for all PDF files based on the keywords
77
+
78
+ Args:
79
+ query: Keywords to search documents on the Web
80
+ """
81
+
82
+ url = f"{DDG_API_BASE}/search"
83
+ data = await make_request(url, data={"query": query})
84
+ if not data:
85
+ return "Unable to fetch results"
86
+ if len(data["results"]) == 0:
87
+ return "No results found"
88
+
89
+ results = [format_result_search(result) for result in data["results"]]
90
+ return "\n--\n".join(results)
91
+
92
+
93
+ @mcp.tool()
94
+ async def get_pdf_text(pdf_url: str, limit_page: int = -1) -> str:
95
+ """
96
+ Extract the text from the URL pointing to a PDF file
97
+
98
+ Args:
99
+ pdf_url: URL to a PDF document
100
+ limit_page: How many pages the user wants to extract the content (default: -1 for all pages)
101
+ """
102
+
103
+ url = f"{CUSTOM_ARXIV_API_BASE}/extract_pdf/url"
104
+ data = {"url": pdf_url}
105
+ if limit_page != -1:
106
+ data["page_num"] = limit_page
107
+ data = await make_request(url, data=data)
108
+ if data["error"]:
109
+ return data["message"]
110
+ if not data:
111
+ return "Unable to extract PDF text"
112
+ if len(data["message"].keys()) == 0:
113
+ return "No text can be extracted from this PDF"
114
+
115
+ return format_extract(data["message"])
116
+
117
+ if __name__ == "__main__":
118
+ mcp.run(transport="stdio")