Spaces:
Running
Running
import gradio as gr | |
from huggingface_hub import InferenceClient | |
import os | |
import json | |
import base64 | |
from PIL import Image | |
import io | |
import requests # Keep for potential future use, though not directly used in core logic now | |
from smolagents.mcp_client import MCPClient # Ensure this is correctly installed and importable | |
ACCESS_TOKEN = os.getenv("HF_TOKEN") | |
if ACCESS_TOKEN: | |
print("Access token loaded from HF_TOKEN environment variable.") | |
else: | |
print("Warning: HF_TOKEN environment variable not set. Some operations might fail.") | |
# Function to encode image to base64 | |
def encode_image(image_path_or_pil): | |
if not image_path_or_pil: | |
print("No image path or PIL Image provided") | |
return None | |
try: | |
if isinstance(image_path_or_pil, Image.Image): | |
image = image_path_or_pil | |
print(f"Encoding PIL Image object.") | |
elif isinstance(image_path_or_pil, str): | |
print(f"Encoding image from path: {image_path_or_pil}") | |
if not os.path.exists(image_path_or_pil): | |
print(f"Error: Image file not found at {image_path_or_pil}") | |
return None | |
image = Image.open(image_path_or_pil) | |
else: | |
print(f"Error: Unsupported image input type: {type(image_path_or_pil)}") | |
return None | |
if image.mode == 'RGBA': | |
image = image.convert('RGB') | |
buffered = io.BytesIO() | |
image.save(buffered, format="JPEG") # Or PNG if preferred, ensure consistency | |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
print("Image encoded successfully to base64.") | |
return img_str | |
except Exception as e: | |
print(f"Error encoding image: {e}") | |
import traceback | |
traceback.print_exc() | |
return None | |
# Dictionary to store active MCP connections | |
mcp_connections = {} | |
def connect_to_mcp_server(server_url, server_name=None): | |
"""Connect to an MCP server and return available tools""" | |
if not server_url: | |
return None, "No server URL provided. Please enter a valid URL." | |
try: | |
print(f"Attempting to connect to MCP server at URL: {server_url}") | |
client = MCPClient({"url": server_url}) # This might block or raise if connection fails | |
tools = client.get_tools() # This should also be a blocking call until tools are fetched | |
name = server_name.strip() if server_name and server_name.strip() else f"Server_{len(mcp_connections) + 1}" | |
mcp_connections[name] = {"client": client, "tools": tools, "url": server_url} | |
print(f"Successfully connected to MCP server: {name} with {len(tools)} tools.") | |
return name, f"Successfully connected to '{name}' ({server_url}). Found {len(tools)} tool(s)." | |
except Exception as e: | |
print(f"Error connecting to MCP server at {server_url}: {e}") | |
import traceback | |
traceback.print_exc() | |
return None, f"Error connecting to MCP server '{server_url}': {str(e)}" | |
def list_mcp_tools(server_name): | |
"""List available tools for a connected MCP server""" | |
if server_name not in mcp_connections: | |
return "Server not connected or name not found." | |
tools = mcp_connections[server_name]["tools"] | |
tool_info = [] | |
for tool in tools: | |
tool_info.append(f"- **{tool.name}**: {tool.description}") | |
if not tool_info: | |
return "No tools available for this server." | |
return "\n".join(tool_info) | |
def call_mcp_tool(server_name, tool_name, **kwargs): | |
"""Call a specific tool from an MCP server and process its result.""" | |
if server_name not in mcp_connections: | |
return {"type": "error", "message": f"Server '{server_name}' not connected."} | |
mcp_client_instance = mcp_connections[server_name]["client"] | |
try: | |
print(f"Calling MCP tool: {server_name}.{tool_name} with args: {kwargs}") | |
# Assuming mcp_client_instance.call_tool returns an mcp.client.tool.ToolResult object | |
tool_result = mcp_client_instance.call_tool(tool_name, kwargs) | |
if tool_result and tool_result.content: | |
# Process multiple blocks if present, concatenating text or prioritizing audio | |
audio_block_found = None | |
text_parts = [] | |
json_parts = [] | |
other_parts = [] | |
for block in tool_result.content: | |
if hasattr(block, 'uri') and isinstance(block.uri, str) and block.uri.startswith('data:audio/'): | |
audio_block_found = { | |
"type": "audio", | |
"data_uri": block.uri, | |
"name": getattr(block, 'name', 'audio_output.wav') | |
} | |
break # Prioritize first audio block | |
elif hasattr(block, 'text') and block.text is not None: | |
text_parts.append(str(block.text)) | |
elif hasattr(block, 'json_data') and block.json_data is not None: | |
try: | |
json_parts.append(json.dumps(block.json_data, indent=2)) | |
except TypeError: | |
json_parts.append(str(block.json_data)) # Fallback | |
else: | |
other_parts.append(str(block)) | |
if audio_block_found: | |
print(f"MCP tool returned audio: {audio_block_found['name']}") | |
return audio_block_found | |
elif text_parts: | |
full_text = "\n".join(text_parts) | |
print(f"MCP tool returned text: {full_text[:100]}...") | |
return {"type": "text", "value": full_text} | |
elif json_parts: | |
full_json_str = "\n".join(json_parts) | |
print(f"MCP tool returned JSON string.") | |
return {"type": "json_string", "value": full_json_str} # Treat as string for display | |
elif other_parts: | |
print(f"MCP tool returned other content types.") | |
return {"type": "text", "value": "\n".join(other_parts)} | |
else: | |
print("MCP tool executed but returned no interpretable primary content blocks.") | |
return {"type": "text", "value": "Tool executed, but returned no standard content (audio/text/json)."} | |
print("MCP tool executed, but ToolResult or its content was empty.") | |
return {"type": "text", "value": "Tool executed, but returned no content."} | |
except Exception as e: | |
print(f"Error calling MCP tool '{tool_name}' or processing its result: {e}") | |
import traceback | |
traceback.print_exc() | |
return {"type": "error", "message": f"Error during MCP tool '{tool_name}' execution: {str(e)}"} | |
def analyze_message_for_tool_call(message, active_mcp_servers, llm_client, llm_model_to_use, base_system_message): | |
"""Analyze a message to determine if an MCP tool should be called""" | |
if not message or not message.strip() or not active_mcp_servers: | |
return None, None | |
tool_info_for_llm = [] | |
for server_name_iter in active_mcp_servers: | |
if server_name_iter in mcp_connections: | |
server_tools = mcp_connections[server_name_iter]["tools"] | |
for tool in server_tools: | |
# Provide a concise description for the LLM | |
tool_info_for_llm.append( | |
f"- Server: '{server_name_iter}', Tool: '{tool.name}', Description: '{tool.description}'" | |
) | |
if not tool_info_for_llm: | |
print("No active MCP tools found for analysis.") | |
return None, None | |
tools_string_for_llm = "\n".join(tool_info_for_llm) | |
# More robust system prompt for tool detection | |
analysis_system_prompt = f"""You are an expert assistant that determines if a user's request requires an external tool. | |
You have access to the following tools: | |
{tools_string_for_llm} | |
Based on the user's message, decide if any of these tools are appropriate. | |
If a tool is needed, respond ONLY with a JSON object containing: | |
"server_name": The name of the server providing the tool. | |
"tool_name": The name of the tool to be called. | |
"parameters": A dictionary of parameters for the tool, inferred from the user's message. Ensure parameter names match what the tool expects (often 'text', 'query', 'speed', etc.). | |
If NO tool is needed, respond ONLY with the exact string: NO_TOOL_NEEDED | |
Example 1 (TTS tool): | |
User: "Can you say 'hello world' for me at a slightly faster speed?" | |
Response: {{"server_name": "kokoroTTS", "tool_name": "text_to_audio", "parameters": {{"text": "hello world", "speed": 1.2}}}} | |
Example 2 (File tool): | |
User: "Read the content of my_document.txt" | |
Response: {{"server_name": "FileSystemServer", "tool_name": "readFile", "parameters": {{"path": "my_document.txt"}}}} | |
Example 3 (No tool): | |
User: "What's the weather like today?" (Assuming no weather tool is listed) | |
Response: NO_TOOL_NEEDED | |
User's current message is: "{message}" | |
Now, provide your decision:""" | |
try: | |
print(f"Sending tool analysis request to LLM model: {llm_model_to_use}") | |
response = llm_client.chat_completion( | |
model=llm_model_to_use, | |
messages=[ | |
# {"role": "system", "content": base_system_message}, # Optional: provide original system message for context | |
{"role": "user", "content": analysis_system_prompt} # The prompt itself is the user message here | |
], | |
temperature=0.1, # Low temperature for deterministic tool selection | |
max_tokens=300, | |
stop=["\n\n"] # Stop early if LLM adds extra verbiage | |
) | |
analysis_text = response.choices[0].message.content.strip() | |
print(f"LLM tool analysis response: '{analysis_text}'") | |
if "NO_TOOL_NEEDED" in analysis_text or analysis_text == "NO_TOOL_NEEDED": | |
print("LLM determined no tool needed.") | |
return None, None | |
# Try to extract JSON from the response (handle potential markdown code blocks) | |
if analysis_text.startswith("```json"): | |
analysis_text = analysis_text.replace("```json", "").replace("```", "").strip() | |
elif analysis_text.startswith("```"): | |
analysis_text = analysis_text.replace("```", "").strip() | |
json_start = analysis_text.find("{") | |
json_end = analysis_text.rfind("}") + 1 | |
if json_start == -1 or json_end <= json_start: | |
print(f"Could not find valid JSON object in LLM response: '{analysis_text}'") | |
return None, None | |
json_str = analysis_text[json_start:json_end] | |
try: | |
tool_call_data = json.loads(json_str) | |
if "server_name" in tool_call_data and "tool_name" in tool_call_data: | |
print(f"LLM suggested tool call: {tool_call_data}") | |
return tool_call_data.get("server_name"), { | |
"tool_name": tool_call_data.get("tool_name"), | |
"parameters": tool_call_data.get("parameters", {}) | |
} | |
else: | |
print(f"LLM response parsed as JSON but missing server_name or tool_name: {json_str}") | |
return None, None | |
except json.JSONDecodeError as e: | |
print(f"Failed to parse tool call JSON from LLM response: '{json_str}'. Error: {e}") | |
return None, None | |
except Exception as e: | |
print(f"Error during LLM analysis for tool calls: {str(e)}") | |
import traceback | |
traceback.print_exc() | |
return None, None | |
def respond( | |
message_text_input, # From user function, this is just the text part | |
message_files_input, # From user function, this is the list of file paths | |
history_tuples: list[tuple[tuple[str, list], str]], # History: list of ((user_text, [user_files]), assistant_response) | |
system_message_prompt, | |
max_tokens_val, | |
temperature_val, | |
top_p_val, | |
frequency_penalty_val, | |
seed_val, | |
provider_choice, | |
custom_api_key_val, | |
custom_model_id, | |
# model_search_term_val, # Not directly used in respond, but kept for signature consistency if UI passes it | |
selected_hf_model_id, | |
mcp_is_enabled, | |
active_mcp_server_names, # List of selected server names | |
mcp_interaction_mode_choice | |
): | |
print(f"\n--- RESPOND FUNCTION CALLED ---") | |
print(f"Message Text: '{message_text_input}'") | |
print(f"Message Files: {message_files_input}") | |
# print(f"History (first item type if exists): {type(history_tuples) if history_tuples else 'No history'}") | |
print(f"System Prompt: '{system_message_prompt}'") | |
print(f"Provider: {provider_choice}, MCP Enabled: {mcp_is_enabled}, MCP Mode: {mcp_interaction_mode_choice}") | |
print(f"Active MCP Servers: {active_mcp_server_names}") | |
token_to_use_for_llm = custom_api_key_val if custom_api_key_val.strip() else ACCESS_TOKEN | |
if not token_to_use_for_llm and provider_choice != "hf-inference": # Basic check | |
yield "Error: API Key required for non-hf-inference providers." | |
return | |
llm_client_instance = InferenceClient(token=token_to_use_for_llm, provider=provider_choice) | |
current_seed = None if seed_val == -1 else seed_val | |
model_id_for_llm = custom_model_id.strip() if custom_model_id.strip() else selected_hf_model_id | |
print(f"Using LLM model: {model_id_for_llm} via {provider_choice}") | |
# --- MCP Tool Call Logic --- | |
if mcp_is_enabled and (message_text_input or message_files_input) and active_mcp_server_names: | |
tool_call_output_dict = None | |
invoked_tool_display_name = "a tool" | |
invoked_server_display_name = "an MCP server" | |
if message_text_input and message_text_input.startswith("/mcp"): | |
print("Processing explicit MCP command...") | |
command_parts = message_text_input.split(" ", 3) | |
if len(command_parts) < 3: | |
yield "Invalid MCP command. Format: /mcp <server_name> <tool_name> [arguments_json]" | |
return | |
_, server_name_cmd, tool_name_cmd = command_parts[:3] | |
invoked_server_display_name = server_name_cmd | |
invoked_tool_display_name = tool_name_cmd | |
args_json_str = "{}" if len(command_parts) < 4 else command_parts | |
try: | |
args_dict_cmd = json.loads(args_json_str) | |
tool_call_output_dict = call_mcp_tool(invoked_server_display_name, invoked_tool_display_name, **args_dict_cmd) | |
except json.JSONDecodeError: | |
yield f"Invalid JSON arguments for MCP command: {args_json_str}" | |
return | |
except Exception as e_cmd: | |
yield f"Error preparing MCP command: {str(e_cmd)}" | |
return | |
elif mcp_interaction_mode_choice == "Natural Language": | |
print("Analyzing message for natural language tool call...") | |
# For natural language, primarily use message_text_input. Files could be context later. | |
detected_server_nl, tool_info_nl = analyze_message_for_tool_call( | |
message_text_input, | |
active_mcp_server_names, | |
llm_client_instance, | |
model_id_for_llm, | |
system_message_prompt | |
) | |
if detected_server_nl and tool_info_nl and tool_info_nl.get("tool_name"): | |
invoked_server_display_name = detected_server_nl | |
invoked_tool_display_name = tool_info_nl['tool_name'] | |
tool_params_nl = tool_info_nl.get("parameters", {}) | |
tool_call_output_dict = call_mcp_tool(invoked_server_display_name, invoked_tool_display_name, **tool_params_nl) | |
# --- Handle MCP Tool Result (if a tool was called) --- | |
if tool_call_output_dict: | |
response_message_parts = [f"I attempted to use the **{invoked_tool_display_name}** tool from **{invoked_server_display_name}**."] | |
if tool_call_output_dict.get("type") == "audio": | |
audio_data_uri = tool_call_output_dict["data_uri"] | |
audio_html_tag = f"<audio controls src='{audio_data_uri}' title='{tool_call_output_dict.get('name', 'Audio Output')}'></audio>" | |
response_message_parts.append(f"Here's the audio output:\n{audio_html_tag}") | |
elif tool_call_output_dict.get("type") == "text": | |
response_message_parts.append(f"\nResult:\n```\n{tool_call_output_dict['value']}\n```") | |
elif tool_call_output_dict.get("type") == "json_string": # Changed from "json" to avoid confusion with dict | |
response_message_parts.append(f"\nResult (JSON):\n```json\n{tool_call_output_dict['value']}\n```") | |
elif tool_call_output_dict.get("type") == "error": | |
response_message_parts.append(f"\nUnfortunately, there was an error: {tool_call_output_dict['message']}") | |
else: # Fallback for unexpected result structure | |
response_message_parts.append(f"\nThe tool returned: {str(tool_call_output_dict)}") | |
yield "\n".join(response_message_parts) | |
return # End here if a tool was called and processed | |
# --- Regular LLM Response Logic (if no MCP tool was successfully called and returned primary content) --- | |
print("Proceeding with standard LLM response generation.") | |
# Prepare current user message for LLM (multimodal if files exist) | |
current_user_llm_content = [] | |
if message_text_input and message_text_input.strip(): | |
current_user_llm_content.append({"type": "text", "text": message_text_input}) | |
if message_files_input: | |
for file_path in message_files_input: | |
if file_path: # file_path is already the actual temp path from gr.File or gr.Image | |
encoded_img_str = encode_image(file_path) | |
if encoded_img_str: | |
current_user_llm_content.append({ | |
"type": "image_url", | |
"image_url": {"url": f"data:image/jpeg;base64,{encoded_img_str}"} | |
}) | |
else: | |
print(f"Warning: Failed to encode image {file_path} for LLM.") | |
if not current_user_llm_content: | |
print("No content (text or valid files) in current user message for LLM.") | |
yield "" # Or some indicator of no action | |
return | |
# Augment system message with MCP tool info if enabled | |
augmented_sys_msg = system_message_prompt | |
if mcp_is_enabled and active_mcp_server_names: | |
mcp_tool_descriptions_for_llm = [] | |
for server_name_iter in active_mcp_server_names: | |
if server_name_iter in mcp_connections: | |
# Use the more detailed list_mcp_tools output for the system prompt if desired | |
tools_list_str = list_mcp_tools(server_name_iter) # This returns markdown | |
mcp_tool_descriptions_for_llm.append(f"From server '{server_name_iter}':\n{tools_list_str}") | |
if mcp_tool_descriptions_for_llm: | |
full_tools_info_str = "\n\n".join(mcp_tool_descriptions_for_llm) | |
interaction_advice = "" | |
if mcp_interaction_mode_choice == "Command Mode": | |
interaction_advice = "The user can invoke these tools using '/mcp <server_name> <tool_name> <json_args>'." | |
# For Natural Language mode, the LLM doesn't need explicit instruction in system prompt | |
# as `analyze_message_for_tool_call` handles that part. | |
augmented_sys_msg += f"\n\nYou also have access to the following external tools via Model Context Protocol (MCP):\n{full_tools_info_str}\n{interaction_advice}" | |
# Prepare messages list for LLM | |
messages_for_llm_api = [{"role": "system", "content": augmented_sys_msg}] | |
for hist_user_turn, hist_assist_response in history_tuples: | |
hist_user_text, hist_user_files = hist_user_turn # Unpack ((text, [files])) | |
history_user_llm_content = [] | |
if hist_user_text and hist_user_text.strip(): | |
history_user_llm_content.append({"type": "text", "text": hist_user_text}) | |
if hist_user_files: | |
for hist_file_path in hist_user_files: | |
encoded_hist_img = encode_image(hist_file_path) | |
if encoded_hist_img: | |
history_user_llm_content.append({ | |
"type": "image_url", | |
"image_url": {"url": f"data:image/jpeg;base64,{encoded_hist_img}"} | |
}) | |
if history_user_llm_content: # Only add if there's actual content | |
messages_for_llm_api.append({"role": "user", "content": history_user_llm_content}) | |
if hist_assist_response and hist_assist_response.strip(): | |
messages_for_llm_api.append({"role": "assistant", "content": hist_assist_response}) | |
messages_for_llm_api.append({"role": "user", "content": current_user_llm_content}) | |
# print(f"Final messages for LLM API: {json.dumps(messages_for_llm_api, indent=2)}") | |
llm_parameters = { | |
"max_tokens": max_tokens_val, "temperature": temperature_val, "top_p": top_p_val, | |
"frequency_penalty": frequency_penalty_val, | |
} | |
if current_seed is not None: | |
llm_parameters["seed"] = current_seed | |
print(f"Sending request to LLM: Model={model_id_for_llm}, Params={llm_parameters}") | |
streamed_response_text = "" | |
try: | |
llm_stream = llm_client_instance.chat_completion( | |
model=model_id_for_llm, | |
messages=messages_for_llm_api, | |
stream=True, | |
**llm_parameters | |
) | |
# print("Streaming LLM response: ", end="", flush=True) | |
for chunk in llm_stream: | |
if hasattr(chunk, 'choices') and len(chunk.choices) > 0: | |
delta = chunk.choices.delta | |
if hasattr(delta, 'content') and delta.content: | |
token = delta.content | |
# print(token, end="", flush=True) | |
streamed_response_text += token | |
yield streamed_response_text | |
# print("\nLLM Stream finished.") | |
except Exception as e_llm: | |
error_msg = f"Error during LLM inference: {str(e_llm)}" | |
print(error_msg) | |
import traceback | |
traceback.print_exc() | |
streamed_response_text += f"\n{error_msg}" # Append error to existing stream if any | |
yield streamed_response_text | |
print(f"--- RESPOND FUNCTION COMPLETED ---") | |
# GRADIO UI | |
with gr.Blocks(theme="Nymbo/Nymbo_Theme", title="Serverless TextGen Hub + MCP") as demo: | |
gr.Markdown("# Serverless TextGen Hub with MCP Client") | |
chatbot = gr.Chatbot( | |
label="Chat", | |
height=600, | |
show_copy_button=True, | |
placeholder="Select a model, connect MCP servers (optional), and start chatting!", | |
bubble_full_width=False, | |
avatar_images=(None, "https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo-square.png") | |
) | |
with gr.Row(): | |
msg_textbox = gr.MultimodalTextbox( # Changed from gr.Textbox to gr.MultimodalTextbox | |
placeholder="Type a message or upload images... (Use /mcp for commands)", | |
show_label=False, | |
container=False, | |
scale=12, | |
file_types=["image"], # Can add more types like "audio", "video" if supported by models | |
file_count="multiple" # Allow multiple image uploads | |
) | |
# submit_button = gr.Button("Send", variant="primary", scale=1, min_width=100) # Optional explicit send button | |
with gr.Accordion("LLM Settings", open=False): | |
system_message_prompt_box = gr.Textbox( | |
value="You are a helpful and versatile AI assistant. You can understand text and images. If you have access to MCP tools, you can use them when appropriate or when the user asks.", | |
label="System Prompt", lines=3 | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
max_tokens_slider_ui = gr.Slider(minimum=128, maximum=8192, value=1024, step=128, label="Max New Tokens") | |
temperature_slider_ui = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.05, label="Temperature") | |
top_p_slider_ui = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, label="Top-P (Nucleus Sampling)") | |
with gr.Column(scale=1): | |
frequency_penalty_slider_ui = gr.Slider(minimum=-2.0, maximum=2.0, value=0.0, step=0.1, label="Frequency Penalty") | |
seed_slider_ui = gr.Slider(minimum=-1, maximum=65535, value=-1, step=1, label="Seed (-1 for random)") | |
providers_list_ui = [ | |
"hf-inference", "cerebras", "together", "sambanova", "novita", | |
"cohere", "fireworks-ai", "hyperbolic", "nebius", | |
] | |
provider_radio_ui = gr.Radio(choices=providers_list_ui, value="hf-inference", label="Inference Provider") | |
byok_textbox_ui = gr.Textbox(label="Your Hugging Face API Key (Optional)", placeholder="Enter HF Token if using non-hf-inference providers or private models", type="password") | |
custom_model_id_box = gr.Textbox(label="Custom Model ID (Overrides selection below)", placeholder="e.g., meta-llama/Llama-3-8B-Instruct") | |
model_search_box_ui = gr.Textbox(label="Filter Featured Models", placeholder="Search...", lines=1) | |
# More diverse model list, including some known multimodal ones | |
featured_models_list_data = [ | |
"meta-llama/Meta-Llama-3.1-8B-Instruct", # Good default | |
"meta-llama/Meta-Llama-3.1-70B-Instruct", | |
"mistralai/Mistral-Nemo-Instruct-2407", | |
"mistralai/Mixtral-8x22B-Instruct-v0.1", | |
"Qwen/Qwen2-7B-Instruct", | |
"microsoft/Phi-3-medium-128k-instruct", | |
# Multimodal | |
"Salesforce/blip-image-captioning-large", # Example, might not be chat | |
"llava-hf/llava-1.5-7b-hf", # LLaVA example | |
"microsoft/kosmos-2-patch14-224", # Kosmos-2 | |
"google/paligemma-3b-mix-448", # PaliGemma | |
] | |
featured_model_radio_ui = gr.Radio(label="Select a Featured Model", choices=featured_models_list_data, value="meta-llama/Meta-Llama-3.1-8B-Instruct", interactive=True) | |
gr.Markdown("Tip: For multimodal chat, ensure selected model supports image inputs (e.g., LLaVA, PaliGemma, Kosmos-2).") | |
with gr.Accordion("MCP Client Settings", open=False): | |
mcp_enabled_checkbox_ui = gr.Checkbox(label="Enable MCP Support", value=False, info="Connect to external tools and services via MCP.") | |
with gr.Row(): | |
mcp_server_url_textbox = gr.Textbox(label="MCP Server URL", placeholder="e.g., https://your-mcp-server.hf.space/gradio_api/mcp/sse") | |
mcp_server_name_textbox = gr.Textbox(label="Friendly Server Name (Optional)", placeholder="MyTTS_Server") | |
mcp_connect_button_ui = gr.Button("Connect", variant="secondary") | |
mcp_connection_status_textbox = gr.Textbox(label="MCP Connection Status", placeholder="No MCP servers connected.", interactive=False, lines=2) | |
active_mcp_servers_dropdown = gr.Dropdown( | |
label="Use Tools From (Select Active MCP Servers)", choices=[], multiselect=True, | |
info="Choose which connected servers the LLM can use tools from." | |
) | |
mcp_interaction_mode_radio = gr.Radio( | |
label="MCP Interaction Mode", choices=["Natural Language", "Command Mode"], value="Natural Language", | |
info="Natural Language: AI tries to detect tool use. Command Mode: Use '/mcp ...' syntax." | |
) | |
gr.Markdown("Example MCP Command: `/mcp MyTTS text_to_audio {\"text\": \"Hello world!\"}`") | |
# --- Event Handlers --- | |
# Store history as list of tuples: [ ((user_text, [user_files]), assistant_response), ... ] | |
chat_history_state = gr.State([]) | |
def user_interaction(user_multimodal_input, current_chat_history): | |
user_text = user_multimodal_input["text"] if user_multimodal_input and "text" in user_multimodal_input else "" | |
user_files = user_multimodal_input["files"] if user_multimodal_input and "files" in user_multimodal_input else [] | |
# Only add to history if there's text or files | |
if user_text or user_files: | |
current_chat_history.append( ((user_text, user_files), None) ) # Append user turn, assistant response is None initially | |
return current_chat_history, gr.update(value={"text": "", "files": []}) # Clear input textbox | |
def bot_response_generator( | |
current_chat_history, system_prompt, max_tokens, temp, top_p_val, freq_penalty, seed_val, | |
provider_val, api_key_val, custom_model_val, selected_model_val, # Removed search_term as it's not directly used by respond | |
mcp_enabled_val, active_servers_val, mcp_mode_val | |
): | |
if not current_chat_history or current_chat_history[-1] is not None: # If no user message or last message already has bot response | |
yield current_chat_history # Or simply `return current_chat_history` if not streaming | |
return | |
user_turn_content, _ = current_chat_history[-1] # Get the latest user turn: (text, [files]) | |
message_text, message_files = user_turn_content | |
# The history passed to `respond` should be all turns *before* the current one | |
history_for_respond = current_chat_history[:-1] | |
response_stream = respond( | |
message_text, message_files, history_for_respond, | |
system_prompt, max_tokens, temp, top_p_val, freq_penalty, seed_val, | |
provider_val, api_key_val, custom_model_val, selected_model_val, | |
mcp_enabled_val, active_servers_val, mcp_mode_val | |
) | |
full_bot_message = "" | |
for chunk in response_stream: | |
full_bot_message = chunk | |
current_chat_history[-1] = (user_turn_content, full_bot_message) # Update last item's assistant part | |
yield current_chat_history | |
# Link UI components to functions | |
msg_textbox.submit( | |
user_interaction, | |
inputs=[msg_textbox, chat_history_state], | |
outputs=[chat_history_state, msg_textbox] # Update history and clear input | |
).then( | |
bot_response_generator, | |
inputs=[ | |
chat_history_state, system_message_prompt_box, max_tokens_slider_ui, temperature_slider_ui, | |
top_p_slider_ui, frequency_penalty_slider_ui, seed_slider_ui, provider_radio_ui, | |
byok_textbox_ui, custom_model_id_box, featured_model_radio_ui, | |
mcp_enabled_checkbox_ui, active_mcp_servers_dropdown, mcp_interaction_mode_radio | |
], | |
outputs=[chatbot] # Stream to chatbot | |
) | |
# MCP Connection | |
def handle_mcp_connect(url, name_suggestion): | |
if not url or not url.strip(): | |
return "MCP Server URL cannot be empty.", gr.update(choices=list(mcp_connections.keys())) | |
_, status_msg = connect_to_mcp_server(url, name_suggestion) | |
# Update dropdown choices with current server names | |
new_choices = list(mcp_connections.keys()) | |
# Preserve selected values if they are still valid connections | |
# current_selected = active_mcp_servers_dropdown.value # This might not work directly | |
# new_selected = [s for s in current_selected if s in new_choices] | |
return status_msg, gr.update(choices=new_choices) #, value=new_selected) | |
mcp_connect_button_ui.click( | |
handle_mcp_connect, | |
inputs=[mcp_server_url_textbox, mcp_server_name_textbox], | |
outputs=[mcp_connection_status_textbox, active_mcp_servers_dropdown] | |
) | |
# Model Filtering | |
def filter_featured_models(search_query): | |
if not search_query: | |
return gr.update(choices=featured_models_list_data) | |
filtered = [m for m in featured_models_list_data if search_query.lower() in m.lower()] | |
return gr.update(choices=filtered if filtered else ["No models match your search"]) | |
model_search_box_ui.change(filter_featured_models, inputs=model_search_box_ui, outputs=featured_model_radio_ui) | |
# Auto-select hf-inference if BYOK is empty and other provider is chosen | |
def validate_api_key_for_provider(api_key_text, current_provider): | |
if not api_key_text.strip() and current_provider != "hf-inference": | |
gr.Warning("API Key needed for non-hf-inference providers. Defaulting to hf-inference.") | |
return gr.update(value="hf-inference") | |
return current_provider # No change if key provided or hf-inference selected | |
byok_textbox_ui.change(validate_api_key_for_provider, inputs=[byok_textbox_ui, provider_radio_ui], outputs=provider_radio_ui) | |
provider_radio_ui.change(validate_api_key_for_provider, inputs=[byok_textbox_ui, provider_radio_ui], outputs=provider_radio_ui) | |
if __name__ == "__main__": | |
print("Launching Gradio demo...") | |
demo.queue().launch(debug=True, show_api=False) # mcp_server=False as this is a client app |