Nymbo's picture
Update app.py
4db9e4f verified
raw
history blame
23.3 kB
import gradio as gr
from huggingface_hub import InferenceClient
import os
import json
import base64
from PIL import Image
import io
# Import smolagents components
from smolagents import CodeAgent, Tool
from smolagents.models import InferenceClientModel as SmolInferenceClientModel # Alias to avoid conflict
ACCESS_TOKEN = os.getenv("HF_TOKEN")
print("Access token loaded.")
# --- Smolagents Setup for Image Generation ---
print("Initializing smolagents components for image generation...")
try:
image_generation_tool = Tool.from_space(
"black-forest-labs/FLUX.1-schnell", # The Space ID of the image generation tool
name="image_generator",
description="Generates an image from a textual prompt. Use this tool if the user asks to 'generate an image of X', 'draw X', 'create a picture of X', or similar requests for visual content based on a description.",
# Ensure the HF_TOKEN is available to gradio-client if the space is private or requires auth
token=ACCESS_TOKEN if ACCESS_TOKEN and ACCESS_TOKEN.strip() != "" else None
)
print("Image generation tool loaded successfully.")
# Initialize a model for the CodeAgent. This can be a simpler/faster model
# as it's mainly for orchestrating the tool call.
# Using a default InferenceClientModel from smolagents
smol_agent_model = SmolInferenceClientModel(token=ACCESS_TOKEN if ACCESS_TOKEN and ACCESS_TOKEN.strip() != "" else None)
print(f"Smolagent model initialized with: {smol_agent_model.model_id if hasattr(smol_agent_model, 'model_id') else 'default'}")
image_agent = CodeAgent(
tools=[image_generation_tool],
model=smol_agent_model,
verbosity_level=1 # Set to 0 for less verbose agent logging, 1 for info, 2 for debug
)
print("Image generation agent initialized successfully.")
except Exception as e:
print(f"Error initializing smolagents components: {e}")
image_agent = None
# --- End Smolagents Setup ---
# Function to encode image to base64
def encode_image(image_path):
if not image_path:
print("No image path provided")
return None
try:
print(f"Encoding image from path: {image_path}")
# If it's already a PIL Image
if isinstance(image_path, Image.Image):
image = image_path
else:
# Try to open the image file
image = Image.open(image_path)
# Convert to RGB if image has an alpha channel (RGBA)
if image.mode == 'RGBA':
image = image.convert('RGB')
# Encode to base64
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
print("Image encoded successfully")
return img_str
except Exception as e:
print(f"Error encoding image: {e}")
return None
def respond(
message,
image_files, # Changed parameter name and structure
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
frequency_penalty,
seed,
provider,
custom_api_key,
custom_model,
model_search_term,
selected_model
):
print(f"Received message: {message}")
print(f"Received {len(image_files) if image_files else 0} images")
print(f"History: {history}")
print(f"System message: {system_message}")
print(f"Max tokens: {max_tokens}, Temperature: {temperature}, Top-P: {top_p}")
print(f"Frequency Penalty: {frequency_penalty}, Seed: {seed}")
print(f"Selected provider: {provider}")
print(f"Custom API Key provided: {bool(custom_api_key.strip())}")
print(f"Selected model (custom_model): {custom_model}")
print(f"Model search term: {model_search_term}")
print(f"Selected model from radio: {selected_model}")
# --- Agent-based Image Generation ---
if message.startswith("/generate_image"):
if image_agent is None:
yield "Image generation agent is not initialized. Please check server logs."
return
prompt_for_agent = message.replace("/generate_image", "").strip()
if not prompt_for_agent:
yield "Please provide a prompt for image generation. Usage: /generate_image <your prompt>"
return
print(f"Image generation requested with prompt: {prompt_for_agent}")
try:
# Agent run is blocking and returns the final result
# Ensure the image_agent's model also has a token if needed for its operations (though it's for orchestration)
agent_response = image_agent.run(prompt_for_agent)
if isinstance(agent_response, str) and agent_response.lower().startswith("error"):
yield f"Agent error: {agent_response}"
elif hasattr(agent_response, 'to_string'): # Check if it's an AgentImage or similar
image_path = agent_response.to_string() # This is a local path to the generated image
print(f"Agent returned image path: {image_path}")
# Gradio's chatbot can display images if the content is a file path string
# or a tuple (filepath, alt_text)
yield image_path
else:
yield f"Agent returned an unexpected response: {str(agent_response)}"
return
except Exception as e:
print(f"Error running image agent: {e}")
yield f"Error generating image: {str(e)}"
return
# --- End Agent-based Image Generation ---
# Determine which token to use for text generation
token_to_use = custom_api_key if custom_api_key.strip() != "" else ACCESS_TOKEN
if custom_api_key.strip() != "":
print("USING CUSTOM API KEY: BYOK token provided by user is being used for authentication")
else:
print("USING DEFAULT API KEY: Environment variable HF_TOKEN is being used for authentication")
# Initialize the Inference Client with the provider and appropriate token
client = InferenceClient(token=token_to_use, provider=provider)
print(f"Hugging Face Inference Client initialized with {provider} provider for text generation.")
# Convert seed to None if -1 (meaning random)
if seed == -1:
seed = None
# Create multimodal content if images are present
if image_files and len(image_files) > 0:
user_content = []
if message and message.strip():
user_content.append({
"type": "text",
"text": message
})
for img_path in image_files: # Assuming image_files contains paths from MultimodalTextbox
if img_path is not None:
try:
encoded_image = encode_image(img_path) # img_path is already a path from MultimodalTextbox
if encoded_image:
user_content.append({
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{encoded_image}"
}
})
except Exception as e:
print(f"Error encoding image: {e}")
else:
# Text-only message
user_content = message
# Prepare messages in the format expected by the API
messages = [{"role": "system", "content": system_message}]
print("Initial messages array constructed.")
# Add conversation history to the context
for val in history:
user_part = val[0]
assistant_part = val[1]
# Handle user messages (could be text or image markdown)
if user_part:
if isinstance(user_part, str) and user_part.startswith("![Image]("):
# This is an image path from a previous agent generation
# or a user upload represented as markdown
history_image_path = user_part.replace("![Image](", "").replace(")", "")
encoded_history_image = encode_image(history_image_path)
if encoded_history_image:
messages.append({"role": "user", "content": [{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{encoded_history_image}"}
}]})
elif isinstance(user_part, tuple) and len(user_part) == 2: # Multimodal input from user
history_content_list = []
if user_part[0]: # Text part
history_content_list.append({"type": "text", "text": user_part[0]})
for img_hist_path in user_part[1]: # List of image paths
encoded_img_hist = encode_image(img_hist_path)
if encoded_img_hist:
history_content_list.append({
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{encoded_img_hist}"}
})
if history_content_list:
messages.append({"role": "user", "content": history_content_list})
else: # Regular text message
messages.append({"role": "user", "content": user_part})
print(f"Added user message to context (type: {type(user_part)})")
if assistant_part:
messages.append({"role": "assistant", "content": assistant_part})
print(f"Added assistant message to context: {assistant_part}")
# Append the latest user message
messages.append({"role": "user", "content": user_content})
print(f"Latest user message appended (content type: {type(user_content)})")
# Determine which model to use, prioritizing custom_model if provided
model_to_use = custom_model.strip() if custom_model.strip() != "" else selected_model
print(f"Model selected for inference: {model_to_use}")
# Start with an empty string to build the response as tokens stream in
response = ""
print(f"Sending request to {provider} provider.")
# Prepare parameters for the chat completion request
parameters = {
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"frequency_penalty": frequency_penalty,
}
if seed is not None:
parameters["seed"] = seed
# Use the InferenceClient for making the request
try:
# Create a generator for the streaming response
stream = client.chat_completion(
model=model_to_use,
messages=messages,
stream=True,
**parameters
)
print("Received tokens: ", end="", flush=True)
# Process the streaming response
for chunk in stream:
if hasattr(chunk, 'choices') and len(chunk.choices) > 0:
# Extract the content from the response
if hasattr(chunk.choices[0], 'delta') and hasattr(chunk.choices[0].delta, 'content'):
token_text = chunk.choices[0].delta.content
if token_text:
print(token_text, end="", flush=True)
response += token_text
yield response
print()
except Exception as e:
print(f"Error during inference: {e}")
response += f"\nError: {str(e)}"
yield response
print("Completed response generation.")
# Function to validate provider selection based on BYOK
def validate_provider(api_key, provider):
if not api_key.strip() and provider != "hf-inference":
return gr.update(value="hf-inference")
return gr.update(value=provider)
# GRADIO UI
with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
# Create the chatbot component
chatbot = gr.Chatbot(
height=600,
show_copy_button=True,
placeholder="Select a model and begin chatting. Use '/generate_image your prompt' to create images.",
layout="panel",
show_share_button=True # Added for ease of sharing if deployed
)
print("Chatbot interface created.")
# Multimodal textbox for messages (combines text and file uploads)
msg = gr.MultimodalTextbox(
placeholder="Type a message or upload images... (e.g., /generate_image a cat wearing a hat)",
show_label=False,
container=False,
scale=12,
file_types=["image"],
file_count="multiple",
sources=["upload"]
)
# Create accordion for settings
with gr.Accordion("Settings", open=False):
# System message
system_message_box = gr.Textbox(
value="You are a helpful AI assistant that can understand images and text. If asked to generate an image, use the image_generator tool.",
placeholder="You are a helpful assistant.",
label="System Prompt"
)
# Generation parameters
with gr.Row():
with gr.Column():
max_tokens_slider = gr.Slider(
minimum=1,
maximum=4096,
value=512,
step=1,
label="Max tokens"
)
temperature_slider = gr.Slider(
minimum=0.1,
maximum=4.0,
value=0.7,
step=0.1,
label="Temperature"
)
top_p_slider = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-P"
)
with gr.Column():
frequency_penalty_slider = gr.Slider(
minimum=-2.0,
maximum=2.0,
value=0.0,
step=0.1,
label="Frequency Penalty"
)
seed_slider = gr.Slider(
minimum=-1,
maximum=65535,
value=-1,
step=1,
label="Seed (-1 for random)"
)
# Provider selection
providers_list = [
"hf-inference", # Default Hugging Face Inference
"cerebras", # Cerebras provider
"together", # Together AI
"sambanova", # SambaNova
"novita", # Novita AI
"cohere", # Cohere
"fireworks-ai", # Fireworks AI
"hyperbolic", # Hyperbolic
"nebius", # Nebius
]
provider_radio = gr.Radio(
choices=providers_list,
value="hf-inference",
label="Inference Provider",
)
# New BYOK textbox
byok_textbox = gr.Textbox(
value="",
label="BYOK (Bring Your Own Key)",
info="Enter a custom Hugging Face API key here. When empty, only 'hf-inference' provider can be used.",
placeholder="Enter your Hugging Face API token",
type="password" # Hide the API key for security
)
# Custom model box
custom_model_box = gr.Textbox(
value="",
label="Custom Model",
info="(Optional) Provide a custom Hugging Face model path. Overrides any selected featured model.",
placeholder="meta-llama/Llama-3.3-70B-Instruct"
)
# Model search
model_search_box = gr.Textbox(
label="Filter Models",
placeholder="Search for a featured model...",
lines=1
)
# Featured models list
# Updated to include multimodal models
models_list = [
"meta-llama/Llama-3.2-11B-Vision-Instruct",
"meta-llama/Llama-3.3-70B-Instruct",
"meta-llama/Llama-3.1-70B-Instruct",
"meta-llama/Llama-3.0-70B-Instruct",
"meta-llama/Llama-3.2-3B-Instruct",
"meta-llama/Llama-3.2-1B-Instruct",
"meta-llama/Llama-3.1-8B-Instruct",
"NousResearch/Hermes-3-Llama-3.1-8B",
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
"mistralai/Mistral-Nemo-Instruct-2407",
"mistralai/Mixtral-8x7B-Instruct-v0.1",
"mistralai/Mistral-7B-Instruct-v0.3",
"mistralai/Mistral-7B-Instruct-v0.2",
"Qwen/Qwen3-235B-A22B",
"Qwen/Qwen3-32B",
"Qwen/Qwen2.5-72B-Instruct",
"Qwen/Qwen2.5-3B-Instruct",
"Qwen/Qwen2.5-0.5B-Instruct",
"Qwen/QwQ-32B",
"Qwen/Qwen2.5-Coder-32B-Instruct",
"microsoft/Phi-3.5-mini-instruct",
"microsoft/Phi-3-mini-128k-instruct",
"microsoft/Phi-3-mini-4k-instruct",
]
featured_model_radio = gr.Radio(
label="Select a model below",
choices=models_list,
value="meta-llama/Llama-3.2-11B-Vision-Instruct", # Default to a multimodal model
interactive=True
)
gr.Markdown("[View all Text-to-Text models](https://huggingface.co/models?inference_provider=all&pipeline_tag=text-generation&sort=trending) | [View all multimodal models](https://huggingface.co/models?inference_provider=all&pipeline_tag=image-text-to-text&sort=trending)")
# Chat history state
chat_history = gr.State([])
# Function to filter models
def filter_models(search_term):
print(f"Filtering models with search term: {search_term}")
filtered = [m for m in models_list if search_term.lower() in m.lower()]
print(f"Filtered models: {filtered}")
return gr.update(choices=filtered if filtered else models_list, value=filtered[0] if filtered else models_list[0])
# Function to set custom model from radio
def set_custom_model_from_radio(selected):
print(f"Featured model selected: {selected}")
return selected
# Function for the chat interface
def user(user_message_obj, history):
print(f"User message object received: {user_message_obj}")
text_content = user_message_obj.get("text", "").strip()
files = user_message_obj.get("files", []) # files is a list of temp file paths
if not text_content and not files:
print("Empty message (no text, no files), skipping history update.")
return history # Or raise gr.Error("Please enter a message or upload an image.")
# Represent uploaded images in history using markdown syntax for local paths
# For multimodal models, the actual file path from 'files' will be used in 'respond'
display_message_parts = []
if text_content:
display_message_parts.append(text_content)
processed_files_for_history = []
if files:
for file_path_obj in files:
# Gradio's MultimodalTextbox provides file objects with a .name attribute for the path
file_path = file_path_obj.name if hasattr(file_path_obj, 'name') else str(file_path_obj)
display_message_parts.append(f"![Uploaded Image]({file_path})")
processed_files_for_history.append(file_path) # Store the actual path for 'respond'
# For history, we store the text and a list of file paths
# The 'respond' function will then re-encode these for the API
history_entry_user = (text_content, processed_files_for_history)
history.append([history_entry_user, None])
print(f"History updated with user input: {history_entry_user}")
return history
# Define bot response function
def bot(history, system_msg, max_tokens, temperature, top_p, freq_penalty, seed, provider, api_key, custom_model, search_term, selected_model):
if not history or len(history) == 0 or history[-1][0] is None:
print("No user message in history to process for bot.")
yield history
return
user_input_tuple = history[-1][0] # This is now (text, [file_paths])
text_message_from_history = user_input_tuple[0]
image_files_from_history = user_input_tuple[1]
print(f"Bot processing: text='{text_message_from_history}', images={image_files_from_history}")
history[-1][1] = ""
# Pass text and image file paths to respond function
for response_chunk in respond(
message=text_message_from_history,
image_files=image_files_from_history,
history=history[:-1], # Pass history excluding the current user turn
system_message=system_msg,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
frequency_penalty=freq_penalty,
seed=seed,
provider=provider,
custom_api_key=api_key,
custom_model=custom_model,
model_search_term=search_term,
selected_model=selected_model
):
history[-1][1] = response_chunk
yield history
# Event handlers
msg.submit(
user,
[msg, chatbot], # msg is MultimodalTextboxOutput(text=str, files=List[FileData])
[chatbot],
queue=False
).then(
bot,
[chatbot, system_message_box, max_tokens_slider, temperature_slider, top_p_slider,
frequency_penalty_slider, seed_slider, provider_radio, byok_textbox, custom_model_box,
model_search_box, featured_model_radio],
[chatbot]
).then(
lambda: gr.update(value={"text": "", "files": []}), # Clear MultimodalTextbox
None,
[msg]
)
# Connect the model filter to update the radio choices
model_search_box.change(
fn=filter_models,
inputs=model_search_box,
outputs=featured_model_radio
)
print("Model search box change event linked.")
# Connect the featured model radio to update the custom model box
featured_model_radio.change(
fn=set_custom_model_from_radio,
inputs=featured_model_radio,
outputs=custom_model_box
)
print("Featured model radio button change event linked.")
# Connect the BYOK textbox to validate provider selection
byok_textbox.change(
fn=validate_provider,
inputs=[byok_textbox, provider_radio],
outputs=provider_radio
)
print("BYOK textbox change event linked.")
# Also validate provider when the radio changes to ensure consistency
provider_radio.change(
fn=validate_provider,
inputs=[byok_textbox, provider_radio],
outputs=provider_radio
)
print("Provider radio button change event linked.")
print("Gradio interface initialized.")
if __name__ == "__main__":
print("Launching the demo application.")
demo.launch(show_api=False) # show_api=False for cleaner public interface, True for debugging