# mcp_groq_gradio.py import gradio as gr import json import traceback import httpx from contextlib import AsyncExitStack from typing import Optional from dotenv import load_dotenv from groq import Groq # MCP imports from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client # Load environment load_dotenv() class MCPGroqClient: """Unified client handling MCP server and Groq integration""" def __init__(self): self.session: Optional[ClientSession] = None self.exit_stack = AsyncExitStack() self.current_model = "" self.groq = None async def connect(self): """Establish MCP STDIO connection""" server_params = StdioServerParameters( command="uv", args=["run","server.py"] ) transport = await self.exit_stack.enter_async_context( stdio_client(server_params) ) stdio_reader, stdio_writer = transport self.session = await self.exit_stack.enter_async_context( ClientSession(stdio_reader, stdio_writer) ) await self.session.initialize() async def stream_response(self, query: str): """Handle streaming with failed_generation debugging""" messages = [{"role": "user", "content": query}] try: tools = await self._get_mcp_tools() # Get sync stream and create async wrapper sync_stream = self.groq.chat.completions.create( model=self.current_model, max_tokens=5500, messages=messages, tools=tools, stream=True ) async def async_wrapper(): for chunk in sync_stream: yield chunk sync_stream.close() full_response = "" async for chunk in async_wrapper(): if content := chunk.choices[0].delta.content: full_response += content yield content if tool_calls := chunk.choices[0].delta.tool_calls: await self._process_tool_calls(tool_calls, messages) async for tool_chunk in self._stream_tool_response(messages): full_response += tool_chunk yield tool_chunk except Exception as e: # Handle Groq-specific errors traceback.print_exception(e) if hasattr(e, "body") and "failed_generation" in e.body: failed_generation = e.body["failed_generation"] yield f"\n⚠️ Error: Failed to call a function. Invalid generation:\n{failed_generation} {messages[-1]}" else: yield f"\n⚠️ Critical Error: {str(e)}" finally: if 'sync_stream' in locals(): sync_stream.close() async def _get_mcp_tools(self): response = await self.session.list_tools() return [{ "type": "function", "function": { "name": tool.name, "description": tool.description, "parameters": tool.inputSchema } } for tool in response.tools] async def _process_tool_calls(self, tool_calls, messages): for tool in tool_calls: func = tool.function result = await self.session.call_tool( func.name, json.loads(func.arguments) ) messages.append({ "role": "tool", "content": str(result.content), "tool_call_id": tool.id }) async def _stream_tool_response(self, messages): """Async wrapper for tool response streaming""" sync_stream = self.groq.chat.completions.create( model=self.current_model, max_tokens=5500, messages=messages, stream=True ) async def tool_async_wrapper(): for chunk in sync_stream: yield chunk sync_stream.close() async for chunk in tool_async_wrapper(): if content := chunk.choices[0].delta.content: yield content def create_interface(): # Initialize client without API key client = MCPGroqClient() client.groq = None # Remove initial Groq client with gr.Blocks(theme=gr.themes.Soft(), title="MCP-Groq Client") as interface: gr.Markdown("## MCP STDIO/Groq Chat Interface") # Connection Section with gr.Row(): api_key_input = gr.Textbox( label="Groq API Key", placeholder="gsk_...", type="password", interactive=True ) connect_btn = gr.Button("Connect", variant="primary") connection_status = gr.Textbox( label="Status", interactive=False, value="Disconnected" ) # Main Chat Interface (initially hidden) with gr.Row(visible=False) as chat_row: with gr.Column(scale=0.6): chatbot = gr.Chatbot(height=600) input_box = gr.Textbox(placeholder="Type message...") submit_btn = gr.Button("Send", variant="primary") with gr.Column(scale=0.4): model_selector = gr.Dropdown( label="Available Models", interactive=True, visible=False ) # Connect Button Logic def connect_client(api_key): try: # Initialize Groq client with provided API key client.groq = Groq(api_key=api_key, http_client=httpx.Client(verify=False)) models = client.groq.models.list().data available_models = sorted([m.id for m in models if m.active]) compatible_models = [ "qwen-qwq-32b", "qwen-2.5-coder-32b", "qwen-2.5-32b", "deepseek-r1-distill-qwen-32b", "deepseek-r1-distill-llama-70b", "llama-3.3-70b-versatile", "llama-3.1-8b-instant", "mixtral-8x7b-32768", "gemma2-9b-it" ] active_models = sorted([model for model in available_models if model in compatible_models]) return { connection_status: "Connected ✅", chat_row: gr.update(visible=True), model_selector: gr.update( choices=active_models, value=active_models[0] if active_models else "", visible=True ), connect_btn: gr.update(visible=False), api_key_input: gr.update(interactive=False) } except Exception as e: return { connection_status: f"Connection failed: {str(e)}", chat_row: gr.update(visible=False) } connect_btn.click( connect_client, inputs=api_key_input, outputs=[connection_status, chat_row, model_selector, connect_btn, api_key_input] ) # Chat Handling async def chat_stream(query, history, selected_model): client.current_model = selected_model # Initialize fresh client session if not client.session: await client.connect() accumulated_response = "" async for chunk in client.stream_response(query): accumulated_response += chunk yield "", history + [(query, accumulated_response)] yield "", history + [(query, accumulated_response)] submit_btn.click( chat_stream, [input_box, chatbot, model_selector], [input_box, chatbot], show_progress="hidden" ) input_box.submit( chat_stream, [input_box, chatbot, model_selector], [input_box, chatbot], show_progress="hidden" ) return interface if __name__ == "__main__": interface = create_interface() interface.queue().launch( server_port=7860, server_name="0.0.0.0", show_error=True )