|
|
|
import gradio as gr |
|
import json |
|
import traceback |
|
import httpx |
|
from contextlib import AsyncExitStack |
|
from typing import Optional |
|
from dotenv import load_dotenv |
|
from groq import Groq |
|
|
|
|
|
from mcp import ClientSession, StdioServerParameters |
|
from mcp.client.stdio import stdio_client |
|
|
|
|
|
load_dotenv() |
|
|
|
class MCPGroqClient: |
|
"""Unified client handling MCP server and Groq integration""" |
|
|
|
def __init__(self): |
|
self.session: Optional[ClientSession] = None |
|
self.exit_stack = AsyncExitStack() |
|
self.current_model = "" |
|
self.groq = None |
|
|
|
async def connect(self): |
|
"""Establish MCP STDIO connection""" |
|
server_params = StdioServerParameters( |
|
command="uv", |
|
args=["run","server.py"] |
|
) |
|
|
|
transport = await self.exit_stack.enter_async_context( |
|
stdio_client(server_params) |
|
) |
|
stdio_reader, stdio_writer = transport |
|
self.session = await self.exit_stack.enter_async_context( |
|
ClientSession(stdio_reader, stdio_writer) |
|
) |
|
await self.session.initialize() |
|
|
|
async def stream_response(self, query: str): |
|
"""Handle streaming with failed_generation debugging""" |
|
messages = [{"role": "user", "content": query}] |
|
|
|
try: |
|
tools = await self._get_mcp_tools() |
|
|
|
|
|
sync_stream = self.groq.chat.completions.create( |
|
model=self.current_model, |
|
max_tokens=5500, |
|
messages=messages, |
|
tools=tools, |
|
stream=True |
|
) |
|
|
|
async def async_wrapper(): |
|
for chunk in sync_stream: |
|
yield chunk |
|
sync_stream.close() |
|
|
|
full_response = "" |
|
async for chunk in async_wrapper(): |
|
if content := chunk.choices[0].delta.content: |
|
full_response += content |
|
yield content |
|
|
|
if tool_calls := chunk.choices[0].delta.tool_calls: |
|
await self._process_tool_calls(tool_calls, messages) |
|
async for tool_chunk in self._stream_tool_response(messages): |
|
full_response += tool_chunk |
|
yield tool_chunk |
|
|
|
except Exception as e: |
|
|
|
traceback.print_exception(e) |
|
if hasattr(e, "body") and "failed_generation" in e.body: |
|
failed_generation = e.body["failed_generation"] |
|
yield f"\n⚠️ Error: Failed to call a function. Invalid generation:\n{failed_generation}" |
|
else: |
|
yield f"\n⚠️ Critical Error: {str(e)}" |
|
finally: |
|
if 'sync_stream' in locals(): |
|
sync_stream.close() |
|
|
|
async def _get_mcp_tools(self): |
|
response = await self.session.list_tools() |
|
return [{ |
|
"type": "function", |
|
"function": { |
|
"name": tool.name, |
|
"description": tool.description, |
|
"parameters": tool.inputSchema |
|
} |
|
} for tool in response.tools] |
|
|
|
async def _process_tool_calls(self, tool_calls, messages): |
|
for tool in tool_calls: |
|
func = tool.function |
|
result = await self.session.call_tool( |
|
func.name, |
|
json.loads(func.arguments) |
|
) |
|
messages.append({ |
|
"role": "tool", |
|
"content": str(result.content), |
|
"tool_call_id": tool.id |
|
}) |
|
|
|
async def _stream_tool_response(self, messages): |
|
"""Async wrapper for tool response streaming""" |
|
sync_stream = self.groq.chat.completions.create( |
|
model=self.current_model, |
|
max_tokens=5500, |
|
messages=messages, |
|
stream=True |
|
) |
|
|
|
async def tool_async_wrapper(): |
|
for chunk in sync_stream: |
|
yield chunk |
|
sync_stream.close() |
|
|
|
async for chunk in tool_async_wrapper(): |
|
if content := chunk.choices[0].delta.content: |
|
yield content |
|
|
|
def create_interface(): |
|
|
|
client = MCPGroqClient() |
|
client.groq = None |
|
|
|
with gr.Blocks(theme=gr.themes.Soft(), title="MCP-Groq Client") as interface: |
|
gr.Markdown("## MCP STDIO/Groq Chat Interface") |
|
|
|
|
|
with gr.Row(): |
|
api_key_input = gr.Textbox( |
|
label="Groq API Key", |
|
placeholder="gsk_...", |
|
type="password", |
|
interactive=True |
|
) |
|
connect_btn = gr.Button("Connect", variant="primary") |
|
connection_status = gr.Textbox( |
|
label="Status", |
|
interactive=False, |
|
value="Disconnected" |
|
) |
|
|
|
|
|
with gr.Row(visible=False) as chat_row: |
|
with gr.Column(scale=0.6): |
|
chatbot = gr.Chatbot(height=600) |
|
input_box = gr.Textbox(placeholder="Type message...") |
|
submit_btn = gr.Button("Send", variant="primary") |
|
|
|
with gr.Column(scale=0.4): |
|
model_selector = gr.Dropdown( |
|
label="Available Models", |
|
interactive=True, |
|
visible=False |
|
) |
|
|
|
|
|
def connect_client(api_key): |
|
try: |
|
|
|
client.groq = Groq(api_key=api_key, http_client=httpx.Client(verify=False)) |
|
models = client.groq.models.list().data |
|
available_models = sorted([m.id for m in models if m.active]) |
|
compatible_models = [ |
|
"qwen-qwq-32b", |
|
"qwen-2.5-coder-32b", |
|
"qwen-2.5-32b", |
|
"deepseek-r1-distill-qwen-32b", |
|
"deepseek-r1-distill-llama-70b", |
|
"llama-3.3-70b-versatile", |
|
"llama-3.1-8b-instant", |
|
"mixtral-8x7b-32768", |
|
"gemma2-9b-it" |
|
] |
|
active_models = sorted([model for model in available_models if model in compatible_models]) |
|
|
|
|
|
return { |
|
connection_status: "Connected ✅", |
|
chat_row: gr.update(visible=True), |
|
model_selector: gr.update( |
|
choices=active_models, |
|
value=active_models[0] if active_models else "", |
|
visible=True |
|
), |
|
connect_btn: gr.update(visible=False), |
|
api_key_input: gr.update(interactive=False) |
|
} |
|
except Exception as e: |
|
return { |
|
connection_status: f"Connection failed: {str(e)}", |
|
chat_row: gr.update(visible=False) |
|
} |
|
|
|
connect_btn.click( |
|
connect_client, |
|
inputs=api_key_input, |
|
outputs=[connection_status, chat_row, model_selector, connect_btn, api_key_input] |
|
) |
|
|
|
|
|
async def chat_stream(query, history, selected_model): |
|
client.current_model = selected_model |
|
|
|
|
|
if not client.session: |
|
await client.connect() |
|
|
|
accumulated_response = "" |
|
async for chunk in client.stream_response(query): |
|
accumulated_response += chunk |
|
yield "", history + [(query, accumulated_response)] |
|
|
|
yield "", history + [(query, accumulated_response)] |
|
|
|
submit_btn.click( |
|
chat_stream, |
|
[input_box, chatbot, model_selector], |
|
[input_box, chatbot], |
|
show_progress="hidden" |
|
) |
|
input_box.submit( |
|
chat_stream, |
|
[input_box, chatbot, model_selector], |
|
[input_box, chatbot], |
|
show_progress="hidden" |
|
) |
|
|
|
return interface |
|
|
|
if __name__ == "__main__": |
|
interface = create_interface() |
|
interface.queue().launch( |
|
server_port=7860, |
|
server_name="0.0.0.0", |
|
show_error=True |
|
) |