MCPGroqClient / app.py
Omar ID EL MOUMEN
Big fix
6b639db
raw
history blame
8.63 kB
# mcp_groq_gradio.py
import gradio as gr
import json
import traceback
import httpx
from contextlib import AsyncExitStack
from typing import Optional
from dotenv import load_dotenv
from groq import Groq
# MCP imports
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
# Load environment
load_dotenv()
class MCPGroqClient:
"""Unified client handling MCP server and Groq integration"""
def __init__(self):
self.session: Optional[ClientSession] = None
self.exit_stack = AsyncExitStack()
self.current_model = ""
self.groq = None
async def connect(self):
"""Establish MCP STDIO connection"""
server_params = StdioServerParameters(
command="uv",
args=["run","server.py"]
)
transport = await self.exit_stack.enter_async_context(
stdio_client(server_params)
)
stdio_reader, stdio_writer = transport
self.session = await self.exit_stack.enter_async_context(
ClientSession(stdio_reader, stdio_writer)
)
await self.session.initialize()
async def stream_response(self, query: str):
"""Handle streaming with failed_generation debugging"""
messages = [{"role": "user", "content": query}]
try:
tools = await self._get_mcp_tools()
# Get sync stream and create async wrapper
sync_stream = self.groq.chat.completions.create(
model=self.current_model,
max_tokens=5500,
messages=messages,
tools=tools,
stream=True
)
async def async_wrapper():
for chunk in sync_stream:
yield chunk
sync_stream.close()
full_response = ""
async for chunk in async_wrapper():
if content := chunk.choices[0].delta.content:
full_response += content
yield content
if tool_calls := chunk.choices[0].delta.tool_calls:
await self._process_tool_calls(tool_calls, messages)
async for tool_chunk in self._stream_tool_response(messages):
full_response += tool_chunk
yield tool_chunk
except Exception as e:
# Handle Groq-specific errors
traceback.print_exception(e)
if hasattr(e, "body") and "failed_generation" in e.body:
failed_generation = e.body["failed_generation"]
yield f"\n⚠️ Error: Failed to call a function. Invalid generation:\n{failed_generation}"
else:
yield f"\n⚠️ Critical Error: {str(e)}"
finally:
if 'sync_stream' in locals():
sync_stream.close()
async def _get_mcp_tools(self):
response = await self.session.list_tools()
return [{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.inputSchema
}
} for tool in response.tools]
async def _process_tool_calls(self, tool_calls, messages):
for tool in tool_calls:
func = tool.function
result = await self.session.call_tool(
func.name,
json.loads(func.arguments)
)
messages.append({
"role": "tool",
"content": str(result.content),
"tool_call_id": tool.id
})
async def _stream_tool_response(self, messages):
"""Async wrapper for tool response streaming"""
sync_stream = self.groq.chat.completions.create(
model=self.current_model,
max_tokens=5500,
messages=messages,
stream=True
)
async def tool_async_wrapper():
for chunk in sync_stream:
yield chunk
sync_stream.close()
async for chunk in tool_async_wrapper():
if content := chunk.choices[0].delta.content:
yield content
def create_interface():
# Initialize client without API key
client = MCPGroqClient()
client.groq = None # Remove initial Groq client
with gr.Blocks(theme=gr.themes.Soft(), title="MCP-Groq Client") as interface:
gr.Markdown("## MCP STDIO/Groq Chat Interface")
# Connection Section
with gr.Row():
api_key_input = gr.Textbox(
label="Groq API Key",
placeholder="gsk_...",
type="password",
interactive=True
)
connect_btn = gr.Button("Connect", variant="primary")
connection_status = gr.Textbox(
label="Status",
interactive=False,
value="Disconnected"
)
# Main Chat Interface (initially hidden)
with gr.Row(visible=False) as chat_row:
with gr.Column(scale=0.6):
chatbot = gr.Chatbot(height=600)
input_box = gr.Textbox(placeholder="Type message...")
submit_btn = gr.Button("Send", variant="primary")
with gr.Column(scale=0.4):
model_selector = gr.Dropdown(
label="Available Models",
interactive=True,
visible=False
)
# Connect Button Logic
def connect_client(api_key):
try:
# Initialize Groq client with provided API key
client.groq = Groq(api_key=api_key, http_client=httpx.Client(verify=False))
models = client.groq.models.list().data
available_models = sorted([m.id for m in models if m.active])
compatible_models = [
"qwen-qwq-32b",
"qwen-2.5-coder-32b",
"qwen-2.5-32b",
"deepseek-r1-distill-qwen-32b",
"deepseek-r1-distill-llama-70b",
"llama-3.3-70b-versatile",
"llama-3.1-8b-instant",
"mixtral-8x7b-32768",
"gemma2-9b-it"
]
active_models = sorted([model for model in available_models if model in compatible_models])
return {
connection_status: "Connected ✅",
chat_row: gr.update(visible=True),
model_selector: gr.update(
choices=active_models,
value=active_models[0] if active_models else "",
visible=True
),
connect_btn: gr.update(visible=False),
api_key_input: gr.update(interactive=False)
}
except Exception as e:
return {
connection_status: f"Connection failed: {str(e)}",
chat_row: gr.update(visible=False)
}
connect_btn.click(
connect_client,
inputs=api_key_input,
outputs=[connection_status, chat_row, model_selector, connect_btn, api_key_input]
)
# Chat Handling
async def chat_stream(query, history, selected_model):
client.current_model = selected_model
# Initialize fresh client session
if not client.session:
await client.connect()
accumulated_response = ""
async for chunk in client.stream_response(query):
accumulated_response += chunk
yield "", history + [(query, accumulated_response)]
yield "", history + [(query, accumulated_response)]
submit_btn.click(
chat_stream,
[input_box, chatbot, model_selector],
[input_box, chatbot],
show_progress="hidden"
)
input_box.submit(
chat_stream,
[input_box, chatbot, model_selector],
[input_box, chatbot],
show_progress="hidden"
)
return interface
if __name__ == "__main__":
interface = create_interface()
interface.queue().launch(
server_port=7860,
server_name="0.0.0.0",
show_error=True
)