MCPGroqClient / app.py
Omar ID EL MOUMEN
Fix chat not responding
d59b539
raw
history blame
9.25 kB
# mcp_groq_gradio.py
import asyncio
import gradio as gr
import os
import json
import httpx
from contextlib import AsyncExitStack
from typing import Optional
from dotenv import load_dotenv
from groq import Groq
# MCP imports
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
# Load environment
load_dotenv()
class MCPGroqClient:
"""Unified client handling MCP server and Groq integration"""
def __init__(self):
self.session: Optional[ClientSession] = None
self.current_model = ""
self.exit_stack = AsyncExitStack()
self.groq = None
self.available_tools = None
async def connect(self):
"""Establish MCP STDIO connection"""
server_params = StdioServerParameters(
command="uv",
args=["run","server.py"]
)
transport = await self.exit_stack.enter_async_context(
stdio_client(server_params)
)
stdio_reader, stdio_writer = transport
self.session = await self.exit_stack.enter_async_context(
ClientSession(stdio_reader, stdio_writer)
)
await self.session.initialize()
self.available_tools = await self.session.list_tools()
async def stream_response(self, query: str):
"""Handle streaming with failed_generation debugging"""
messages = [{"role": "user", "content": query}]
try:
tools = await self._get_mcp_tools()
# Get sync stream and create async wrapper
sync_stream = self.groq.chat.completions.create(
model=self.current_model,
max_tokens=5500,
messages=messages,
tools=tools,
stream=True
)
async def async_wrapper():
for chunk in sync_stream:
yield chunk
sync_stream.close()
full_response = ""
async for chunk in async_wrapper():
if content := chunk.choices[0].delta.content:
full_response += content
yield content
if tool_calls := chunk.choices[0].delta.tool_calls:
await self._process_tool_calls(tool_calls, messages)
async for tool_chunk in self._stream_tool_response(messages):
full_response += tool_chunk
yield tool_chunk
except Exception as e:
# Handle Groq-specific errors
if hasattr(e, "body") and "failed_generation" in e.body:
failed_generation = e.response_body["failed_generation"]
yield f"\n⚠️ Error: Failed to call a function. Invalid generation:\n{failed_generation}"
else:
yield f"\n⚠️ Critical Error: {str(e)}"
finally:
if 'sync_stream' in locals():
sync_stream.close()
async def _get_mcp_tools(self):
response = await self.session.list_tools()
return [{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.inputSchema
}
} for tool in response.tools]
async def _get_available_models(self):
try:
models = self.groq.models.list()
return sorted([model.id for model in models.data if model.active])
except Exception as e:
print(e)
return sorted([
"llama-3.3-70b-versatile",
"llama-3.1-8b-instant",
"gemma2-9b-it"
])
async def _process_tool_calls(self, tool_calls, messages):
for tool in tool_calls:
func = tool.function
result = await self.session.call_tool(
func.name,
json.loads(func.arguments)
)
messages.append({
"role": "tool",
"content": str(result.content),
"tool_call_id": tool.id
})
async def _stream_tool_response(self, messages):
"""Async wrapper for tool response streaming"""
sync_stream = self.groq.chat.completions.create(
model=self.current_model,
max_tokens=5500,
messages=messages,
stream=True
)
async def tool_async_wrapper():
for chunk in sync_stream:
yield chunk
sync_stream.close()
async for chunk in tool_async_wrapper():
if content := chunk.choices[0].delta.content:
yield content
def create_interface():
# Initialize client without API key
client = MCPGroqClient()
client.groq = None # Remove initial Groq client
with gr.Blocks(theme=gr.themes.Soft(), title="MCP-Groq Client") as interface:
gr.Markdown("## MCP STDIO/Groq Chat Interface")
# Connection Section
with gr.Row():
api_key_input = gr.Textbox(
label="Groq API Key",
placeholder="gsk_...",
type="password",
interactive=True
)
connect_btn = gr.Button("Connect", variant="primary")
connection_status = gr.Textbox(
label="Status",
interactive=False,
value="Disconnected"
)
# Main Chat Interface (initially hidden)
with gr.Row(visible=False) as chat_row:
with gr.Column(scale=0.6):
chatbot = gr.Chatbot(height=600)
input_box = gr.Textbox(placeholder="Type message...", lines=1)
submit_btn = gr.Button("Send", variant="primary")
with gr.Column(scale=0.4):
model_selector = gr.Dropdown(
label="Available Models",
interactive=True,
visible=False
)
available_tools = gr.Textbox(
label="Tools Available",
interactive=False,
visible=False
)
# Connect Button Logic
def connect_client(api_key):
try:
# Initialize Groq client with provided API key
client.groq = Groq(api_key=api_key, http_client=httpx.Client(verify=False))
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(client.connect())
models = client.groq.models.list().data
active_models = sorted([m.id for m in models if m.active])
tools_list = "\n".join([f"• {t.name}: {t.description}" for t in client.available_tools.tools])
return {
connection_status: "Connected ✅",
chat_row: gr.update(visible=True),
model_selector: gr.update(
choices=active_models,
value=active_models[0] if active_models else "",
visible=True
),
available_tools: gr.update(
value=tools_list,
visible=True
),
connect_btn: gr.update(visible=False),
api_key_input: gr.update(interactive=False)
}
except Exception as e:
return {
connection_status: f"Connection failed: {str(e)}",
chat_row: gr.update(visible=False)
}
connect_btn.click(
connect_client,
inputs=api_key_input,
outputs=[connection_status, chat_row, model_selector, available_tools, connect_btn, api_key_input]
)
# Chat Handling
async def chat_stream(query, history, selected_model):
print(f"Received query: {query}") # Debugging log
client.current_model = selected_model
# Initialize fresh client session if not connected
if not client.session:
await client.connect()
accumulated_response = ""
async for chunk in client.stream_response(query):
accumulated_response += chunk
yield "", history + [(query, accumulated_response)]
yield "", history + [(query, accumulated_response)]
submit_btn.click(
chat_stream,
[input_box, chatbot, model_selector],
[input_box, chatbot],
show_progress="hidden"
)
input_box.submit(
chat_stream,
[input_box, chatbot, model_selector],
[input_box, chatbot],
show_progress="hidden"
)
return interface
if __name__ == "__main__":
interface = create_interface()
interface.queue().launch(
server_port=7860,
server_name="0.0.0.0",
show_error=True
)