Nymbo's picture
Update app.py
717cd1f verified
raw
history blame
39.4 kB
import gradio as gr
from huggingface_hub import InferenceClient
import os
import json
import base64
from PIL import Image
import io
import requests
from smolagents.mcp_client import MCPClient
from mcp import ToolResult # For type hinting, good practice
from mcp.common.content_block import ValueContentBlock # To access the actual tool return value
import numpy as np # For handling audio array
import soundfile as sf # For converting audio array to WAV
ACCESS_TOKEN = os.getenv("HF_TOKEN")
print("Access token loaded.")
# Function to encode image to base64
def encode_image(image_path):
if not image_path:
print("No image path provided")
return None
try:
print(f"Encoding image from path: {image_path}")
# If it's already a PIL Image
if isinstance(image_path, Image.Image):
image = image_path
else:
# Try to open the image file
image = Image.open(image_path)
# Convert to RGB if image has an alpha channel (RGBA)
if image.mode == 'RGBA':
image = image.convert('RGB')
# Encode to base64
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
print("Image encoded successfully")
return img_str
except Exception as e:
print(f"Error encoding image: {e}")
return None
# Dictionary to store active MCP connections
mcp_connections = {}
def connect_to_mcp_server(server_url, server_name=None):
"""Connect to an MCP server and return available tools"""
if not server_url:
return None, "No server URL provided"
try:
# Create an MCP client and connect to the server
client = MCPClient({"url": server_url})
# Get available tools
tools = client.get_tools()
# Store the connection for later use
name = server_name or f"Server_{len(mcp_connections)}"
mcp_connections[name] = {"client": client, "tools": tools, "url": server_url}
return name, f"Successfully connected to {name} with {len(tools)} available tools"
except Exception as e:
print(f"Error connecting to MCP server: {e}")
return None, f"Error connecting to MCP server: {str(e)}"
def list_mcp_tools(server_name):
"""List available tools for a connected MCP server"""
if server_name not in mcp_connections:
return "Server not connected"
tools = mcp_connections[server_name]["tools"]
tool_info = []
for tool in tools:
tool_info.append(f"- {tool.name}: {tool.description}")
if not tool_info:
return "No tools available for this server"
return "\n".join(tool_info)
def call_mcp_tool(server_name, tool_name, **kwargs):
"""Call a specific tool from an MCP server"""
if server_name not in mcp_connections:
return {"error": f"Server '{server_name}' not connected"} # Return dict for consistency
client_data = mcp_connections[server_name]
client = client_data["client"]
server_tools = client_data["tools"]
# Find the requested tool
tool = next((t for t in server_tools if t.name == tool_name), None)
if not tool:
return {"error": f"Tool '{tool_name}' not found on server '{server_name}'"}
try:
# Call the tool with provided arguments
mcp_tool_result: ToolResult = client.call_tool(tool_name=tool_name, arguments=kwargs)
actual_result = None
if mcp_tool_result.content:
content_block = mcp_tool_result.content[0]
if isinstance(content_block, ValueContentBlock):
actual_result = content_block.value
elif hasattr(content_block, 'text'): # e.g., TextContentBlock
actual_result = content_block.text
else:
actual_result = str(content_block) # Fallback
else: # No content
return {"warning": "Tool returned no content."}
# Special handling for audio result (e.g., from Kokoro TTS)
# This checks if the result is a tuple (sample_rate, audio_data_list)
# Gradio MCP server serializes numpy arrays to lists.
if (server_name == "kokoroTTS" and tool_name == "text_to_audio" and
isinstance(actual_result, tuple) and len(actual_result) == 2 and
isinstance(actual_result[0], int) and
(isinstance(actual_result[1], list) or isinstance(actual_result[1], np.ndarray))):
print(f"Received audio data from {server_name}.{tool_name}")
sample_rate, audio_data_list = actual_result
# Convert list to numpy array if necessary
audio_data = np.array(audio_data_list)
# Ensure correct dtype for soundfile (float32 is common, or int16)
# Kokoro returns float, likely in [-1, 1] range.
if audio_data.dtype != np.float32 and audio_data.dtype != np.int16:
# Attempt to normalize if it looks like it's not in [-1, 1] for float
if np.issubdtype(audio_data.dtype, np.floating) and (np.min(audio_data) < -1.1 or np.max(audio_data) > 1.1):
print(f"Warning: Audio data for {server_name}.{tool_name} might not be normalized. Min: {np.min(audio_data)}, Max: {np.max(audio_data)}")
audio_data = audio_data.astype(np.float32)
wav_io = io.BytesIO()
sf.write(wav_io, audio_data, sample_rate, format='WAV')
wav_io.seek(0)
wav_b64 = base64.b64encode(wav_io.read()).decode('utf-8')
return {
"type": "audio_b64",
"data": wav_b64,
"message": f"Audio generated by {server_name}.{tool_name}"
}
# Handle other types of results
if isinstance(actual_result, dict):
return actual_result
elif isinstance(actual_result, str):
try: # If string is JSON, parse to dict
return json.loads(actual_result)
except json.JSONDecodeError:
return {"text": actual_result} # Wrap raw string
else:
return {"value": str(actual_result)} # Fallback for other primitive types
except Exception as e:
print(f"Error calling MCP tool: {e}")
import traceback
traceback.print_exc()
return {"error": f"Error calling MCP tool: {str(e)}"}
def analyze_message_for_tool_call(message, active_mcp_servers, client, model_to_use, system_message):
"""Analyze a message to determine if an MCP tool should be called"""
if not message or not message.strip():
return None, None
tool_info = []
for server_name in active_mcp_servers:
if server_name in mcp_connections:
server_tools_raw = list_mcp_tools(server_name) # This returns a string
if server_tools_raw != "Server not connected" and server_tools_raw != "No tools available for this server":
# Parse the string from list_mcp_tools
for line in server_tools_raw.split("\n"):
if line.startswith("- "):
parts = line[2:].split(":", 1)
if len(parts) == 2:
tool_info.append({
"server_name": server_name,
"tool_name": parts[0].strip(),
"description": parts[1].strip()
})
if not tool_info:
return None, None
tools_desc = []
for info in tool_info:
tools_desc.append(f"{info['server_name']}.{info['tool_name']}: {info['description']}")
tools_string = "\n".join(tools_desc)
analysis_system_prompt = f"""You are an assistant that helps determine if a user message requires using an external tool.
Available tools:
{tools_string}
Your job is to:
1. Analyze the user's message.
2. Determine if they're asking to use one of the tools.
3. If yes, respond ONLY with a JSON object with "server_name", "tool_name", and "parameters".
4. If no, respond ONLY with the exact string "NO_TOOL_NEEDED".
Example 1 (User wants TTS):
User: "Please turn this text into speech: Hello world"
Response: {{"server_name": "kokoroTTS", "tool_name": "text_to_audio", "parameters": {{"text": "Hello world", "speed": 1.0}}}}
Example 2 (User wants TTS with different server name):
User: "Use mySpeechTool to say 'good morning'"
Response: {{"server_name": "mySpeechTool", "tool_name": "text_to_audio", "parameters": {{"text": "good morning"}}}}
Example 3 (User does not want a tool):
User: "What is the capital of France?"
Response: NO_TOOL_NEEDED"""
try:
response = client.chat_completion(
model=model_to_use,
messages=[
{"role": "system", "content": analysis_system_prompt},
{"role": "user", "content": message}
],
temperature=0.1,
max_tokens=300
)
analysis = response.choices[0].message.content.strip()
print(f"Tool analysis LLM response: '{analysis}'")
if analysis == "NO_TOOL_NEEDED":
return None, None
try:
tool_call = json.loads(analysis)
if isinstance(tool_call, dict) and "server_name" in tool_call and "tool_name" in tool_call:
return tool_call.get("server_name"), {
"tool_name": tool_call.get("tool_name"),
"parameters": tool_call.get("parameters", {})
}
else:
print(f"LLM response for tool call was not a valid JSON with required keys: {analysis}")
return None, None
except json.JSONDecodeError:
print(f"Failed to parse tool call JSON from LLM: {analysis}")
return None, None
except Exception as e:
print(f"Error analyzing message for tool calls: {str(e)}")
return None, None
def respond(
message,
image_files,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
frequency_penalty,
seed,
provider,
custom_api_key,
custom_model,
model_search_term,
selected_model,
mcp_enabled=False,
active_mcp_servers=None,
mcp_interaction_mode="Natural Language"
):
print(f"Received message: {message}")
print(f"Received {len(image_files) if image_files else 0} images")
# print(f"History: {history}") # Can be verbose
print(f"System message: {system_message}")
print(f"Max tokens: {max_tokens}, Temperature: {temperature}, Top-P: {top_p}")
print(f"Frequency Penalty: {frequency_penalty}, Seed: {seed}")
print(f"Selected provider: {provider}")
print(f"Custom API Key provided: {bool(custom_api_key.strip())}")
print(f"Selected model (custom_model): {custom_model}")
print(f"Model search term: {model_search_term}")
print(f"Selected model from radio: {selected_model}")
print(f"MCP enabled: {mcp_enabled}")
print(f"Active MCP servers: {active_mcp_servers}")
print(f"MCP interaction mode: {mcp_interaction_mode}")
token_to_use = custom_api_key if custom_api_key.strip() != "" else ACCESS_TOKEN
if custom_api_key.strip() != "":
print("USING CUSTOM API KEY: BYOK token provided by user is being used for authentication")
else:
print("USING DEFAULT API KEY: Environment variable HF_TOKEN is being used for authentication")
client = InferenceClient(token=token_to_use, provider=provider)
print(f"Hugging Face Inference Client initialized with {provider} provider.")
if seed == -1:
seed = None
model_to_use = custom_model.strip() if custom_model.strip() != "" else selected_model
print(f"Model selected for inference: {model_to_use}")
if mcp_enabled and message:
if message.startswith("/mcp"):
command_parts = message.split(" ", 3)
if len(command_parts) < 3:
yield "Invalid MCP command. Format: /mcp <server_name> <tool_name> [arguments_json]"
return
_, server_name, tool_name = command_parts[:3]
args_json = "{}" if len(command_parts) < 4 else command_parts[3]
try:
args_dict = json.loads(args_json)
result = call_mcp_tool(server_name, tool_name, **args_dict)
if isinstance(result, dict) and result.get("type") == "audio_b64":
yield f"<audio controls src=\"data:audio/wav;base64,{result.get('data')}\"></audio>"
elif isinstance(result, dict) and "error" in result:
yield f"Error: {result.get('error')}"
elif isinstance(result, dict):
yield json.dumps(result, indent=2)
else:
yield str(result)
return
except json.JSONDecodeError:
yield f"Invalid JSON arguments: {args_json}"
return
except Exception as e:
yield f"Error executing MCP command: {str(e)}"
return
elif mcp_interaction_mode == "Natural Language" and active_mcp_servers and active_mcp_servers:
print("Attempting natural language tool call detection...")
server_name, tool_info = analyze_message_for_tool_call(
message, active_mcp_servers, client, model_to_use, system_message
)
if server_name and tool_info and tool_info.get("tool_name"):
try:
print(f"Calling tool via natural language: {server_name}.{tool_info['tool_name']} with parameters: {tool_info['parameters']}")
result = call_mcp_tool(server_name, tool_info['tool_name'], **tool_info.get('parameters', {}))
response_message = f"I used the **{tool_info['tool_name']}** tool from **{server_name}**."
if isinstance(result, dict) and result.get("message"):
response_message += f" ({result.get('message')})"
response_message += "\n\n"
if isinstance(result, dict) and result.get("type") == "audio_b64":
audio_html = f"<audio controls src=\"data:audio/wav;base64,{result.get('data')}\"></audio>"
yield response_message + audio_html
elif isinstance(result, dict) and "error" in result:
result_str = f"Tool Error: {result.get('error')}"
yield response_message + result_str
elif isinstance(result, dict):
result_str = f"Result:\n```json\n{json.dumps(result, indent=2)}\n```"
yield response_message + result_str
else:
result_str = f"Result:\n{str(result)}"
yield response_message + result_str
return
except Exception as e:
print(f"Error executing MCP tool via natural language: {str(e)}")
# yield f"Sorry, I encountered an error trying to use the tool: {str(e)}"
# Fall through to normal LLM response if tool call fails here
else:
print("No tool call detected by natural language analysis or tool_info incomplete.")
user_content_parts = []
if message and message.strip():
user_content_parts.append({"type": "text", "text": message})
if image_files and len(image_files) > 0:
for img_path in image_files:
if img_path:
try:
encoded_image = encode_image(img_path)
if encoded_image:
user_content_parts.append({
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}
})
except Exception as e:
print(f"Error encoding image {img_path}: {e}")
if not user_content_parts: # If message was only /mcp command and processed
print("No further content for LLM after MCP command processing.")
# This might happen if an MCP command was fully handled and returned.
# If yield was used, the function already exited. If not, we might need to ensure no LLM call.
# However, the logic above for MCP commands uses `yield ...; return`, so this path might not be hit often.
# If it *is* hit, it means the MCP command didn't yield, and we should not proceed to LLM.
if message and message.startswith("/mcp"):
return # Ensure we don't fall through after a command that should have yielded.
final_user_content = user_content_parts if len(user_content_parts) > 1 else (user_content_parts[0] if user_content_parts else "")
augmented_system_message = system_message
if mcp_enabled and active_mcp_servers:
tool_list_for_prompt = []
for server_name_iter in active_mcp_servers:
if server_name_iter in mcp_connections:
server_tools_str = list_mcp_tools(server_name_iter)
if server_tools_str and "not connected" not in server_tools_str and "No tools available" not in server_tools_str:
tool_list_for_prompt.append(f"From server '{server_name_iter}':\n{server_tools_str}")
if tool_list_for_prompt:
mcp_tools_description = "\n\n".join(tool_list_for_prompt)
if mcp_interaction_mode == "Command Mode":
augmented_system_message += f"\n\nYou have access to the following MCP tools. To use them, type a command in the format: /mcp <server_name> <tool_name> <arguments_json>\nTools:\n{mcp_tools_description}"
else: # Natural Language
augmented_system_message += f"\n\nYou have access to the following MCP tools. You can ask to use them in natural language, and I will try to detect when a tool is needed. If I miss it, you can try being more explicit about the tool name.\nTools:\n{mcp_tools_description}"
messages_for_api = [{"role": "system", "content": augmented_system_message}]
print("Initial messages array constructed.")
for val in history:
past_user_msg, past_assistant_msg = val
# Handle past user messages (could be text or multimodal)
if past_user_msg:
if isinstance(past_user_msg, list): # Already multimodal
messages_for_api.append({"role": "user", "content": past_user_msg})
elif isinstance(past_user_msg, str): # Text only
messages_for_api.append({"role": "user", "content": past_user_msg})
if past_assistant_msg:
messages_for_api.append({"role": "assistant", "content": past_assistant_msg})
if final_user_content: # Add current user message if it exists
messages_for_api.append({"role": "user", "content": final_user_content})
print(f"Latest user message appended (content type: {type(final_user_content)})")
# print(f"Full messages_for_api: {json.dumps(messages_for_api, indent=2)}") # Can be very verbose
llm_response_text = ""
print(f"Sending request to {provider} provider for model {model_to_use}.")
parameters = {
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"frequency_penalty": frequency_penalty,
}
if seed is not None:
parameters["seed"] = seed
try:
stream = client.chat_completion(
model=model_to_use,
messages=messages_for_api,
stream=True,
**parameters
)
# print("Received tokens: ", end="", flush=True) # Can be too noisy
for chunk in stream:
if hasattr(chunk, 'choices') and len(chunk.choices) > 0:
if hasattr(chunk.choices[0], 'delta') and hasattr(chunk.choices[0].delta, 'content'):
token_text = chunk.choices[0].delta.content
if token_text:
# print(token_text, end="", flush=True) # Can be too noisy
llm_response_text += token_text
yield llm_response_text
# print() # Newline after tokens
except Exception as e:
print(f"Error during LLM inference: {e}")
llm_response_text += f"\nLLM Error: {str(e)}"
yield llm_response_text
print("Completed LLM response generation.")
# GRADIO UI
with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
chatbot = gr.Chatbot(
height=600,
show_copy_button=True,
placeholder="Select a model and begin chatting. Supports multiple inference providers, multimodal inputs, and MCP tools.",
layout="panel",
show_label=False,
render=False # Delay rendering
)
print("Chatbot interface created.")
with gr.Row():
msg = gr.MultimodalTextbox(
placeholder="Type a message or upload images...",
show_label=False,
container=False,
scale=12,
file_types=["image"],
file_count="multiple",
sources=["upload"],
render=False # Delay rendering
)
chatbot.render()
msg.render()
with gr.Accordion("Settings", open=False):
system_message_box = gr.Textbox(
value="You are a helpful AI assistant that can understand images and text. If the user asks you to use a tool, try your best.",
placeholder="You are a helpful assistant.",
label="System Prompt"
)
with gr.Row():
with gr.Column(scale=1):
max_tokens_slider = gr.Slider(minimum=1, maximum=8192, value=1024, step=1, label="Max tokens")
temperature_slider = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.01, label="Temperature")
top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.01, label="Top-P")
with gr.Column(scale=1):
frequency_penalty_slider = gr.Slider(minimum=-2.0, maximum=2.0, value=0.0, step=0.1, label="Frequency Penalty")
seed_slider = gr.Slider(minimum=-1, maximum=65535, value=-1, step=1, label="Seed (-1 for random)")
providers_list = ["hf-inference", "cerebras", "together", "sambanova", "novita", "cohere", "fireworks-ai", "hyperbolic", "nebius"]
provider_radio = gr.Radio(choices=providers_list, value="hf-inference", label="Inference Provider")
byok_textbox = gr.Textbox(value="", label="BYOK (Bring Your Own Key)", info="Enter a custom Hugging Face API key here. If empty, only 'hf-inference' provider can be used with the shared token.", placeholder="Enter your Hugging Face API token", type="password")
custom_model_box = gr.Textbox(value="", label="Custom Model ID", info="(Optional) Provide a custom Hugging Face model ID. Overrides selected featured model.", placeholder="meta-llama/Llama-3.1-70B-Instruct")
model_search_box = gr.Textbox(label="Filter Featured Models", placeholder="Search for a featured model...", lines=1)
models_list = [
"meta-llama/Llama-3.1-405B-Instruct-FP8", # Large model, might be slow/expensive
"meta-llama/Llama-3.1-70B-Instruct",
"meta-llama/Llama-3.1-8B-Instruct",
"mistralai/Mistral-Nemo-Instruct-2407",
"Qwen/Qwen2-72B-Instruct",
"Qwen/Qwen2-57B-A14B-Instruct",
"CohereForAI/c4ai-command-r-plus",
# Multimodal models
"Salesforce/LlavaLlama3-8b-hf",
"llava-hf/llava-v1.6-mistral-7b-hf",
"llava-hf/llava-v1.6-vicuna-13b-hf",
"microsoft/Phi-3-vision-128k-instruct",
"google/paligemma-3b-mix-448",
# Older but still popular
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
"mistralai/Mixtral-8x7B-Instruct-v0.1",
"mistralai/Mistral-7B-Instruct-v0.3",
]
featured_model_radio = gr.Radio(label="Select a Featured Model", choices=models_list, value="meta-llama/Llama-3.1-8B-Instruct", interactive=True)
gr.Markdown("[View all Text-to-Text models](https://huggingface.co/models?pipeline_tag=text-generation&sort=trending) | [View all multimodal models](https://huggingface.co/models?pipeline_tag=image-to-text&sort=trending)")
with gr.Accordion("MCP Settings", open=False):
mcp_enabled_checkbox = gr.Checkbox(label="Enable MCP Support", value=False, info="Enable Model Context Protocol support to connect to external tools and services")
with gr.Row():
mcp_server_url = gr.Textbox(label="MCP Server URL", placeholder="https://your-mcp-server.hf.space/gradio_api/mcp/sse", info="URL of the MCP server (usually ends with /gradio_api/mcp/sse for Gradio MCP servers)")
mcp_server_name = gr.Textbox(label="Server Name (Optional)", placeholder="e.g., kokoroTTS", info="A friendly name to identify this server")
mcp_connect_button = gr.Button("Connect to MCP Server")
mcp_status = gr.Textbox(label="MCP Connection Status", placeholder="No MCP servers connected", interactive=False)
active_mcp_servers = gr.Dropdown(label="Active MCP Servers for Chat", choices=[], multiselect=True, info="Select which connected MCP servers to make available to the LLM for this chat session")
mcp_mode = gr.Radio(label="MCP Interaction Mode", choices=["Natural Language", "Command Mode"], value="Natural Language", info="Choose how to interact with MCP tools")
gr.Markdown("""
### MCP Interaction Modes & Examples
**Natural Language Mode**: Describe what you want.
`Please say 'Hello world' using the kokoroTTS server.`
`Use my speech tool to read this: "Welcome"`
**Command Mode**: Use structured commands (server name must match connected server's friendly name).
`/mcp <server_name> <tool_name> {"param1": "value1"}`
Example: `/mcp kokoroTTS text_to_audio {"text": "Hello world", "speed": 1.0}`
""")
# Chat history state
# The chatbot component itself manages history for display.
# The `respond` function receives this display history and reconstructs API history.
def filter_models_ui_update(search_term):
print(f"Filtering models with search term: {search_term}")
filtered = [m for m in models_list if search_term.lower() in m.lower()]
if not filtered: # If search yields no results, show all models
filtered = models_list
print(f"Filtered models: {filtered}")
return gr.Radio(choices=filtered, label="Select a Featured Model", value=featured_model_radio.value if featured_model_radio.value in filtered else (filtered[0] if filtered else None))
def set_custom_model_from_radio_ui_update(selected_featured_model):
print(f"Featured model selected: {selected_featured_model}")
return selected_featured_model # This updates the custom_model_box
def connect_mcp_server_ui_update(url, name_optional):
actual_name, status_msg = connect_to_mcp_server(url, name_optional)
updated_server_choices = list(mcp_connections.keys())
# Keep existing selection if possible
current_selection = active_mcp_servers.value if active_mcp_servers.value else []
valid_selection = [s for s in current_selection if s in updated_server_choices]
if actual_name and actual_name not in valid_selection: # Auto-select newly connected server
valid_selection.append(actual_name)
return status_msg, gr.Dropdown(choices=updated_server_choices, value=valid_selection, label="Active MCP Servers for Chat")
# This function processes the user's multimodal input and adds it to the chatbot history.
# It prepares the history in a way that `bot` can understand.
def handle_user_input(multimodal_input, history_list: list):
text_content = multimodal_input.get("text", "").strip()
files = multimodal_input.get("files", [])
# This will be the entry for the user's turn in the history
user_turn_for_api = []
user_turn_for_display = ""
if text_content:
user_turn_for_api.append({"type": "text", "text": text_content})
user_turn_for_display = text_content
if files:
display_files_md = ""
for file_path in files:
if file_path and isinstance(file_path, str): # Gradio provides temp path
encoded_img = encode_image(file_path) # For API
if encoded_img:
user_turn_for_api.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_img}"}})
# For display, Gradio handles showing the image from MultimodalTextbox output
# We'll just make a note in the display string
display_files_md += f"\n<img src='file={file_path}' style='max-height:150px; display:block;' alt='uploaded image'>" # Gradio can render this!
if user_turn_for_display:
user_turn_for_display += display_files_md
else:
user_turn_for_display = display_files_md if display_files_md else "Image(s) uploaded"
if not user_turn_for_display and not user_turn_for_api: # Empty input
return history_list, multimodal_input # No change
# The `respond` function expects history as list of [user_api_content, assistant_text_content]
# For the current turn, we add [user_api_content, None]
# The display history for chatbot is [user_display_content, assistant_text_content]
# We pass the API-formatted user turn to the `message` arg of `respond`
# and the existing history to the `history` arg.
# The chatbot's display history is updated here.
history_list.append([user_turn_for_display, None])
return history_list, user_turn_for_api # Return updated history and the API formatted current message
# The bot function that calls `respond` generator
def call_bot_responder(history_list_for_display, current_user_api_content, sys_msg, max_tok, temp, top_p_val, freq_pen, seed_val, prov, api_key_val, cust_model, _search, sel_model, mcp_on, active_servs, mcp_inter_mode):
if not current_user_api_content and not (history_list_for_display and history_list_for_display[-1][0]):
print("Bot called with no current message and no history, skipping.")
yield history_list_for_display # No change
return
# Reconstruct API history from display history
# `respond` expects history as list of [user_api_content, assistant_text_content]
# The current `history_list_for_display` is [user_display, assistant_text]
# This reconstruction is tricky because display != api format.
# For simplicity, we'll pass only the text part of history to `respond` for now,
# and the full current_user_api_content for the current message.
# A more robust solution would store API history separately.
# Simplified history for `respond` (text only from past turns)
# The `respond` function itself needs to be robust to handle this.
# Let's adjust `respond` to take `message` (current API content) and `image_files` (current files)
# and `history` (past turns, which we'll simplify here).
# The `respond` function is already structured to take `message` (text) and `image_files`
# The `current_user_api_content` is what we need to pass as `message` (if text) or `image_files`
current_message_text = ""
current_image_paths = []
if isinstance(current_user_api_content, list): # Multimodal
for part in current_user_api_content:
if part["type"] == "text":
current_message_text = part["text"]
elif part["type"] == "image_url":
# We can't easily get back the path from base64 for `respond`'s current design
# This indicates a slight mismatch. `respond` expects paths for current images.
# For now, let's assume `respond` can handle base64 if passed correctly.
# Or, we modify `handle_user_input` to also pass original paths if needed by `respond`.
# Let's assume `respond`'s `image_files` param can take base64 strings for now.
# This is a simplification.
# The `encode_image` in `respond` expects paths.
# For now, we'll pass None for image_files if it's already in current_user_api_content.
# This part needs careful review of how `respond` handles current images.
# The `respond` function's `image_files` parameter is for new uploads.
# If `current_user_api_content` already has encoded images, `respond` should use that.
# The `respond` function's first two args are `message` (text) and `image_files` (paths).
# We need to extract these from `current_user_api_content`.
pass # Images are part of `current_user_api_content` which is passed to `messages_for_api`
elif isinstance(current_user_api_content, str): # Text only
current_message_text = current_user_api_content
# Simplified history for `respond` (text from display)
# `respond` will reconstruct its own API history.
simplified_past_history = []
if len(history_list_for_display) > 1: # Exclude current turn
for user_disp, assistant_text in history_list_for_display[:-1]:
# Extract text from user_disp for simplified history
user_text_for_hist = user_disp
if isinstance(user_disp, str) and "<img src" in user_disp : # Basic check if it was image display
# Try to find text part if any, otherwise empty
lines = user_disp.splitlines()
text_lines = [line for line in lines if not line.strip().startswith("<img")]
user_text_for_hist = "\n".join(text_lines).strip() if text_lines else ""
simplified_past_history.append([user_text_for_hist, assistant_text])
# The `respond` function's first argument is `message` (current text)
# and `image_files` (current image paths).
# We need to extract these from `current_user_api_content` if it was prepared by `handle_user_input`.
# For now, let's assume `respond` will get the full `current_user_api_content` via `messages_for_api`.
# The first two args of `respond` are for the *current* turn's text and image paths.
# Let's get current text and image paths from `current_user_api_content`
# This is slightly redundant as `respond` also reconstructs this, but for clarity:
_current_text_for_respond = ""
_current_image_paths_for_respond = [] # `respond` expects paths
if isinstance(current_user_api_content, list):
for item in current_user_api_content:
if item['type'] == 'text':
_current_text_for_respond = item['text']
# We can't get paths back from base64 easily.
# This highlights that `respond` needs to be able to take already processed multimodal content.
# For now, we'll assume `respond` internally uses the `messages_for_api` which has the full content.
# So, we can pass `_current_text_for_respond` and `None` for image_files if images are already in API format.
bot_response_stream = respond(
message=_current_text_for_respond, # Current text
image_files=None, # Assume images are handled by messages_for_api construction in respond
history=simplified_past_history, # Past turns
system_message=sys_msg,
max_tokens=max_tok,
temperature=temp,
top_p=top_p_val,
frequency_penalty=freq_pen,
seed=seed_val,
provider=prov,
custom_api_key=api_key_val,
custom_model=cust_model,
model_search_term="", # Not directly used by respond
selected_model=sel_model,
mcp_enabled=mcp_on,
active_mcp_servers=active_servs,
mcp_interaction_mode=mcp_inter_mode
)
for response_chunk in bot_response_stream:
history_list_for_display[-1][1] = response_chunk
yield history_list_for_display
# This state will hold the API-formatted content of the current user message
current_api_message_state = gr.State(None)
msg.submit(
handle_user_input,
[msg, chatbot], # chatbot here is the history_list
[chatbot, current_api_message_state] # Update history display and current_api_message_state
).then(
call_bot_responder,
[chatbot, current_api_message_state, system_message_box, max_tokens_slider, temperature_slider, top_p_slider,
frequency_penalty_slider, seed_slider, provider_radio, byok_textbox, custom_model_box,
model_search_box, featured_model_radio, mcp_enabled_checkbox, active_mcp_servers, mcp_mode],
[chatbot] # Update chatbot display with streaming response
).then(
lambda: gr.MultimodalTextbox(value={"text": "", "files": []}), # Clear MultimodalTextbox
None,
[msg]
)
mcp_connect_button.click(
connect_mcp_server_ui_update,
[mcp_server_url, mcp_server_name],
[mcp_status, active_mcp_servers]
)
model_search_box.change(fn=filter_models_ui_update, inputs=model_search_box, outputs=featured_model_radio)
featured_model_radio.change(fn=set_custom_model_from_radio_ui_update, inputs=featured_model_radio, outputs=custom_model_box)
def validate_provider_ui_update(api_key, current_provider):
if not api_key.strip() and current_provider != "hf-inference":
gr.Info("No API key provided. Defaulting to 'hf-inference' provider.")
return gr.Radio(value="hf-inference") # Update provider_radio
return gr.Radio(value=current_provider) # No change needed or keep current
byok_textbox.change(fn=validate_provider_ui_update, inputs=[byok_textbox, provider_radio], outputs=provider_radio)
provider_radio.change(fn=validate_provider_ui_update, inputs=[byok_textbox, provider_radio], outputs=provider_radio)
print("Gradio interface initialized.")
if __name__ == "__main__":
print("Launching the demo application.")
demo.queue().launch(show_api=False, debug=False) # mcp_server=False as this is a client