File size: 10,547 Bytes
5515a5b
e9b0c98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5515a5b
e9b0c98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
import asyncio
import json
import logging
import os
from typing import Any, Dict, List, Optional
from datetime import timedelta

from mcp.shared.message import SessionMessage
from mcp.types import (
    JSONRPCMessage,
    JSONRPCRequest,
    JSONRPCNotification,
    JSONRPCResponse,
    JSONRPCError,
)
from mcp.client.streamable_http import streamablehttp_client

logger = logging.getLogger(__name__)


class HuggingFaceMCPClient:
    """Client for interacting with Hugging Face MCP endpoint."""
    
    def __init__(self, hf_token: str, timeout: int = 30):
        """
        Initialize the Hugging Face MCP client.
        
        Args:
            hf_token: Hugging Face API token
            timeout: Timeout in seconds for HTTP requests
        """
        self.hf_token = hf_token
        self.url = "https://huggingface.co/mcp"
        self.headers = {"Authorization": f"Bearer {hf_token}"}
        self.timeout = timedelta(seconds=timeout)
        self.request_id_counter = 0
        
    def _get_next_request_id(self) -> int:
        """Get the next request ID."""
        self.request_id_counter += 1
        return self.request_id_counter
    
    async def _send_request_and_get_response(
        self, 
        method: str, 
        params: Optional[Dict[str, Any]] = None
    ) -> Any:
        """
        Send a JSON-RPC request and wait for the response.
        
        Args:
            method: The JSON-RPC method name
            params: Optional parameters for the method
            
        Returns:
            The response result or raises an exception
        """
        request_id = self._get_next_request_id()
        
        # Create JSON-RPC request
        jsonrpc_request = JSONRPCRequest(
            jsonrpc="2.0",
            id=request_id,
            method=method,
            params=params
        )
        
        message = JSONRPCMessage(jsonrpc_request)
        session_message = SessionMessage(message)
        
        async with streamablehttp_client(
            url=self.url,
            headers=self.headers,
            timeout=self.timeout,
            terminate_on_close=True
        ) as (read_stream, write_stream, get_session_id):
            
            try:
                # Send initialization request first
                init_request = JSONRPCRequest(
                    jsonrpc="2.0",
                    id=self._get_next_request_id(),
                    method="initialize",
                    params={
                        "protocolVersion": "2024-11-05",
                        "capabilities": {
                            "tools": {}
                        },
                        "clientInfo": {
                            "name": "hf-mcp-client",
                            "version": "1.0.0"
                        }
                    }
                )
                
                init_message = JSONRPCMessage(init_request)
                init_session_message = SessionMessage(init_message)
                
                await write_stream.send(init_session_message)
                
                # Wait for initialization response
                init_response_received = False
                timeout_counter = 0
                max_iterations = 100  # Prevent infinite loops
                
                while not init_response_received and timeout_counter < max_iterations:
                    try:
                        response = await read_stream.receive()
                        timeout_counter += 1
                        
                        if isinstance(response, Exception):
                            raise response
                        
                        if isinstance(response, SessionMessage):
                            msg = response.message.root
                            if isinstance(msg, JSONRPCResponse) and msg.id == init_request.id:
                                logger.info("MCP client initialized successfully")
                                init_response_received = True
                            elif isinstance(msg, JSONRPCError) and msg.id == init_request.id:
                                raise Exception(f"Initialization failed: {msg.error}")
                    except Exception as e:
                        if "ClosedResourceError" in str(type(e)):
                            logger.error("Stream closed during initialization")
                            raise Exception("Connection closed during initialization")
                        raise
                
                if not init_response_received:
                    raise Exception("Initialization timeout")
                
                # Send initialized notification
                initialized_notification = JSONRPCNotification(
                    jsonrpc="2.0",
                    method="notifications/initialized"
                )
                
                init_notif_message = JSONRPCMessage(initialized_notification)
                init_notif_session_message = SessionMessage(init_notif_message)
                
                await write_stream.send(init_notif_session_message)
                
                # Small delay to let the notification process
                await asyncio.sleep(0.1)
                
                # Now send our actual request
                await write_stream.send(session_message)
                
                # Wait for the response to our request
                response_received = False
                timeout_counter = 0
                
                while not response_received and timeout_counter < max_iterations:
                    try:
                        response = await read_stream.receive()
                        timeout_counter += 1
                        
                        if isinstance(response, Exception):
                            raise response
                        
                        if isinstance(response, SessionMessage):
                            msg = response.message.root
                            if isinstance(msg, JSONRPCResponse) and msg.id == request_id:
                                return msg.result
                            elif isinstance(msg, JSONRPCError) and msg.id == request_id:
                                raise Exception(f"Request failed: {msg.error}")
                    except Exception as e:
                        if "ClosedResourceError" in str(type(e)):
                            logger.error("Stream closed during request processing")
                            raise Exception("Connection closed during request processing")
                        raise
                
                if not response_received:
                    raise Exception("Request timeout")
                    
            except Exception as e:
                logger.error(f"Error during MCP communication: {e}")
                raise
            finally:
                # Ensure streams are properly closed
                try:
                    await write_stream.aclose()
                except:
                    pass
    
    async def get_all_tools(self) -> List[Dict[str, Any]]:
        """
        Get all available tools from the Hugging Face MCP endpoint.
        
        Returns:
            List of tool definitions
        """
        try:
            logger.info("Fetching all available tools from Hugging Face MCP")
            result = await self._send_request_and_get_response("tools/list")
            
            if isinstance(result, dict) and "tools" in result:
                tools = result["tools"]
                logger.info(f"Found {len(tools)} available tools")
                return tools
            else:
                logger.warning(f"Unexpected response format: {result}")
                return []
                
        except Exception as e:
            logger.error(f"Failed to get tools: {e}")
            raise
    
    async def call_tool(self, tool_name: str, args: Dict[str, Any]) -> Any:
        """
        Call a specific tool with the given arguments.
        
        Args:
            tool_name: Name of the tool to call
            args: Arguments to pass to the tool
            
        Returns:
            The tool's response
        """
        try:
            logger.info(f"Calling tool '{tool_name}' with args: {args}")
            
            params = {
                "name": tool_name,
                "arguments": args
            }
            
            result = await self._send_request_and_get_response("tools/call", params)
            logger.info(f"Tool '{tool_name}' executed successfully")
            return result
            
        except Exception as e:
            logger.error(f"Failed to call tool '{tool_name}': {e}")
            raise


# Convenience functions for easier usage
async def get_hf_tools(hf_token: str) -> List[Dict[str, Any]]:
    """
    Get all available tools from Hugging Face MCP.
    
    Args:
        hf_token: Hugging Face API token
        
    Returns:
        List of tool definitions
    """
    client = HuggingFaceMCPClient(hf_token)
    return await client.get_all_tools()


async def call_hf_tool(hf_token: str, tool_name: str, args: Dict[str, Any]) -> Any:
    """
    Call a specific Hugging Face MCP tool.
    
    Args:
        hf_token: Hugging Face API token
        tool_name: Name of the tool to call
        args: Arguments to pass to the tool
        
    Returns:
        The tool's response
    """
    client = HuggingFaceMCPClient(hf_token)
    return await client.call_tool(tool_name, args)


# Alternative simpler implementation for debugging
class SimpleHFMCPClient:
    """Simplified version for debugging connection issues."""
    
    def __init__(self, hf_token: str):
        self.hf_token = hf_token
        self.url = "https://huggingface.co/mcp"
        self.headers = {"Authorization": f"Bearer {hf_token}"}
        
    async def test_connection(self):
        """Test basic connection to HF MCP endpoint."""
        try:
            async with streamablehttp_client(
                url=self.url,
                headers=self.headers,
                timeout=timedelta(seconds=10),
                terminate_on_close=True
            ) as (read_stream, write_stream, get_session_id):
                logger.info("Connection established successfully")
                return True
        except Exception as e:
            logger.error(f"Connection test failed: {e}")
            return False