Spaces:
Running
Running
import gradio as gr | |
from huggingface_hub import InferenceClient | |
import os | |
import json | |
import base64 | |
from PIL import Image | |
import io | |
# Load the default access token from environment variable at startup | |
# This will be used if no custom key is provided by the user. | |
ACCESS_TOKEN = os.getenv("HF_TOKEN") | |
print(f"Default HF_TOKEN from environment loaded: {'Present' if ACCESS_TOKEN else 'Not set'}") | |
# Function to encode image to base64 | |
def encode_image(image_path): | |
if not image_path: | |
print("No image path provided") | |
return None | |
try: | |
print(f"Encoding image from path: {image_path}") | |
if isinstance(image_path, Image.Image): | |
image = image_path | |
else: | |
image = Image.open(image_path) | |
if image.mode == 'RGBA': | |
image = image.convert('RGB') | |
buffered = io.BytesIO() | |
image.save(buffered, format="JPEG") | |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
print("Image encoded successfully") | |
return img_str | |
except Exception as e: | |
print(f"Error encoding image: {e}") | |
return None | |
def respond( | |
message, | |
image_files, | |
history: list[tuple[str, str]], | |
system_message, | |
max_tokens, | |
temperature, | |
top_p, | |
frequency_penalty, | |
seed, | |
provider, | |
custom_api_key, # This is the value from the BYOK textbox | |
custom_model, | |
model_search_term, | |
selected_model | |
): | |
print(f"Received message: {message}") | |
print(f"Received {len(image_files) if image_files else 0} images") | |
# print(f"History: {history}") # Can be very verbose | |
print(f"System message: {system_message}") | |
print(f"Max tokens: {max_tokens}, Temperature: {temperature}, Top-P: {top_p}") | |
print(f"Frequency Penalty: {frequency_penalty}, Seed: {seed}") | |
print(f"Selected provider: {provider}") | |
print(f"Custom API Key input field value (raw): '{custom_api_key[:10]}...' (masked if long)") | |
print(f"Selected model (custom_model input field): {custom_model}") | |
print(f"Model search term: {model_search_term}") | |
print(f"Selected model from radio: {selected_model}") | |
token_to_use = None | |
original_hf_token_env_value = os.environ.get("HF_TOKEN") | |
env_hf_token_temporarily_modified = False | |
if custom_api_key and custom_api_key.strip(): | |
token_to_use = custom_api_key.strip() | |
print(f"USING CUSTOM API KEY (BYOK): '{token_to_use[:5]}...' (masked for security).") | |
# Aggressively ensure custom key is fundamental: | |
# Temporarily remove HF_TOKEN from os.environ if it exists, | |
# to prevent any possibility of InferenceClient picking it up. | |
if "HF_TOKEN" in os.environ: | |
print(f"Temporarily unsetting HF_TOKEN from environment (was: {'Present' if os.environ.get('HF_TOKEN') else 'Not set'}) to prioritize custom key.") | |
del os.environ["HF_TOKEN"] | |
env_hf_token_temporarily_modified = True | |
elif ACCESS_TOKEN: # Use default token from environment if no custom key | |
token_to_use = ACCESS_TOKEN | |
print(f"USING DEFAULT API KEY (HF_TOKEN from environment variable at script start): '{token_to_use[:5]}...' (masked for security).") | |
# Ensure HF_TOKEN is set in the current env if it was loaded at start | |
# This handles cases where it might have been unset by a previous call with a custom key | |
if original_hf_token_env_value is not None: | |
os.environ["HF_TOKEN"] = original_hf_token_env_value | |
elif "HF_TOKEN" in os.environ: # If ACCESS_TOKEN was loaded but original_hf_token_env_value was None (e.g. set by other means) | |
pass # Let it be whatever it is | |
else: | |
print("No custom API key provided AND no default HF_TOKEN was found in environment at script start.") | |
print("InferenceClient will be initialized without an explicit token. May fail or use public access.") | |
# token_to_use remains None | |
# If HF_TOKEN was in env and we want to ensure it's not used when token_to_use is None: | |
if "HF_TOKEN" in os.environ: | |
print(f"Temporarily unsetting HF_TOKEN from environment (was: {'Present' if os.environ.get('HF_TOKEN') else 'Not set'}) as no valid key is chosen.") | |
del os.environ["HF_TOKEN"] | |
env_hf_token_temporarily_modified = True # Mark for restoration | |
print(f"Final token being passed to InferenceClient: '{str(token_to_use)[:5]}...' (masked)" if token_to_use else "None") | |
try: | |
client = InferenceClient(token=token_to_use, provider=provider) | |
print(f"Hugging Face Inference Client initialized with {provider} provider.") | |
if seed == -1: | |
seed = None | |
user_content = [] | |
if message and message.strip(): | |
user_content.append({"type": "text", "text": message}) | |
if image_files: | |
for img_path in image_files: | |
if img_path: | |
encoded_image = encode_image(img_path) | |
if encoded_image: | |
user_content.append({ | |
"type": "image_url", | |
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"} | |
}) | |
if not user_content: # If only images were sent and none encoded, or empty message | |
if image_files: # If there were image files, it implies an image-only message | |
user_content = [{"type": "text", "text": ""}] # Send an empty text for context, or specific prompt | |
else: # Truly empty input | |
yield "Error: Empty message content." | |
return | |
messages = [{"role": "system", "content": system_message}] | |
for val in history: | |
user_part, assistant_part = val | |
# Handle multimodal history if necessary (simplified for now) | |
if isinstance(user_part, dict) and 'files' in user_part: # from MultimodalTextbox | |
history_text = user_part.get("text", "") | |
history_files = user_part.get("files", []) | |
current_user_content_history = [] | |
if history_text: | |
current_user_content_history.append({"type": "text", "text": history_text}) | |
for h_img_path in history_files: | |
encoded_h_img = encode_image(h_img_path) | |
if encoded_h_img: | |
current_user_content_history.append({ | |
"type": "image_url", | |
"image_url": {"url": f"data:image/jpeg;base64,{encoded_h_img}"} | |
}) | |
if current_user_content_history: | |
messages.append({"role": "user", "content": current_user_content_history}) | |
elif isinstance(user_part, str): # from simple text history | |
messages.append({"role": "user", "content": user_part}) | |
if assistant_part: | |
messages.append({"role": "assistant", "content": assistant_part}) | |
messages.append({"role": "user", "content": user_content if len(user_content) > 1 or not isinstance(user_content[0], dict) or user_content[0].get("type") != "text" else user_content[0]["text"]}) | |
model_to_use = custom_model.strip() if custom_model.strip() else selected_model | |
print(f"Model selected for inference: {model_to_use}") | |
response_text = "" | |
print(f"Sending request to {provider} with model {model_to_use}.") | |
parameters = { | |
"max_tokens": max_tokens, | |
"temperature": temperature, | |
"top_p": top_p, | |
"frequency_penalty": frequency_penalty, | |
} | |
if seed is not None: | |
parameters["seed"] = seed | |
stream = client.chat_completion( | |
model=model_to_use, | |
messages=messages, | |
stream=True, | |
**parameters | |
) | |
print("Streaming response: ", end="", flush=True) | |
for chunk in stream: | |
if hasattr(chunk, 'choices') and chunk.choices: | |
delta = chunk.choices[0].delta | |
if hasattr(delta, 'content') and delta.content: | |
token_chunk = delta.content | |
print(token_chunk, end="", flush=True) | |
response_text += token_chunk | |
yield response_text | |
print("\nStream finished.") | |
except Exception as e: | |
error_message = f"Error during inference: {e}" | |
print(error_message) | |
# If there was already some response, append error. Otherwise, yield error. | |
if 'response_text' in locals() and response_text: | |
response_text += f"\n{error_message}" | |
yield response_text | |
else: | |
yield error_message | |
finally: | |
# Restore HF_TOKEN in os.environ if it was temporarily removed/modified | |
if env_hf_token_temporarily_modified: | |
if original_hf_token_env_value is not None: | |
os.environ["HF_TOKEN"] = original_hf_token_env_value | |
print("Restored HF_TOKEN in environment from its original value.") | |
else: | |
# If it was unset and originally not present, ensure it remains unset | |
if "HF_TOKEN" in os.environ: # Should not happen if original was None and we deleted | |
del os.environ["HF_TOKEN"] | |
print("HF_TOKEN was originally not set and was temporarily removed; ensuring it remains not set in env.") | |
print("Response generation attempt complete.") | |
def validate_provider(api_key, provider_choice): | |
# This validation might need adjustment based on providers. | |
# For now, it assumes any custom key might work with other providers. | |
# If HF_TOKEN is the only one available (no custom key), restrict to hf-inference. | |
if not api_key.strip() and provider_choice != "hf-inference" and ACCESS_TOKEN: | |
gr.Warning("Default HF_TOKEN can only be used with 'hf-inference' provider. Switching to 'hf-inference'.") | |
return gr.update(value="hf-inference") | |
return gr.update(value=provider_choice) | |
with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo: | |
chatbot = gr.Chatbot( | |
height=600, | |
show_copy_button=True, | |
placeholder="Select a model and begin chatting. Supports multimodal inputs.", | |
layout="panel", | |
avatar_images=(None, "https://hf.co/front/assets/huggingface_logo.svg") # Bot avatar | |
) | |
msg = gr.MultimodalTextbox( | |
placeholder="Type a message or upload images...", | |
show_label=False, | |
container=False, | |
scale=12, | |
file_types=["image"], | |
file_count="multiple", | |
sources=["upload"] | |
) | |
with gr.Accordion("Settings", open=False): | |
system_message_box = gr.Textbox( | |
value="You are a helpful AI assistant that can understand images and text.", | |
placeholder="You are a helpful assistant.", | |
label="System Prompt" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
max_tokens_slider = gr.Slider(minimum=1, maximum=4096, value=1024, step=1, label="Max tokens") | |
temperature_slider = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.01, label="Temperature") # Allow 0 for deterministic | |
top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.01, label="Top-P") # Allow 0 | |
with gr.Column(): | |
frequency_penalty_slider = gr.Slider(minimum=-2.0, maximum=2.0, value=0.0, step=0.1, label="Frequency Penalty") | |
seed_slider = gr.Slider(minimum=-1, maximum=65535, value=-1, step=1, label="Seed (-1 for random)") | |
providers_list = ["hf-inference", "cerebras", "together", "sambanova", "novita", "cohere", "fireworks-ai", "hyperbolic", "nebius"] | |
provider_radio = gr.Radio(choices=providers_list, value="hf-inference", label="Inference Provider") | |
byok_textbox = gr.Textbox( | |
value="", label="BYOK (Bring Your Own Key)", | |
info="Enter your Hugging Face API key (or provider-specific key). Overrides default. If empty, uses Space's HF_TOKEN (if set) for 'hf-inference'.", | |
placeholder="hf_... or provider_specific_key", type="password" | |
) | |
custom_model_box = gr.Textbox( | |
value="", label="Custom Model ID", | |
info="(Optional) Provide a model ID (e.g., 'meta-llama/Llama-3-8B-Instruct'). Overrides featured model selection.", | |
placeholder="org/model-name" | |
) | |
model_search_box = gr.Textbox(label="Filter Featured Models", placeholder="Search...", lines=1) | |
models_list = [ | |
"meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.1-70B-Instruct", | |
"mistralai/Mistral-Nemo-Instruct-2407", "Qwen/Qwen2.5-72B-Instruct", | |
"microsoft/Phi-3.5-mini-instruct", "NousResearch/Hermes-3-Llama-3.1-8B", | |
# Add more or fetch dynamically if possible | |
] | |
featured_model_radio = gr.Radio( | |
label="Select a Featured Model", choices=models_list, | |
value="meta-llama/Llama-3.2-11B-Vision-Instruct", interactive=True | |
) | |
gr.Markdown("[All Text Gen Models](https://huggingface.co/models?pipeline_tag=text-generation&sort=trending) | [All Multimodal Models](https://huggingface.co/models?pipeline_tag=image-text-to-text&sort=trending)") | |
# Chat history state (using chatbot component directly for history) | |
def handle_user_message_submission(user_input_mmtb, chat_history_list): | |
# user_input_mmtb is a dict: {"text": "...", "files": ["path1", "path2"]} | |
text_content = user_input_mmtb.get("text", "") | |
files = user_input_mmtb.get("files", []) | |
# Construct the display for the user message in the chat | |
# For Gradio Chatbot, user message can be a string or a tuple (text, filepath) or (None, filepath) | |
# If multiple files, they need to be sent as separate messages or handled in display | |
if not text_content and not files: | |
return chat_history_list # Or raise an error/warning | |
# Append user message to history. | |
# The actual content for the API will be constructed in respond() | |
# For display, we can show text and a placeholder for images, or actual images if supported well. | |
# Let's pass the raw MultimodalTextbox output to history for now. | |
chat_history_list.append([user_input_mmtb, None]) | |
return chat_history_list | |
def handle_bot_response_generation( | |
chat_history_list, system_msg, max_tokens, temp, top_p, freq_pen, seed_val, | |
prov, api_key_val, cust_model_val, search_term_val, feat_model_val | |
): | |
if not chat_history_list or chat_history_list[-1][0] is None: | |
yield chat_history_list # Or an error message | |
return | |
# The last user message is chat_history_list[-1][0] | |
# It's the dict from MultimodalTextbox: {"text": "...", "files": ["path1", ...]} | |
last_user_input_mmtb = chat_history_list[-1][0] | |
current_message_text = last_user_input_mmtb.get("text", "") | |
current_image_files = last_user_input_mmtb.get("files", []) | |
# Prepare history for the `respond` function (excluding the current turn's user message) | |
api_history = [] | |
for user_msg_item, bot_msg_item in chat_history_list[:-1]: | |
# Convert past user messages (which are MMTB dicts) to API format or simple strings | |
past_user_text = user_msg_item.get("text", "") | |
# For simplicity, not including past images in API history here, but could be added | |
api_history.append((past_user_text, bot_msg_item)) | |
# Stream the response | |
full_response = "" | |
for_stream_chunk in respond( | |
message=current_message_text, | |
image_files=current_image_files, | |
history=api_history, # Pass the processed history | |
system_message=system_msg, | |
max_tokens=max_tokens, | |
temperature=temp, | |
top_p=top_p, | |
frequency_penalty=freq_pen, | |
seed=seed_val, | |
provider=prov, | |
custom_api_key=api_key_val, | |
custom_model=cust_model_val, | |
model_search_term=search_term_val, # Note: search_term is for UI filtering, not API | |
selected_model=feat_model_val | |
): | |
full_response = for_stream_chunk | |
chat_history_list[-1][1] = full_response | |
yield chat_history_list | |
msg.submit( | |
handle_user_message_submission, | |
[msg, chatbot], | |
[chatbot], | |
queue=False | |
).then( | |
handle_bot_response_generation, | |
[chatbot, system_message_box, max_tokens_slider, temperature_slider, top_p_slider, | |
frequency_penalty_slider, seed_slider, provider_radio, byok_textbox, custom_model_box, | |
model_search_box, featured_model_radio], | |
[chatbot] | |
).then( | |
lambda: gr.update(value=None), # Clears MultimodalTextbox: {"text": None, "files": None} | |
[], # No inputs needed for this | |
[msg] | |
) | |
def filter_models_ui(search_term): | |
filtered = [m for m in models_list if search_term.lower() in m.lower()] if search_term else models_list | |
return gr.update(choices=filtered, value=filtered[0] if filtered else None) | |
model_search_box.change(fn=filter_models_ui, inputs=model_search_box, outputs=featured_model_radio) | |
# No need for set_custom_model_from_radio if custom_model_box overrides featured_model_radio directly in respond() | |
byok_textbox.change(fn=validate_provider, inputs=[byok_textbox, provider_radio], outputs=provider_radio) | |
provider_radio.change(fn=validate_provider, inputs=[byok_textbox, provider_radio], outputs=provider_radio) | |
print("Gradio interface initialized.") | |
if __name__ == "__main__": | |
print("Launching the demo application.") | |
# ForSpaces, share=True is often implied or handled by Spaces platform | |
# For local, share=True makes it public via Gradio link | |
demo.queue().launch(show_api=False) # .queue() is good for handling multiple users / long tasks |