|
import asyncio |
|
import json |
|
import logging |
|
import os |
|
from typing import Any, Dict, List, Optional |
|
from datetime import timedelta |
|
|
|
from mcp.shared.message import SessionMessage |
|
from mcp.types import ( |
|
JSONRPCMessage, |
|
JSONRPCRequest, |
|
JSONRPCNotification, |
|
JSONRPCResponse, |
|
JSONRPCError, |
|
) |
|
from mcp.client.streamable_http import streamablehttp_client |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class HuggingFaceMCPClient: |
|
"""Client for interacting with Hugging Face MCP endpoint.""" |
|
|
|
def __init__(self, hf_token: str, timeout: int = 30): |
|
""" |
|
Initialize the Hugging Face MCP client. |
|
|
|
Args: |
|
hf_token: Hugging Face API token |
|
timeout: Timeout in seconds for HTTP requests |
|
""" |
|
self.hf_token = hf_token |
|
self.url = "https://huggingface.co/mcp" |
|
self.headers = {"Authorization": f"Bearer {hf_token}"} |
|
self.timeout = timedelta(seconds=timeout) |
|
self.request_id_counter = 0 |
|
|
|
def _get_next_request_id(self) -> int: |
|
"""Get the next request ID.""" |
|
self.request_id_counter += 1 |
|
return self.request_id_counter |
|
|
|
async def _send_request_and_get_response( |
|
self, |
|
method: str, |
|
params: Optional[Dict[str, Any]] = None |
|
) -> Any: |
|
""" |
|
Send a JSON-RPC request and wait for the response. |
|
|
|
Args: |
|
method: The JSON-RPC method name |
|
params: Optional parameters for the method |
|
|
|
Returns: |
|
The response result or raises an exception |
|
""" |
|
request_id = self._get_next_request_id() |
|
|
|
|
|
jsonrpc_request = JSONRPCRequest( |
|
jsonrpc="2.0", |
|
id=request_id, |
|
method=method, |
|
params=params |
|
) |
|
|
|
message = JSONRPCMessage(jsonrpc_request) |
|
session_message = SessionMessage(message) |
|
|
|
async with streamablehttp_client( |
|
url=self.url, |
|
headers=self.headers, |
|
timeout=self.timeout, |
|
terminate_on_close=True |
|
) as (read_stream, write_stream, get_session_id): |
|
|
|
try: |
|
|
|
init_request = JSONRPCRequest( |
|
jsonrpc="2.0", |
|
id=self._get_next_request_id(), |
|
method="initialize", |
|
params={ |
|
"protocolVersion": "2024-11-05", |
|
"capabilities": { |
|
"tools": {} |
|
}, |
|
"clientInfo": { |
|
"name": "hf-mcp-client", |
|
"version": "1.0.0" |
|
} |
|
} |
|
) |
|
|
|
init_message = JSONRPCMessage(init_request) |
|
init_session_message = SessionMessage(init_message) |
|
|
|
await write_stream.send(init_session_message) |
|
|
|
|
|
init_response_received = False |
|
timeout_counter = 0 |
|
max_iterations = 100 |
|
|
|
while not init_response_received and timeout_counter < max_iterations: |
|
try: |
|
response = await read_stream.receive() |
|
timeout_counter += 1 |
|
|
|
if isinstance(response, Exception): |
|
raise response |
|
|
|
if isinstance(response, SessionMessage): |
|
msg = response.message.root |
|
if isinstance(msg, JSONRPCResponse) and msg.id == init_request.id: |
|
logger.info("MCP client initialized successfully") |
|
init_response_received = True |
|
elif isinstance(msg, JSONRPCError) and msg.id == init_request.id: |
|
raise Exception(f"Initialization failed: {msg.error}") |
|
except Exception as e: |
|
if "ClosedResourceError" in str(type(e)): |
|
logger.error("Stream closed during initialization") |
|
raise Exception("Connection closed during initialization") |
|
raise |
|
|
|
if not init_response_received: |
|
raise Exception("Initialization timeout") |
|
|
|
|
|
initialized_notification = JSONRPCNotification( |
|
jsonrpc="2.0", |
|
method="notifications/initialized" |
|
) |
|
|
|
init_notif_message = JSONRPCMessage(initialized_notification) |
|
init_notif_session_message = SessionMessage(init_notif_message) |
|
|
|
await write_stream.send(init_notif_session_message) |
|
|
|
|
|
await asyncio.sleep(0.1) |
|
|
|
|
|
await write_stream.send(session_message) |
|
|
|
|
|
response_received = False |
|
timeout_counter = 0 |
|
|
|
while not response_received and timeout_counter < max_iterations: |
|
try: |
|
response = await read_stream.receive() |
|
timeout_counter += 1 |
|
|
|
if isinstance(response, Exception): |
|
raise response |
|
|
|
if isinstance(response, SessionMessage): |
|
msg = response.message.root |
|
if isinstance(msg, JSONRPCResponse) and msg.id == request_id: |
|
return msg.result |
|
elif isinstance(msg, JSONRPCError) and msg.id == request_id: |
|
raise Exception(f"Request failed: {msg.error}") |
|
except Exception as e: |
|
if "ClosedResourceError" in str(type(e)): |
|
logger.error("Stream closed during request processing") |
|
raise Exception("Connection closed during request processing") |
|
raise |
|
|
|
if not response_received: |
|
raise Exception("Request timeout") |
|
|
|
except Exception as e: |
|
logger.error(f"Error during MCP communication: {e}") |
|
raise |
|
finally: |
|
|
|
try: |
|
await write_stream.aclose() |
|
except: |
|
pass |
|
|
|
async def get_all_tools(self) -> List[Dict[str, Any]]: |
|
""" |
|
Get all available tools from the Hugging Face MCP endpoint. |
|
|
|
Returns: |
|
List of tool definitions |
|
""" |
|
try: |
|
logger.info("Fetching all available tools from Hugging Face MCP") |
|
result = await self._send_request_and_get_response("tools/list") |
|
|
|
if isinstance(result, dict) and "tools" in result: |
|
tools = result["tools"] |
|
logger.info(f"Found {len(tools)} available tools") |
|
return tools |
|
else: |
|
logger.warning(f"Unexpected response format: {result}") |
|
return [] |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to get tools: {e}") |
|
raise |
|
|
|
async def call_tool(self, tool_name: str, args: Dict[str, Any]) -> Any: |
|
""" |
|
Call a specific tool with the given arguments. |
|
|
|
Args: |
|
tool_name: Name of the tool to call |
|
args: Arguments to pass to the tool |
|
|
|
Returns: |
|
The tool's response |
|
""" |
|
try: |
|
logger.info(f"Calling tool '{tool_name}' with args: {args}") |
|
|
|
params = { |
|
"name": tool_name, |
|
"arguments": args |
|
} |
|
|
|
result = await self._send_request_and_get_response("tools/call", params) |
|
logger.info(f"Tool '{tool_name}' executed successfully") |
|
return result |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to call tool '{tool_name}': {e}") |
|
raise |
|
|
|
|
|
|
|
async def get_hf_tools(hf_token: str) -> List[Dict[str, Any]]: |
|
""" |
|
Get all available tools from Hugging Face MCP. |
|
|
|
Args: |
|
hf_token: Hugging Face API token |
|
|
|
Returns: |
|
List of tool definitions |
|
""" |
|
client = HuggingFaceMCPClient(hf_token) |
|
return await client.get_all_tools() |
|
|
|
|
|
async def call_hf_tool(hf_token: str, tool_name: str, args: Dict[str, Any]) -> Any: |
|
""" |
|
Call a specific Hugging Face MCP tool. |
|
|
|
Args: |
|
hf_token: Hugging Face API token |
|
tool_name: Name of the tool to call |
|
args: Arguments to pass to the tool |
|
|
|
Returns: |
|
The tool's response |
|
""" |
|
client = HuggingFaceMCPClient(hf_token) |
|
return await client.call_tool(tool_name, args) |
|
|
|
|
|
|
|
class SimpleHFMCPClient: |
|
"""Simplified version for debugging connection issues.""" |
|
|
|
def __init__(self, hf_token: str): |
|
self.hf_token = hf_token |
|
self.url = "https://huggingface.co/mcp" |
|
self.headers = {"Authorization": f"Bearer {hf_token}"} |
|
|
|
async def test_connection(self): |
|
"""Test basic connection to HF MCP endpoint.""" |
|
try: |
|
async with streamablehttp_client( |
|
url=self.url, |
|
headers=self.headers, |
|
timeout=timedelta(seconds=10), |
|
terminate_on_close=True |
|
) as (read_stream, write_stream, get_session_id): |
|
logger.info("Connection established successfully") |
|
return True |
|
except Exception as e: |
|
logger.error(f"Connection test failed: {e}") |
|
return False |
|
|
|
|