Spaces:
Running
Running
import gradio as gr | |
from huggingface_hub import InferenceClient | |
import os | |
import json | |
import base64 | |
from PIL import Image | |
import io | |
import requests | |
from smolagents.mcp_client import MCPClient | |
from mcp import ToolResult # For type hinting, good practice | |
from mcp.common.content_block import ValueContentBlock # To access the actual tool return value | |
import numpy as np # For handling audio array | |
import soundfile as sf # For converting audio array to WAV | |
ACCESS_TOKEN = os.getenv("HF_TOKEN") | |
print("Access token loaded.") | |
# Function to encode image to base64 | |
def encode_image(image_path): | |
if not image_path: | |
print("No image path provided") | |
return None | |
try: | |
print(f"Encoding image from path: {image_path}") | |
# If it's already a PIL Image | |
if isinstance(image_path, Image.Image): | |
image = image_path | |
else: | |
# Try to open the image file | |
image = Image.open(image_path) | |
# Convert to RGB if image has an alpha channel (RGBA) | |
if image.mode == 'RGBA': | |
image = image.convert('RGB') | |
# Encode to base64 | |
buffered = io.BytesIO() | |
image.save(buffered, format="JPEG") | |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
print("Image encoded successfully") | |
return img_str | |
except Exception as e: | |
print(f"Error encoding image: {e}") | |
return None | |
# Dictionary to store active MCP connections | |
mcp_connections = {} | |
def connect_to_mcp_server(server_url, server_name=None): | |
"""Connect to an MCP server and return available tools""" | |
if not server_url: | |
return None, "No server URL provided" | |
try: | |
# Create an MCP client and connect to the server | |
client = MCPClient({"url": server_url}) | |
# Get available tools | |
tools = client.get_tools() | |
# Store the connection for later use | |
name = server_name or f"Server_{len(mcp_connections)}" | |
mcp_connections[name] = {"client": client, "tools": tools, "url": server_url} | |
return name, f"Successfully connected to {name} with {len(tools)} available tools" | |
except Exception as e: | |
print(f"Error connecting to MCP server: {e}") | |
return None, f"Error connecting to MCP server: {str(e)}" | |
def list_mcp_tools(server_name): | |
"""List available tools for a connected MCP server""" | |
if server_name not in mcp_connections: | |
return "Server not connected" | |
tools = mcp_connections[server_name]["tools"] | |
tool_info = [] | |
for tool in tools: | |
tool_info.append(f"- {tool.name}: {tool.description}") | |
if not tool_info: | |
return "No tools available for this server" | |
return "\n".join(tool_info) | |
def call_mcp_tool(server_name, tool_name, **kwargs): | |
"""Call a specific tool from an MCP server""" | |
if server_name not in mcp_connections: | |
return {"error": f"Server '{server_name}' not connected"} # Return dict for consistency | |
client_data = mcp_connections[server_name] | |
client = client_data["client"] | |
server_tools = client_data["tools"] | |
# Find the requested tool | |
tool = next((t for t in server_tools if t.name == tool_name), None) | |
if not tool: | |
return {"error": f"Tool '{tool_name}' not found on server '{server_name}'"} | |
try: | |
# Call the tool with provided arguments | |
mcp_tool_result: ToolResult = client.call_tool(tool_name=tool_name, arguments=kwargs) | |
actual_result = None | |
if mcp_tool_result.content: | |
content_block = mcp_tool_result.content[0] | |
if isinstance(content_block, ValueContentBlock): | |
actual_result = content_block.value | |
elif hasattr(content_block, 'text'): # e.g., TextContentBlock | |
actual_result = content_block.text | |
else: | |
actual_result = str(content_block) # Fallback | |
else: # No content | |
return {"warning": "Tool returned no content."} | |
# Special handling for audio result (e.g., from Kokoro TTS) | |
# This checks if the result is a tuple (sample_rate, audio_data_list) | |
# Gradio MCP server serializes numpy arrays to lists. | |
if (server_name == "kokoroTTS" and tool_name == "text_to_audio" and | |
isinstance(actual_result, tuple) and len(actual_result) == 2 and | |
isinstance(actual_result[0], int) and | |
(isinstance(actual_result[1], list) or isinstance(actual_result[1], np.ndarray))): | |
print(f"Received audio data from {server_name}.{tool_name}") | |
sample_rate, audio_data_list = actual_result | |
# Convert list to numpy array if necessary | |
audio_data = np.array(audio_data_list) | |
# Ensure correct dtype for soundfile (float32 is common, or int16) | |
# Kokoro returns float, likely in [-1, 1] range. | |
if audio_data.dtype != np.float32 and audio_data.dtype != np.int16: | |
# Attempt to normalize if it looks like it's not in [-1, 1] for float | |
if np.issubdtype(audio_data.dtype, np.floating) and (np.min(audio_data) < -1.1 or np.max(audio_data) > 1.1): | |
print(f"Warning: Audio data for {server_name}.{tool_name} might not be normalized. Min: {np.min(audio_data)}, Max: {np.max(audio_data)}") | |
audio_data = audio_data.astype(np.float32) | |
wav_io = io.BytesIO() | |
sf.write(wav_io, audio_data, sample_rate, format='WAV') | |
wav_io.seek(0) | |
wav_b64 = base64.b64encode(wav_io.read()).decode('utf-8') | |
return { | |
"type": "audio_b64", | |
"data": wav_b64, | |
"message": f"Audio generated by {server_name}.{tool_name}" | |
} | |
# Handle other types of results | |
if isinstance(actual_result, dict): | |
return actual_result | |
elif isinstance(actual_result, str): | |
try: # If string is JSON, parse to dict | |
return json.loads(actual_result) | |
except json.JSONDecodeError: | |
return {"text": actual_result} # Wrap raw string | |
else: | |
return {"value": str(actual_result)} # Fallback for other primitive types | |
except Exception as e: | |
print(f"Error calling MCP tool: {e}") | |
import traceback | |
traceback.print_exc() | |
return {"error": f"Error calling MCP tool: {str(e)}"} | |
def analyze_message_for_tool_call(message, active_mcp_servers, client, model_to_use, system_message): | |
"""Analyze a message to determine if an MCP tool should be called""" | |
if not message or not message.strip(): | |
return None, None | |
tool_info = [] | |
for server_name in active_mcp_servers: | |
if server_name in mcp_connections: | |
server_tools_raw = list_mcp_tools(server_name) # This returns a string | |
if server_tools_raw != "Server not connected" and server_tools_raw != "No tools available for this server": | |
# Parse the string from list_mcp_tools | |
for line in server_tools_raw.split("\n"): | |
if line.startswith("- "): | |
parts = line[2:].split(":", 1) | |
if len(parts) == 2: | |
tool_info.append({ | |
"server_name": server_name, | |
"tool_name": parts[0].strip(), | |
"description": parts[1].strip() | |
}) | |
if not tool_info: | |
return None, None | |
tools_desc = [] | |
for info in tool_info: | |
tools_desc.append(f"{info['server_name']}.{info['tool_name']}: {info['description']}") | |
tools_string = "\n".join(tools_desc) | |
analysis_system_prompt = f"""You are an assistant that helps determine if a user message requires using an external tool. | |
Available tools: | |
{tools_string} | |
Your job is to: | |
1. Analyze the user's message. | |
2. Determine if they're asking to use one of the tools. | |
3. If yes, respond ONLY with a JSON object with "server_name", "tool_name", and "parameters". | |
4. If no, respond ONLY with the exact string "NO_TOOL_NEEDED". | |
Example 1 (User wants TTS): | |
User: "Please turn this text into speech: Hello world" | |
Response: {{"server_name": "kokoroTTS", "tool_name": "text_to_audio", "parameters": {{"text": "Hello world", "speed": 1.0}}}} | |
Example 2 (User wants TTS with different server name): | |
User: "Use mySpeechTool to say 'good morning'" | |
Response: {{"server_name": "mySpeechTool", "tool_name": "text_to_audio", "parameters": {{"text": "good morning"}}}} | |
Example 3 (User does not want a tool): | |
User: "What is the capital of France?" | |
Response: NO_TOOL_NEEDED""" | |
try: | |
response = client.chat_completion( | |
model=model_to_use, | |
messages=[ | |
{"role": "system", "content": analysis_system_prompt}, | |
{"role": "user", "content": message} | |
], | |
temperature=0.1, | |
max_tokens=300 | |
) | |
analysis = response.choices[0].message.content.strip() | |
print(f"Tool analysis LLM response: '{analysis}'") | |
if analysis == "NO_TOOL_NEEDED": | |
return None, None | |
try: | |
tool_call = json.loads(analysis) | |
if isinstance(tool_call, dict) and "server_name" in tool_call and "tool_name" in tool_call: | |
return tool_call.get("server_name"), { | |
"tool_name": tool_call.get("tool_name"), | |
"parameters": tool_call.get("parameters", {}) | |
} | |
else: | |
print(f"LLM response for tool call was not a valid JSON with required keys: {analysis}") | |
return None, None | |
except json.JSONDecodeError: | |
print(f"Failed to parse tool call JSON from LLM: {analysis}") | |
return None, None | |
except Exception as e: | |
print(f"Error analyzing message for tool calls: {str(e)}") | |
return None, None | |
def respond( | |
message, | |
image_files, | |
history: list[tuple[str, str]], | |
system_message, | |
max_tokens, | |
temperature, | |
top_p, | |
frequency_penalty, | |
seed, | |
provider, | |
custom_api_key, | |
custom_model, | |
model_search_term, | |
selected_model, | |
mcp_enabled=False, | |
active_mcp_servers=None, | |
mcp_interaction_mode="Natural Language" | |
): | |
print(f"Received message: {message}") | |
print(f"Received {len(image_files) if image_files else 0} images") | |
# print(f"History: {history}") # Can be verbose | |
print(f"System message: {system_message}") | |
print(f"Max tokens: {max_tokens}, Temperature: {temperature}, Top-P: {top_p}") | |
print(f"Frequency Penalty: {frequency_penalty}, Seed: {seed}") | |
print(f"Selected provider: {provider}") | |
print(f"Custom API Key provided: {bool(custom_api_key.strip())}") | |
print(f"Selected model (custom_model): {custom_model}") | |
print(f"Model search term: {model_search_term}") | |
print(f"Selected model from radio: {selected_model}") | |
print(f"MCP enabled: {mcp_enabled}") | |
print(f"Active MCP servers: {active_mcp_servers}") | |
print(f"MCP interaction mode: {mcp_interaction_mode}") | |
token_to_use = custom_api_key if custom_api_key.strip() != "" else ACCESS_TOKEN | |
if custom_api_key.strip() != "": | |
print("USING CUSTOM API KEY: BYOK token provided by user is being used for authentication") | |
else: | |
print("USING DEFAULT API KEY: Environment variable HF_TOKEN is being used for authentication") | |
client = InferenceClient(token=token_to_use, provider=provider) | |
print(f"Hugging Face Inference Client initialized with {provider} provider.") | |
if seed == -1: | |
seed = None | |
model_to_use = custom_model.strip() if custom_model.strip() != "" else selected_model | |
print(f"Model selected for inference: {model_to_use}") | |
if mcp_enabled and message: | |
if message.startswith("/mcp"): | |
command_parts = message.split(" ", 3) | |
if len(command_parts) < 3: | |
yield "Invalid MCP command. Format: /mcp <server_name> <tool_name> [arguments_json]" | |
return | |
_, server_name, tool_name = command_parts[:3] | |
args_json = "{}" if len(command_parts) < 4 else command_parts[3] | |
try: | |
args_dict = json.loads(args_json) | |
result = call_mcp_tool(server_name, tool_name, **args_dict) | |
if isinstance(result, dict) and result.get("type") == "audio_b64": | |
yield f"<audio controls src=\"data:audio/wav;base64,{result.get('data')}\"></audio>" | |
elif isinstance(result, dict) and "error" in result: | |
yield f"Error: {result.get('error')}" | |
elif isinstance(result, dict): | |
yield json.dumps(result, indent=2) | |
else: | |
yield str(result) | |
return | |
except json.JSONDecodeError: | |
yield f"Invalid JSON arguments: {args_json}" | |
return | |
except Exception as e: | |
yield f"Error executing MCP command: {str(e)}" | |
return | |
elif mcp_interaction_mode == "Natural Language" and active_mcp_servers and active_mcp_servers: | |
print("Attempting natural language tool call detection...") | |
server_name, tool_info = analyze_message_for_tool_call( | |
message, active_mcp_servers, client, model_to_use, system_message | |
) | |
if server_name and tool_info and tool_info.get("tool_name"): | |
try: | |
print(f"Calling tool via natural language: {server_name}.{tool_info['tool_name']} with parameters: {tool_info['parameters']}") | |
result = call_mcp_tool(server_name, tool_info['tool_name'], **tool_info.get('parameters', {})) | |
response_message = f"I used the **{tool_info['tool_name']}** tool from **{server_name}**." | |
if isinstance(result, dict) and result.get("message"): | |
response_message += f" ({result.get('message')})" | |
response_message += "\n\n" | |
if isinstance(result, dict) and result.get("type") == "audio_b64": | |
audio_html = f"<audio controls src=\"data:audio/wav;base64,{result.get('data')}\"></audio>" | |
yield response_message + audio_html | |
elif isinstance(result, dict) and "error" in result: | |
result_str = f"Tool Error: {result.get('error')}" | |
yield response_message + result_str | |
elif isinstance(result, dict): | |
result_str = f"Result:\n```json\n{json.dumps(result, indent=2)}\n```" | |
yield response_message + result_str | |
else: | |
result_str = f"Result:\n{str(result)}" | |
yield response_message + result_str | |
return | |
except Exception as e: | |
print(f"Error executing MCP tool via natural language: {str(e)}") | |
# yield f"Sorry, I encountered an error trying to use the tool: {str(e)}" | |
# Fall through to normal LLM response if tool call fails here | |
else: | |
print("No tool call detected by natural language analysis or tool_info incomplete.") | |
user_content_parts = [] | |
if message and message.strip(): | |
user_content_parts.append({"type": "text", "text": message}) | |
if image_files and len(image_files) > 0: | |
for img_path in image_files: | |
if img_path: | |
try: | |
encoded_image = encode_image(img_path) | |
if encoded_image: | |
user_content_parts.append({ | |
"type": "image_url", | |
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"} | |
}) | |
except Exception as e: | |
print(f"Error encoding image {img_path}: {e}") | |
if not user_content_parts: # If message was only /mcp command and processed | |
print("No further content for LLM after MCP command processing.") | |
# This might happen if an MCP command was fully handled and returned. | |
# If yield was used, the function already exited. If not, we might need to ensure no LLM call. | |
# However, the logic above for MCP commands uses `yield ...; return`, so this path might not be hit often. | |
# If it *is* hit, it means the MCP command didn't yield, and we should not proceed to LLM. | |
if message and message.startswith("/mcp"): | |
return # Ensure we don't fall through after a command that should have yielded. | |
final_user_content = user_content_parts if len(user_content_parts) > 1 else (user_content_parts[0] if user_content_parts else "") | |
augmented_system_message = system_message | |
if mcp_enabled and active_mcp_servers: | |
tool_list_for_prompt = [] | |
for server_name_iter in active_mcp_servers: | |
if server_name_iter in mcp_connections: | |
server_tools_str = list_mcp_tools(server_name_iter) | |
if server_tools_str and "not connected" not in server_tools_str and "No tools available" not in server_tools_str: | |
tool_list_for_prompt.append(f"From server '{server_name_iter}':\n{server_tools_str}") | |
if tool_list_for_prompt: | |
mcp_tools_description = "\n\n".join(tool_list_for_prompt) | |
if mcp_interaction_mode == "Command Mode": | |
augmented_system_message += f"\n\nYou have access to the following MCP tools. To use them, type a command in the format: /mcp <server_name> <tool_name> <arguments_json>\nTools:\n{mcp_tools_description}" | |
else: # Natural Language | |
augmented_system_message += f"\n\nYou have access to the following MCP tools. You can ask to use them in natural language, and I will try to detect when a tool is needed. If I miss it, you can try being more explicit about the tool name.\nTools:\n{mcp_tools_description}" | |
messages_for_api = [{"role": "system", "content": augmented_system_message}] | |
print("Initial messages array constructed.") | |
for val in history: | |
past_user_msg, past_assistant_msg = val | |
# Handle past user messages (could be text or multimodal) | |
if past_user_msg: | |
if isinstance(past_user_msg, list): # Already multimodal | |
messages_for_api.append({"role": "user", "content": past_user_msg}) | |
elif isinstance(past_user_msg, str): # Text only | |
messages_for_api.append({"role": "user", "content": past_user_msg}) | |
if past_assistant_msg: | |
messages_for_api.append({"role": "assistant", "content": past_assistant_msg}) | |
if final_user_content: # Add current user message if it exists | |
messages_for_api.append({"role": "user", "content": final_user_content}) | |
print(f"Latest user message appended (content type: {type(final_user_content)})") | |
# print(f"Full messages_for_api: {json.dumps(messages_for_api, indent=2)}") # Can be very verbose | |
llm_response_text = "" | |
print(f"Sending request to {provider} provider for model {model_to_use}.") | |
parameters = { | |
"max_tokens": max_tokens, | |
"temperature": temperature, | |
"top_p": top_p, | |
"frequency_penalty": frequency_penalty, | |
} | |
if seed is not None: | |
parameters["seed"] = seed | |
try: | |
stream = client.chat_completion( | |
model=model_to_use, | |
messages=messages_for_api, | |
stream=True, | |
**parameters | |
) | |
# print("Received tokens: ", end="", flush=True) # Can be too noisy | |
for chunk in stream: | |
if hasattr(chunk, 'choices') and len(chunk.choices) > 0: | |
if hasattr(chunk.choices[0], 'delta') and hasattr(chunk.choices[0].delta, 'content'): | |
token_text = chunk.choices[0].delta.content | |
if token_text: | |
# print(token_text, end="", flush=True) # Can be too noisy | |
llm_response_text += token_text | |
yield llm_response_text | |
# print() # Newline after tokens | |
except Exception as e: | |
print(f"Error during LLM inference: {e}") | |
llm_response_text += f"\nLLM Error: {str(e)}" | |
yield llm_response_text | |
print("Completed LLM response generation.") | |
# GRADIO UI | |
with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo: | |
chatbot = gr.Chatbot( | |
height=600, | |
show_copy_button=True, | |
placeholder="Select a model and begin chatting. Supports multiple inference providers, multimodal inputs, and MCP tools.", | |
layout="panel", | |
show_label=False, | |
render=False # Delay rendering | |
) | |
print("Chatbot interface created.") | |
with gr.Row(): | |
msg = gr.MultimodalTextbox( | |
placeholder="Type a message or upload images...", | |
show_label=False, | |
container=False, | |
scale=12, | |
file_types=["image"], | |
file_count="multiple", | |
sources=["upload"], | |
render=False # Delay rendering | |
) | |
chatbot.render() | |
msg.render() | |
with gr.Accordion("Settings", open=False): | |
system_message_box = gr.Textbox( | |
value="You are a helpful AI assistant that can understand images and text. If the user asks you to use a tool, try your best.", | |
placeholder="You are a helpful assistant.", | |
label="System Prompt" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
max_tokens_slider = gr.Slider(minimum=1, maximum=8192, value=1024, step=1, label="Max tokens") | |
temperature_slider = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.01, label="Temperature") | |
top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.01, label="Top-P") | |
with gr.Column(scale=1): | |
frequency_penalty_slider = gr.Slider(minimum=-2.0, maximum=2.0, value=0.0, step=0.1, label="Frequency Penalty") | |
seed_slider = gr.Slider(minimum=-1, maximum=65535, value=-1, step=1, label="Seed (-1 for random)") | |
providers_list = ["hf-inference", "cerebras", "together", "sambanova", "novita", "cohere", "fireworks-ai", "hyperbolic", "nebius"] | |
provider_radio = gr.Radio(choices=providers_list, value="hf-inference", label="Inference Provider") | |
byok_textbox = gr.Textbox(value="", label="BYOK (Bring Your Own Key)", info="Enter a custom Hugging Face API key here. If empty, only 'hf-inference' provider can be used with the shared token.", placeholder="Enter your Hugging Face API token", type="password") | |
custom_model_box = gr.Textbox(value="", label="Custom Model ID", info="(Optional) Provide a custom Hugging Face model ID. Overrides selected featured model.", placeholder="meta-llama/Llama-3.1-70B-Instruct") | |
model_search_box = gr.Textbox(label="Filter Featured Models", placeholder="Search for a featured model...", lines=1) | |
models_list = [ | |
"meta-llama/Llama-3.1-405B-Instruct-FP8", # Large model, might be slow/expensive | |
"meta-llama/Llama-3.1-70B-Instruct", | |
"meta-llama/Llama-3.1-8B-Instruct", | |
"mistralai/Mistral-Nemo-Instruct-2407", | |
"Qwen/Qwen2-72B-Instruct", | |
"Qwen/Qwen2-57B-A14B-Instruct", | |
"CohereForAI/c4ai-command-r-plus", | |
# Multimodal models | |
"Salesforce/LlavaLlama3-8b-hf", | |
"llava-hf/llava-v1.6-mistral-7b-hf", | |
"llava-hf/llava-v1.6-vicuna-13b-hf", | |
"microsoft/Phi-3-vision-128k-instruct", | |
"google/paligemma-3b-mix-448", | |
# Older but still popular | |
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", | |
"mistralai/Mixtral-8x7B-Instruct-v0.1", | |
"mistralai/Mistral-7B-Instruct-v0.3", | |
] | |
featured_model_radio = gr.Radio(label="Select a Featured Model", choices=models_list, value="meta-llama/Llama-3.1-8B-Instruct", interactive=True) | |
gr.Markdown("[View all Text-to-Text models](https://huggingface.co/models?pipeline_tag=text-generation&sort=trending) | [View all multimodal models](https://huggingface.co/models?pipeline_tag=image-to-text&sort=trending)") | |
with gr.Accordion("MCP Settings", open=False): | |
mcp_enabled_checkbox = gr.Checkbox(label="Enable MCP Support", value=False, info="Enable Model Context Protocol support to connect to external tools and services") | |
with gr.Row(): | |
mcp_server_url = gr.Textbox(label="MCP Server URL", placeholder="https://your-mcp-server.hf.space/gradio_api/mcp/sse", info="URL of the MCP server (usually ends with /gradio_api/mcp/sse for Gradio MCP servers)") | |
mcp_server_name = gr.Textbox(label="Server Name (Optional)", placeholder="e.g., kokoroTTS", info="A friendly name to identify this server") | |
mcp_connect_button = gr.Button("Connect to MCP Server") | |
mcp_status = gr.Textbox(label="MCP Connection Status", placeholder="No MCP servers connected", interactive=False) | |
active_mcp_servers = gr.Dropdown(label="Active MCP Servers for Chat", choices=[], multiselect=True, info="Select which connected MCP servers to make available to the LLM for this chat session") | |
mcp_mode = gr.Radio(label="MCP Interaction Mode", choices=["Natural Language", "Command Mode"], value="Natural Language", info="Choose how to interact with MCP tools") | |
gr.Markdown(""" | |
### MCP Interaction Modes & Examples | |
**Natural Language Mode**: Describe what you want. | |
`Please say 'Hello world' using the kokoroTTS server.` | |
`Use my speech tool to read this: "Welcome"` | |
**Command Mode**: Use structured commands (server name must match connected server's friendly name). | |
`/mcp <server_name> <tool_name> {"param1": "value1"}` | |
Example: `/mcp kokoroTTS text_to_audio {"text": "Hello world", "speed": 1.0}` | |
""") | |
# Chat history state | |
# The chatbot component itself manages history for display. | |
# The `respond` function receives this display history and reconstructs API history. | |
def filter_models_ui_update(search_term): | |
print(f"Filtering models with search term: {search_term}") | |
filtered = [m for m in models_list if search_term.lower() in m.lower()] | |
if not filtered: # If search yields no results, show all models | |
filtered = models_list | |
print(f"Filtered models: {filtered}") | |
return gr.Radio(choices=filtered, label="Select a Featured Model", value=featured_model_radio.value if featured_model_radio.value in filtered else (filtered[0] if filtered else None)) | |
def set_custom_model_from_radio_ui_update(selected_featured_model): | |
print(f"Featured model selected: {selected_featured_model}") | |
return selected_featured_model # This updates the custom_model_box | |
def connect_mcp_server_ui_update(url, name_optional): | |
actual_name, status_msg = connect_to_mcp_server(url, name_optional) | |
updated_server_choices = list(mcp_connections.keys()) | |
# Keep existing selection if possible | |
current_selection = active_mcp_servers.value if active_mcp_servers.value else [] | |
valid_selection = [s for s in current_selection if s in updated_server_choices] | |
if actual_name and actual_name not in valid_selection: # Auto-select newly connected server | |
valid_selection.append(actual_name) | |
return status_msg, gr.Dropdown(choices=updated_server_choices, value=valid_selection, label="Active MCP Servers for Chat") | |
# This function processes the user's multimodal input and adds it to the chatbot history. | |
# It prepares the history in a way that `bot` can understand. | |
def handle_user_input(multimodal_input, history_list: list): | |
text_content = multimodal_input.get("text", "").strip() | |
files = multimodal_input.get("files", []) | |
# This will be the entry for the user's turn in the history | |
user_turn_for_api = [] | |
user_turn_for_display = "" | |
if text_content: | |
user_turn_for_api.append({"type": "text", "text": text_content}) | |
user_turn_for_display = text_content | |
if files: | |
display_files_md = "" | |
for file_path in files: | |
if file_path and isinstance(file_path, str): # Gradio provides temp path | |
encoded_img = encode_image(file_path) # For API | |
if encoded_img: | |
user_turn_for_api.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_img}"}}) | |
# For display, Gradio handles showing the image from MultimodalTextbox output | |
# We'll just make a note in the display string | |
display_files_md += f"\n<img src='file={file_path}' style='max-height:150px; display:block;' alt='uploaded image'>" # Gradio can render this! | |
if user_turn_for_display: | |
user_turn_for_display += display_files_md | |
else: | |
user_turn_for_display = display_files_md if display_files_md else "Image(s) uploaded" | |
if not user_turn_for_display and not user_turn_for_api: # Empty input | |
return history_list, multimodal_input # No change | |
# The `respond` function expects history as list of [user_api_content, assistant_text_content] | |
# For the current turn, we add [user_api_content, None] | |
# The display history for chatbot is [user_display_content, assistant_text_content] | |
# We pass the API-formatted user turn to the `message` arg of `respond` | |
# and the existing history to the `history` arg. | |
# The chatbot's display history is updated here. | |
history_list.append([user_turn_for_display, None]) | |
return history_list, user_turn_for_api # Return updated history and the API formatted current message | |
# The bot function that calls `respond` generator | |
def call_bot_responder(history_list_for_display, current_user_api_content, sys_msg, max_tok, temp, top_p_val, freq_pen, seed_val, prov, api_key_val, cust_model, _search, sel_model, mcp_on, active_servs, mcp_inter_mode): | |
if not current_user_api_content and not (history_list_for_display and history_list_for_display[-1][0]): | |
print("Bot called with no current message and no history, skipping.") | |
yield history_list_for_display # No change | |
return | |
# Reconstruct API history from display history | |
# `respond` expects history as list of [user_api_content, assistant_text_content] | |
# The current `history_list_for_display` is [user_display, assistant_text] | |
# This reconstruction is tricky because display != api format. | |
# For simplicity, we'll pass only the text part of history to `respond` for now, | |
# and the full current_user_api_content for the current message. | |
# A more robust solution would store API history separately. | |
# Simplified history for `respond` (text only from past turns) | |
# The `respond` function itself needs to be robust to handle this. | |
# Let's adjust `respond` to take `message` (current API content) and `image_files` (current files) | |
# and `history` (past turns, which we'll simplify here). | |
# The `respond` function is already structured to take `message` (text) and `image_files` | |
# The `current_user_api_content` is what we need to pass as `message` (if text) or `image_files` | |
current_message_text = "" | |
current_image_paths = [] | |
if isinstance(current_user_api_content, list): # Multimodal | |
for part in current_user_api_content: | |
if part["type"] == "text": | |
current_message_text = part["text"] | |
elif part["type"] == "image_url": | |
# We can't easily get back the path from base64 for `respond`'s current design | |
# This indicates a slight mismatch. `respond` expects paths for current images. | |
# For now, let's assume `respond` can handle base64 if passed correctly. | |
# Or, we modify `handle_user_input` to also pass original paths if needed by `respond`. | |
# Let's assume `respond`'s `image_files` param can take base64 strings for now. | |
# This is a simplification. | |
# The `encode_image` in `respond` expects paths. | |
# For now, we'll pass None for image_files if it's already in current_user_api_content. | |
# This part needs careful review of how `respond` handles current images. | |
# The `respond` function's `image_files` parameter is for new uploads. | |
# If `current_user_api_content` already has encoded images, `respond` should use that. | |
# The `respond` function's first two args are `message` (text) and `image_files` (paths). | |
# We need to extract these from `current_user_api_content`. | |
pass # Images are part of `current_user_api_content` which is passed to `messages_for_api` | |
elif isinstance(current_user_api_content, str): # Text only | |
current_message_text = current_user_api_content | |
# Simplified history for `respond` (text from display) | |
# `respond` will reconstruct its own API history. | |
simplified_past_history = [] | |
if len(history_list_for_display) > 1: # Exclude current turn | |
for user_disp, assistant_text in history_list_for_display[:-1]: | |
# Extract text from user_disp for simplified history | |
user_text_for_hist = user_disp | |
if isinstance(user_disp, str) and "<img src" in user_disp : # Basic check if it was image display | |
# Try to find text part if any, otherwise empty | |
lines = user_disp.splitlines() | |
text_lines = [line for line in lines if not line.strip().startswith("<img")] | |
user_text_for_hist = "\n".join(text_lines).strip() if text_lines else "" | |
simplified_past_history.append([user_text_for_hist, assistant_text]) | |
# The `respond` function's first argument is `message` (current text) | |
# and `image_files` (current image paths). | |
# We need to extract these from `current_user_api_content` if it was prepared by `handle_user_input`. | |
# For now, let's assume `respond` will get the full `current_user_api_content` via `messages_for_api`. | |
# The first two args of `respond` are for the *current* turn's text and image paths. | |
# Let's get current text and image paths from `current_user_api_content` | |
# This is slightly redundant as `respond` also reconstructs this, but for clarity: | |
_current_text_for_respond = "" | |
_current_image_paths_for_respond = [] # `respond` expects paths | |
if isinstance(current_user_api_content, list): | |
for item in current_user_api_content: | |
if item['type'] == 'text': | |
_current_text_for_respond = item['text'] | |
# We can't get paths back from base64 easily. | |
# This highlights that `respond` needs to be able to take already processed multimodal content. | |
# For now, we'll assume `respond` internally uses the `messages_for_api` which has the full content. | |
# So, we can pass `_current_text_for_respond` and `None` for image_files if images are already in API format. | |
bot_response_stream = respond( | |
message=_current_text_for_respond, # Current text | |
image_files=None, # Assume images are handled by messages_for_api construction in respond | |
history=simplified_past_history, # Past turns | |
system_message=sys_msg, | |
max_tokens=max_tok, | |
temperature=temp, | |
top_p=top_p_val, | |
frequency_penalty=freq_pen, | |
seed=seed_val, | |
provider=prov, | |
custom_api_key=api_key_val, | |
custom_model=cust_model, | |
model_search_term="", # Not directly used by respond | |
selected_model=sel_model, | |
mcp_enabled=mcp_on, | |
active_mcp_servers=active_servs, | |
mcp_interaction_mode=mcp_inter_mode | |
) | |
for response_chunk in bot_response_stream: | |
history_list_for_display[-1][1] = response_chunk | |
yield history_list_for_display | |
# This state will hold the API-formatted content of the current user message | |
current_api_message_state = gr.State(None) | |
msg.submit( | |
handle_user_input, | |
[msg, chatbot], # chatbot here is the history_list | |
[chatbot, current_api_message_state] # Update history display and current_api_message_state | |
).then( | |
call_bot_responder, | |
[chatbot, current_api_message_state, system_message_box, max_tokens_slider, temperature_slider, top_p_slider, | |
frequency_penalty_slider, seed_slider, provider_radio, byok_textbox, custom_model_box, | |
model_search_box, featured_model_radio, mcp_enabled_checkbox, active_mcp_servers, mcp_mode], | |
[chatbot] # Update chatbot display with streaming response | |
).then( | |
lambda: gr.MultimodalTextbox(value={"text": "", "files": []}), # Clear MultimodalTextbox | |
None, | |
[msg] | |
) | |
mcp_connect_button.click( | |
connect_mcp_server_ui_update, | |
[mcp_server_url, mcp_server_name], | |
[mcp_status, active_mcp_servers] | |
) | |
model_search_box.change(fn=filter_models_ui_update, inputs=model_search_box, outputs=featured_model_radio) | |
featured_model_radio.change(fn=set_custom_model_from_radio_ui_update, inputs=featured_model_radio, outputs=custom_model_box) | |
def validate_provider_ui_update(api_key, current_provider): | |
if not api_key.strip() and current_provider != "hf-inference": | |
gr.Info("No API key provided. Defaulting to 'hf-inference' provider.") | |
return gr.Radio(value="hf-inference") # Update provider_radio | |
return gr.Radio(value=current_provider) # No change needed or keep current | |
byok_textbox.change(fn=validate_provider_ui_update, inputs=[byok_textbox, provider_radio], outputs=provider_radio) | |
provider_radio.change(fn=validate_provider_ui_update, inputs=[byok_textbox, provider_radio], outputs=provider_radio) | |
print("Gradio interface initialized.") | |
if __name__ == "__main__": | |
print("Launching the demo application.") | |
demo.queue().launch(show_api=False, debug=False) # mcp_server=False as this is a client |