import gradio as gr from huggingface_hub import InferenceClient import os import json import base64 from PIL import Image import io import requests # Retained, though not directly used in the core logic shown for modification from smolagents.mcp_client import MCPClient ACCESS_TOKEN = os.getenv("HF_TOKEN") print("Access token loaded.") # Function to encode image to base64 def encode_image(image_path): if not image_path: print("No image path provided") return None try: print(f"Encoding image from path: {image_path}") # If it's already a PIL Image if isinstance(image_path, Image.Image): image = image_path else: # Try to open the image file image = Image.open(image_path) # Convert to RGB if image has an alpha channel (RGBA) if image.mode == 'RGBA': image = image.convert('RGB') # Encode to base64 buffered = io.BytesIO() image.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") print("Image encoded successfully") return img_str except Exception as e: print(f"Error encoding image: {e}") return None # Dictionary to store active MCP connections mcp_connections = {} def connect_to_mcp_server(server_url, server_name=None): """Connect to an MCP server and return available tools""" if not server_url: return None, "No server URL provided" try: # Create an MCP client and connect to the server client = MCPClient({"url": server_url}) # Get available tools tools = client.get_tools() # Store the connection for later use name = server_name or f"Server_{len(mcp_connections)}_{base64.urlsafe_b64encode(os.urandom(3)).decode()}" # Ensure unique name mcp_connections[name] = {"client": client, "tools": tools, "url": server_url} return name, f"Successfully connected to {name} with {len(tools)} available tools" except Exception as e: print(f"Error connecting to MCP server: {e}") return None, f"Error connecting to MCP server: {str(e)}" def list_mcp_tools(server_name): """List available tools for a connected MCP server""" if server_name not in mcp_connections: return "Server not connected" tools = mcp_connections[server_name]["tools"] tool_info = [] for tool in tools: tool_info.append(f"- {tool.name}: {tool.description}") if not tool_info: return "No tools available for this server" return "\n".join(tool_info) def call_mcp_tool(server_name, tool_name, **kwargs): """Call a specific tool from an MCP server""" if server_name not in mcp_connections: return f"Server '{server_name}' not connected" client = mcp_connections[server_name]["client"] tools = mcp_connections[server_name]["tools"] # Find the requested tool tool = next((t for t in tools if t.name == tool_name), None) if not tool: return f"Tool '{tool_name}' not found on server '{server_name}'" try: # Call the tool with provided arguments # The mcp_client's call_tool is expected to return the direct result from the tool result = client.call_tool(tool_name, kwargs) # The result here could be a string (e.g. base64 audio), a dict, or other types # depending on the MCP tool. The `respond` function will handle formatting. return result except Exception as e: print(f"Error calling MCP tool: {e}") return f"Error calling MCP tool: {str(e)}" def analyze_message_for_tool_call(message, active_mcp_servers, client_for_llm, model_to_use, system_message_for_llm): """Analyze a message to determine if an MCP tool should be called""" # Skip analysis if message is empty if not message or not message.strip(): return None, None # Get information about available tools tool_info = [] if active_mcp_servers: for server_name in active_mcp_servers: if server_name in mcp_connections: server_tools = mcp_connections[server_name]["tools"] for tool in server_tools: tool_info.append({ "server_name": server_name, "tool_name": tool.name, "description": tool.description }) if not tool_info: return None, None # Create a structured query for the LLM to analyze if a tool call is needed tools_desc = [] for info in tool_info: tools_desc.append(f"{info['server_name']}.{info['tool_name']}: {info['description']}") tools_string = "\n".join(tools_desc) # Updated prompt to guide LLM for TTS tool that returns base64 analysis_system_prompt = f"""You are an assistant that helps determine if a user message requires using an external tool. Available tools: {tools_string} Your job is to: 1. Analyze the user's message. 2. Determine if they're asking to use one of the tools. 3. If yes, respond ONLY with a JSON object with "server_name", "tool_name", and "parameters". 4. If no, respond ONLY with the exact string "NO_TOOL_NEEDED". Example 1 (for TTS that returns base64 audio): User: "Please turn this text into speech: Hello world" Response: {{"server_name": "kokoroTTS", "tool_name": "text_to_audio_b64", "parameters": {{"text": "Hello world", "speed": 1.0}}}} Example 2 (for TTS with different speed): User: "Read 'This is faster' at speed 1.5" Response: {{"server_name": "kokoroTTS", "tool_name": "text_to_audio_b64", "parameters": {{"text": "This is faster", "speed": 1.5}}}} Example 3 (general, non-tool): User: "What is the capital of France?" Response: NO_TOOL_NEEDED""" try: # Call the LLM to analyze the message response = client_for_llm.chat_completion( model=model_to_use, messages=[ {"role": "system", "content": analysis_system_prompt}, {"role": "user", "content": message} ], temperature=0.1, # Low temperature for deterministic tool selection max_tokens=300 ) analysis = response.choices[0].message.content.strip() print(f"Tool analysis raw response: '{analysis}'") if analysis == "NO_TOOL_NEEDED": return None, None # Try to parse JSON directly from the response try: tool_call = json.loads(analysis) return tool_call.get("server_name"), { "tool_name": tool_call.get("tool_name"), "parameters": tool_call.get("parameters", {}) } except json.JSONDecodeError: print(f"Failed to parse tool call JSON directly from: {analysis}") # Fallback to extracting JSON if not a direct JSON response json_start = analysis.find("{") json_end = analysis.rfind("}") + 1 if json_start != -1 and json_end != 0 and json_end > json_start: json_str = analysis[json_start:json_end] try: tool_call = json.loads(json_str) return tool_call.get("server_name"), { "tool_name": tool_call.get("tool_name"), "parameters": tool_call.get("parameters", {}) } except json.JSONDecodeError: print(f"Failed to parse extracted tool call JSON: {json_str}") return None, None else: print(f"No JSON object found in analysis: {analysis}") return None, None except Exception as e: print(f"Error analyzing message for tool calls: {str(e)}") return None, None def respond( message, image_files, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, frequency_penalty, seed, provider, custom_api_key, custom_model, model_search_term, selected_model, mcp_enabled=False, active_mcp_servers=None, mcp_interaction_mode="Natural Language" ): print(f"Received message: {message}") print(f"Received {len(image_files) if image_files else 0} images") # print(f"History: {history}") # Can be very verbose print(f"System message: {system_message}") print(f"Max tokens: {max_tokens}, Temperature: {temperature}, Top-P: {top_p}") print(f"Frequency Penalty: {frequency_penalty}, Seed: {seed}") print(f"Selected provider: {provider}") print(f"Custom API Key provided: {bool(custom_api_key.strip())}") print(f"Selected model (custom_model): {custom_model}") print(f"Model search term: {model_search_term}") print(f"Selected model from radio: {selected_model}") print(f"MCP enabled: {mcp_enabled}") print(f"Active MCP servers: {active_mcp_servers}") print(f"MCP interaction mode: {mcp_interaction_mode}") token_to_use = custom_api_key if custom_api_key.strip() != "" else ACCESS_TOKEN if custom_api_key.strip() != "": print("USING CUSTOM API KEY: BYOK token provided by user is being used for authentication") else: print("USING DEFAULT API KEY: Environment variable HF_TOKEN is being used for authentication") client_for_llm = InferenceClient(token=token_to_use, provider=provider) print(f"Hugging Face Inference Client initialized with {provider} provider.") if seed == -1: seed = None model_to_use = custom_model.strip() if custom_model.strip() != "" else selected_model print(f"Model selected for inference: {model_to_use}") if mcp_enabled and message: if message.startswith("/mcp"): command_parts = message.split(" ", 3) if len(command_parts) < 3: yield "Invalid MCP command. Format: /mcp [arguments_json]" return _, server_name, tool_name = command_parts[:3] args_json_str = "{}" if len(command_parts) < 4 else command_parts[3] try: args_dict = json.loads(args_json_str) result = call_mcp_tool(server_name, tool_name, **args_dict) if "audio" in tool_name.lower() and "b64" in tool_name.lower() and isinstance(result, str): audio_html = f'' yield f"Executed {tool_name} from {server_name}.\n\nResult:\n{audio_html}" elif isinstance(result, dict): yield json.dumps(result, indent=2) else: yield str(result) return # MCP command handled, exit except json.JSONDecodeError: yield f"Invalid JSON arguments: {args_json_str}" return except Exception as e: yield f"Error executing MCP command: {str(e)}" return elif mcp_interaction_mode == "Natural Language" and active_mcp_servers: server_name, tool_info = analyze_message_for_tool_call( message, active_mcp_servers, client_for_llm, model_to_use, system_message # Original system message for context, LLM uses its own for analysis ) if server_name and tool_info and tool_info.get("tool_name"): try: print(f"Calling tool via natural language: {server_name}.{tool_info['tool_name']} with parameters: {tool_info.get('parameters', {})}") result = call_mcp_tool(server_name, tool_info['tool_name'], **tool_info.get('parameters', {})) tool_display_name = tool_info['tool_name'] if "audio" in tool_display_name.lower() and "b64" in tool_display_name.lower() and isinstance(result, str) and len(result) > 100: # Heuristic for base64 audio audio_html = f'' yield f"I used the {tool_display_name} tool from {server_name} with your request.\n\nResult:\n{audio_html}" elif isinstance(result, dict): result_str = json.dumps(result, indent=2) yield f"I used the {tool_display_name} tool from {server_name} with your request.\n\nResult:\n{result_str}" else: result_str = str(result) yield f"I used the {tool_display_name} tool from {server_name} with your request.\n\nResult:\n{result_str}" return # MCP tool call handled via natural language except Exception as e: print(f"Error executing MCP tool via natural language: {str(e)}") yield f"I tried to use a tool but encountered an error: {str(e)}. I will try to respond without it." # Fall through to normal LLM response if tool call fails user_content = [] if message and message.strip(): user_content.append({"type": "text", "text": message}) if image_files and len(image_files) > 0: for img_path in image_files: if img_path is not None: try: encoded_image = encode_image(img_path) if encoded_image: user_content.append({ "type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"} }) except Exception as e: print(f"Error encoding image for user content: {e}") if not user_content: # If message was empty and no images, or only MCP command handled if not message.startswith("/mcp"): # Avoid yielding empty if it was an MCP command yield "" # Or handle appropriately, maybe return if no content return augmented_system_message = system_message if mcp_enabled and active_mcp_servers: tool_desc_list = [] for server_name_active in active_mcp_servers: if server_name_active in mcp_connections: # Get tools for this specific server # Assuming list_mcp_tools returns a string like "- tool1: desc1\n- tool2: desc2" server_tools_str = list_mcp_tools(server_name_active) if server_tools_str != "Server not connected" and server_tools_str != "No tools available for this server": for line in server_tools_str.split('\n'): if line.startswith("- "): tool_desc_list.append(f"{server_name_active}.{line[2:]}") # e.g., kokoroTTS.text_to_audio_b64: Convert text... if tool_desc_list: mcp_tools_description_for_llm = "\n".join(tool_desc_list) # This informs the main LLM about available tools for general conversation, # distinct from the specialized analyzer LLM. # The main LLM doesn't call tools directly but can use this info to guide the user. if mcp_interaction_mode == "Command Mode": augmented_system_message += f"\n\nYou have access to the following MCP tools which the user can invoke:\n{mcp_tools_description_for_llm}\n\nTo use these tools, the user can type a command in the format: /mcp " else: # Natural Language augmented_system_message += f"\n\nYou have access to the following MCP tools. The system will try to use them automatically if the user's request matches their capability:\n{mcp_tools_description_for_llm}\n\nIf the user asks to do something a tool can do, the system will attempt to use it. For example, if a 'text_to_audio_b64' tool is available, and the user says 'read this text aloud', the system will try to use that tool." messages_for_llm = [{"role": "system", "content": augmented_system_message}] print("Initial messages array constructed.") for hist_user, hist_assistant in history: # hist_user can be complex if it included images from MultimodalTextbox # We need to reconstruct it properly for the LLM current_hist_user_content = [] if isinstance(hist_user, dict) and 'text' in hist_user and 'files' in hist_user: # From MultimodalTextbox if hist_user['text'] and hist_user['text'].strip(): current_hist_user_content.append({"type": "text", "text": hist_user['text']}) if hist_user['files']: for img_file_path in hist_user['files']: encoded_img = encode_image(img_file_path) if encoded_img: current_hist_user_content.append({ "type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_img}"} }) elif isinstance(hist_user, str): # Simple text history current_hist_user_content.append({"type": "text", "text": hist_user}) if current_hist_user_content: messages_for_llm.append({"role": "user", "content": current_hist_user_content}) if hist_assistant: # Assistant message is always text # Check if assistant message was an HTML audio tag, if so, send a placeholder to LLM if "