import gradio as gr from huggingface_hub import InferenceClient # Keep for direct use if needed, though agent will use its own model import os import json import base64 from PIL import Image import io # Smolagents imports from smolagents import CodeAgent, Tool from smolagents.models import InferenceClientModel as SmolInferenceClientModel # We'll use PIL.Image directly for opening, AgentImage is for agent's internal typing if needed by a tool from smolagents.gradio_ui import pull_messages_from_step # For formatting agent steps from smolagents.memory import ActionStep, FinalAnswerStep, PlanningStep, MemoryStep # For type checking steps from smolagents.models import ChatMessageStreamDelta # For type checking stream deltas ACCESS_TOKEN = os.getenv("HF_TOKEN") print("Access token loaded.") # Function to encode image to base64 (remains useful if we ever need to pass base64 to a non-smolagent component) def encode_image(image_path_or_pil): if not image_path_or_pil: print("No image path or PIL Image provided") return None try: # print(f"Encoding image: {type(image_path_or_pil)}") # Debug if isinstance(image_path_or_pil, Image.Image): image = image_path_or_pil else: # Assuming it's a path image = Image.open(image_path_or_pil) if image.mode == 'RGBA': image = image.convert('RGB') buffered = io.BytesIO() image.save(buffered, format="JPEG") # JPEG is generally smaller for transfer img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") # print("Image encoded successfully") # Debug return img_str except Exception as e: print(f"Error encoding image: {e}") return None # This function will now set up and run the smolagent def respond( message_text, # Text from MultimodalTextbox image_file_paths, # List of file paths from MultimodalTextbox gradio_history: list[tuple[str, str]], # Gradio history (for context if needed, agent is stateless per call here) system_message_for_agent, # System prompt for the main LLM agent max_tokens, temperature, top_p, frequency_penalty, seed, provider_for_agent_llm, api_key_for_agent_llm, model_id_for_agent_llm, model_search_term, # Unused directly by agent logic selected_model_for_agent_llm # Fallback model ID ): print(f"Respond function called. Message: '{message_text}', Images: {image_file_paths}") token_to_use = api_key_for_agent_llm if api_key_for_agent_llm.strip() != "" else ACCESS_TOKEN model_to_use = model_id_for_agent_llm.strip() if model_id_for_agent_llm.strip() != "" else selected_model_for_agent_llm # --- Initialize the LLM for the CodeAgent --- agent_llm_params = { "model_id": model_to_use, "token": token_to_use, # smolagents's InferenceClientModel uses max_tokens for max_new_tokens "max_tokens": max_tokens, "temperature": temperature if temperature > 0.01 else None, # Some models require temp > 0 "top_p": top_p if top_p < 1.0 else None, # Often 1.0 means no top_p "seed": seed if seed != -1 else None, } if provider_for_agent_llm and provider_for_agent_llm != "hf-inference": agent_llm_params["provider"] = provider_for_agent_llm # HFIC specific params, add if not default and supported if frequency_penalty != 0.0: agent_llm_params["frequency_penalty"] = frequency_penalty agent_llm = SmolInferenceClientModel(**agent_llm_params) print(f"Smolagents LLM for agent initialized: model='{model_to_use}', provider='{provider_for_agent_llm or 'default'}'") # --- Define Tools for the Agent --- agent_tools = [] try: image_gen_tool = Tool.from_space( space_id="black-forest-labs/FLUX.1-schnell", name="image_generator", description="Generates an image from a textual prompt. Input is a single string argument named 'prompt'. Output is an image file path.", token=token_to_use ) agent_tools.append(image_gen_tool) print("Image generation tool loaded: black-forest-labs/FLUX.1-schnell") except Exception as e: print(f"Error loading image generation tool: {e}") yield f"Error: Could not load image generation tool. {e}" return # --- Initialize the CodeAgent --- # If system_message_for_agent is empty, CodeAgent will use its default. # The default is usually good as it explains how to use tools. agent = CodeAgent( tools=agent_tools, model=agent_llm, system_prompt=system_message_for_agent if system_message_for_agent and system_message_for_agent.strip() else None, # add_base_tools=True, # Consider adding Python interpreter, etc. stream_outputs=True # Important for Gradio streaming ) print("Smolagents CodeAgent initialized.") # --- Prepare task and image inputs for the agent --- agent_task_text = message_text pil_images_for_agent = [] if image_file_paths: for file_path in image_file_paths: try: pil_images_for_agent.append(Image.open(file_path)) except Exception as e: print(f"Error opening image file {file_path} for agent: {e}") print(f"Agent task: '{agent_task_text}'") if pil_images_for_agent: print(f"Passing {len(pil_images_for_agent)} image(s) to agent.") # --- Run the agent and stream response --- # Agent is reset each turn. For conversational memory, agent instance # would need to be stored in session_state and agent.run(..., reset=False) used. current_agent_response_text = "" try: # The agent.run method returns a generator when stream=True for step_item in agent.run( task=agent_task_text, images=pil_images_for_agent, stream=True, reset=True # Explicitly reset for stateless operation per call ): if isinstance(step_item, ChatMessageStreamDelta): if step_item.content: current_agent_response_text += step_item.content yield current_agent_response_text # Yield accumulated text elif isinstance(step_item, (ActionStep, PlanningStep, FinalAnswerStep)): # A structured step. Format it for Gradio. # pull_messages_from_step yields gr.ChatMessage objects. for gradio_chat_msg in pull_messages_from_step(step_item, skip_model_outputs=agent.stream_outputs): # The 'bot' function will handle these gr.ChatMessage objects. yield gradio_chat_msg # Yield the gr.ChatMessage object directly current_agent_response_text = "" # Reset text buffer after a structured step # else: # print(f"Unhandled stream item type: {type(step_item)}") # Debug # If there's any remaining text not part of a gr.ChatMessage, yield it. # This usually shouldn't happen if stream_to_gradio logic is followed, # as text deltas should be part of the last gr.ChatMessage or yielded before it. # However, if the agent's final textual answer comes as pure deltas after all steps. if current_agent_response_text and not isinstance(step_item, FinalAnswerStep): # Check if the last yielded item already contains this text if not (isinstance(step_item, gr.ChatMessage) and step_item.content == current_agent_response_text): yield current_agent_response_text except Exception as e: error_message = f"Error during agent execution: {str(e)}" print(error_message) yield error_message # Yield the error message to be displayed in UI print("Agent run completed.") # Function to validate provider selection based on BYOK def validate_provider(api_key, provider): if not api_key.strip() and provider != "hf-inference": return gr.update(value="hf-inference") return gr.update(value=provider) # GRADIO UI with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo: chatbot = gr.Chatbot( height=600, show_copy_button=True, placeholder="Select a model and begin chatting. Now uses smolagents with tools!", layout="panel", bubble_full_width=False # For better display of images/files ) print("Chatbot interface created.") msg = gr.MultimodalTextbox( placeholder="Type a message or upload images...", show_label=False, container=False, scale=12, file_types=["image"], file_count="multiple", sources=["upload"] ) with gr.Accordion("Settings", open=False): system_message_box = gr.Textbox( value="You are a helpful AI assistant. You can generate images if asked. Be precise with your prompts for image generation.", placeholder="You are a helpful AI assistant.", label="System Prompt for Agent" ) with gr.Row(): with gr.Column(): max_tokens_slider = gr.Slider(minimum=1, maximum=4096, value=1024, step=1, label="Max New Tokens") temperature_slider = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.01, label="Temperature") top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.01, label="Top-P") with gr.Column(): frequency_penalty_slider = gr.Slider(minimum=-2.0, maximum=2.0, value=0.0, step=0.1, label="Frequency Penalty") seed_slider = gr.Slider(minimum=-1, maximum=65535, value=-1, step=1, label="Seed (-1 for random)") providers_list = [ "hf-inference", "cerebras", "together", "sambanova", "novita", "cohere", "fireworks-ai", "hyperbolic", "nebius", ] provider_radio = gr.Radio(choices=providers_list, value="hf-inference", label="Inference Provider for Agent's LLM") byok_textbox = gr.Textbox(value="", label="BYOK (Your HF Token or Provider API Key)", info="Enter API key for the selected provider. Uses HF_TOKEN if empty.", placeholder="Enter your API token", type="password") custom_model_box = gr.Textbox(value="", label="Custom Model ID for Agent's LLM", info="(Optional) Provide a custom model ID. Overrides featured model.", placeholder="meta-llama/Llama-3.3-70B-Instruct") model_search_box = gr.Textbox(label="Filter Featured Models", placeholder="Search for a featured model...", lines=1) models_list = [ "meta-llama/Llama-3.3-70B-Instruct", "meta-llama/Llama-3.1-70B-Instruct", "meta-llama/Llama-3.0-70B-Instruct", "meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.2-3B-Instruct", "meta-llama/Llama-3.2-1B-Instruct", "meta-llama/Llama-3.1-8B-Instruct", "NousResearch/Hermes-3-Llama-3.1-8B", "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", "mistralai/Mistral-Nemo-Instruct-2407", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.3", "Qwen/Qwen3-235B-A22B", "Qwen/Qwen3-32B", "Qwen/Qwen2.5-72B-Instruct", "Qwen/Qwen2.5-3B-Instruct", "Qwen/Qwen2.5-Coder-32B-Instruct", "microsoft/Phi-3.5-mini-instruct", "microsoft/Phi-3-mini-128k-instruct", ] featured_model_radio = gr.Radio(label="Select a Featured Model for Agent's LLM", choices=models_list, value="meta-llama/Llama-3.3-70B-Instruct", interactive=True) gr.Markdown("[View all Text-to-Text models](https://huggingface.co/models?inference_provider=all&pipeline_tag=text-generation&sort=trending) | [View all multimodal models](https://huggingface.co/models?inference_provider=all&pipeline_tag=image-text-to-text&sort=trending)") # Chat history state (using gr.State to manage it properly) # The chatbot's value itself will be the history display. # We might need a separate gr.State if agent needs to be conversational across turns. # For now, agent is stateless per turn. # Function for the chat interface def user(user_multimodal_input_dict, history): print(f"User input: {user_multimodal_input_dict}") text_content = user_multimodal_input_dict.get("text", "") files = user_multimodal_input_dict.get("files", []) user_display_parts = [] if text_content and text_content.strip(): user_display_parts.append(text_content) for file_path_obj in files: # file_path_obj is a tempfile._TemporaryFileWrapper user_display_parts.append((file_path_obj.name, os.path.basename(file_path_obj.name))) if not user_display_parts: return history # Append the user's multimodal message to history for display # The actual data (dict) is passed to `bot` function separately. history.append([user_display_parts if len(user_display_parts) > 1 else user_display_parts[0], None]) return history def bot(history, system_msg, max_tokens, temperature, top_p, freq_penalty, seed, provider, api_key, custom_model, search_term, selected_model): if not history or not history[-1][0]: # If no user input yield history return # The user's input (text and list of file paths) is in history[-1][0] # If `user` function stores the dict: raw_user_input_dict = history[-1][0] if isinstance(history[-1][0], dict) else {"text": str(history[-1][0]), "files": []} # If `user` function stores formatted display parts: # We need to reconstruct or rely on msg input to bot. # For now, assuming msg.submit passes the raw dict. # Let's adjust the Gradio flow to pass `msg` directly to `bot` as well. # The `msg` variable in `msg.submit` holds the raw MultimodalTextbox output. # We need to pass this raw dict to `respond`. # The `history` is for display. # This part is tricky as `bot` gets `history` which is already formatted for display. # A common pattern is to pass `msg` (raw input) also to `bot`. # Let's assume `history[-1][0]` contains enough info or we adjust `user` fn. # For simplicity, let's assume `user` stores the raw dict if needed, # or `bot` can parse `history[-1][0]` if it's a string/list of tuples. # Let's assume `history[-1][0]` is the raw `user_multimodal_input_dict` # This means the `user` function must append it like: `history.append([user_multimodal_input_dict, None])` # And the chatbot will display `str(user_multimodal_input_dict)`. # This is what the current `user` function does. user_input_data = history[-1][0] # This should be the dict from MultimodalTextbox text_input_for_agent = user_input_data.get("text", "") # Files from MultimodalTextbox are temp file paths image_file_paths_for_agent = [f.name for f in user_input_data.get("files", []) if hasattr(f, 'name')] history[-1][1] = "" # Initialize assistant's part for streaming # Buffer for current text stream from agent # Handles both pure text deltas and text content from gr.ChatMessage current_text_for_turn = "" for item in respond( message_text=text_input_for_agent, image_file_paths=image_file_paths_for_agent, gradio_history=history[:-1], # Pass previous turns for context if agent uses it system_message_for_agent=system_msg, max_tokens=max_tokens, temperature=temperature, top_p=top_p, frequency_penalty=freq_penalty, seed=seed, provider_for_agent_llm=provider, api_key_for_agent_llm=api_key, model_id_for_agent_llm=custom_model, model_search_term=search_term, # unused selected_model_for_agent_llm=selected_model ): if isinstance(item, str): # LLM text delta from agent's thought or textual answer current_text_for_turn = item history[-1][1] = current_text_for_turn elif isinstance(item, gr.ChatMessage): # This is a structured step (thought, tool output, image, etc.) # We need to append this to the history as a new message or part of current message. # For simplicity, let's append its string content to the current turn's assistant message. # If it's an image/file, we'll represent it as a markdown link. if isinstance(item.content, str): current_text_for_turn = item.content # Replace if it's a full message elif isinstance(item.content, dict) and "path" in item.content: # This is typically an image or audio file file_path = item.content["path"] # We need to make this file accessible to Gradio if it's temporary from agent # For now, just put a placeholder. # If it's an output from a tool, the path might be relative to where smolagents saves it. # Gradio needs an absolute path or a URL. # A common pattern is to copy temp files to a static dir served by Gradio or use gr.File. # For now, let's assume Gradio can handle local paths if they are in a folder it knows. # We'll display it as a tuple for Gradio Chatbot. # This means history[-1][1] needs to become a list. # If current_text_for_turn is not empty, make history[-1][1] a list if current_text_for_turn and not isinstance(history[-1][1], list): history[-1][1] = [current_text_for_turn] elif not current_text_for_turn and not isinstance(history[-1][1], list): history[-1][1] = [] alt_text = item.metadata.get("title", os.path.basename(file_path)) if item.metadata else os.path.basename(file_path) # Add as new component to the list for current assistant message if isinstance(history[-1][1], list): history[-1][1].append((file_path, alt_text)) else: # Should have been made a list above history[-1][1] = [(file_path, alt_text)] current_text_for_turn = "" # Reset text buffer after a file # If it's not a delta, but a full message, replace the current text if not isinstance(history[-1][1], list): # if it hasn't become a list due to file history[-1][1] = current_text_for_turn yield history # Event handlers # `msg.submit`'s first argument is the function to call. # Its `inputs` are the Gradio components whose values are passed to the function. # Its `outputs` are the Gradio components that are updated by the function's return value. # The `user` function now appends the raw dict from MultimodalTextbox to history. # The `bot` function takes this history. # When msg is submitted: # 1. Call `user` to update history with user's input. Output is `chatbot`. # 2. Then call `bot` with the updated history. Output is `chatbot`. # 3. Then clear `msg` msg.submit( user, [msg, chatbot], [chatbot], # `user` returns the new history, updating the chatbot display queue=False ).then( bot, [chatbot, system_message_box, max_tokens_slider, temperature_slider, top_p_slider, frequency_penalty_slider, seed_slider, provider_radio, byok_textbox, custom_model_box, model_search_box, featured_model_radio], [chatbot] # `bot` yields history updates, streaming to chatbot ).then( lambda: {"text": "", "files": []}, # Clear MultimodalTextbox None, [msg] ) model_search_box.change(fn=filter_models, inputs=model_search_box, outputs=featured_model_radio) featured_model_radio.change(fn=set_custom_model_from_radio, inputs=featured_model_radio, outputs=custom_model_box) byok_textbox.change(fn=validate_provider, inputs=[byok_textbox, provider_radio], outputs=provider_radio) provider_radio.change(fn=validate_provider, inputs=[byok_textbox, provider_radio], outputs=provider_radio) print("Gradio interface initialized.") if __name__ == "__main__": print("Launching the demo application.") demo.launch(show_api=False) # show_api=False for cleaner launch, True for API docs