Nymbo's picture
Update app.py
c3b8601 verified
raw
history blame
18.4 kB
import gradio as gr
from huggingface_hub import InferenceClient
import os
import json
import base64
from PIL import Image
import io
# Load the default access token from environment variable at startup
# This will be used if no custom key is provided by the user.
ACCESS_TOKEN = os.getenv("HF_TOKEN")
print(f"Default HF_TOKEN from environment loaded: {'Present' if ACCESS_TOKEN else 'Not set'}")
# Function to encode image to base64
def encode_image(image_path):
if not image_path:
print("No image path provided")
return None
try:
print(f"Encoding image from path: {image_path}")
if isinstance(image_path, Image.Image):
image = image_path
else:
image = Image.open(image_path)
if image.mode == 'RGBA':
image = image.convert('RGB')
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
print("Image encoded successfully")
return img_str
except Exception as e:
print(f"Error encoding image: {e}")
return None
def respond(
message,
image_files,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
frequency_penalty,
seed,
provider,
custom_api_key, # This is the value from the BYOK textbox
custom_model,
model_search_term,
selected_model
):
print(f"Received message: {message}")
print(f"Received {len(image_files) if image_files else 0} images")
# print(f"History: {history}") # Can be very verbose
print(f"System message: {system_message}")
print(f"Max tokens: {max_tokens}, Temperature: {temperature}, Top-P: {top_p}")
print(f"Frequency Penalty: {frequency_penalty}, Seed: {seed}")
print(f"Selected provider: {provider}")
print(f"Custom API Key input field value (raw): '{custom_api_key[:10]}...' (masked if long)")
print(f"Selected model (custom_model input field): {custom_model}")
print(f"Model search term: {model_search_term}")
print(f"Selected model from radio: {selected_model}")
token_to_use = None
original_hf_token_env_value = os.environ.get("HF_TOKEN")
env_hf_token_temporarily_modified = False
if custom_api_key and custom_api_key.strip():
token_to_use = custom_api_key.strip()
print(f"USING CUSTOM API KEY (BYOK): '{token_to_use[:5]}...' (masked for security).")
# Aggressively ensure custom key is fundamental:
# Temporarily remove HF_TOKEN from os.environ if it exists,
# to prevent any possibility of InferenceClient picking it up.
if "HF_TOKEN" in os.environ:
print(f"Temporarily unsetting HF_TOKEN from environment (was: {'Present' if os.environ.get('HF_TOKEN') else 'Not set'}) to prioritize custom key.")
del os.environ["HF_TOKEN"]
env_hf_token_temporarily_modified = True
elif ACCESS_TOKEN: # Use default token from environment if no custom key
token_to_use = ACCESS_TOKEN
print(f"USING DEFAULT API KEY (HF_TOKEN from environment variable at script start): '{token_to_use[:5]}...' (masked for security).")
# Ensure HF_TOKEN is set in the current env if it was loaded at start
# This handles cases where it might have been unset by a previous call with a custom key
if original_hf_token_env_value is not None:
os.environ["HF_TOKEN"] = original_hf_token_env_value
elif "HF_TOKEN" in os.environ: # If ACCESS_TOKEN was loaded but original_hf_token_env_value was None (e.g. set by other means)
pass # Let it be whatever it is
else:
print("No custom API key provided AND no default HF_TOKEN was found in environment at script start.")
print("InferenceClient will be initialized without an explicit token. May fail or use public access.")
# token_to_use remains None
# If HF_TOKEN was in env and we want to ensure it's not used when token_to_use is None:
if "HF_TOKEN" in os.environ:
print(f"Temporarily unsetting HF_TOKEN from environment (was: {'Present' if os.environ.get('HF_TOKEN') else 'Not set'}) as no valid key is chosen.")
del os.environ["HF_TOKEN"]
env_hf_token_temporarily_modified = True # Mark for restoration
print(f"Final token being passed to InferenceClient: '{str(token_to_use)[:5]}...' (masked)" if token_to_use else "None")
try:
client = InferenceClient(token=token_to_use, provider=provider)
print(f"Hugging Face Inference Client initialized with {provider} provider.")
if seed == -1:
seed = None
user_content = []
if message and message.strip():
user_content.append({"type": "text", "text": message})
if image_files:
for img_path in image_files:
if img_path:
encoded_image = encode_image(img_path)
if encoded_image:
user_content.append({
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}
})
if not user_content: # If only images were sent and none encoded, or empty message
if image_files: # If there were image files, it implies an image-only message
user_content = [{"type": "text", "text": ""}] # Send an empty text for context, or specific prompt
else: # Truly empty input
yield "Error: Empty message content."
return
messages = [{"role": "system", "content": system_message}]
for val in history:
user_part, assistant_part = val
# Handle multimodal history if necessary (simplified for now)
if isinstance(user_part, dict) and 'files' in user_part: # from MultimodalTextbox
history_text = user_part.get("text", "")
history_files = user_part.get("files", [])
current_user_content_history = []
if history_text:
current_user_content_history.append({"type": "text", "text": history_text})
for h_img_path in history_files:
encoded_h_img = encode_image(h_img_path)
if encoded_h_img:
current_user_content_history.append({
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{encoded_h_img}"}
})
if current_user_content_history:
messages.append({"role": "user", "content": current_user_content_history})
elif isinstance(user_part, str): # from simple text history
messages.append({"role": "user", "content": user_part})
if assistant_part:
messages.append({"role": "assistant", "content": assistant_part})
messages.append({"role": "user", "content": user_content if len(user_content) > 1 or not isinstance(user_content[0], dict) or user_content[0].get("type") != "text" else user_content[0]["text"]})
model_to_use = custom_model.strip() if custom_model.strip() else selected_model
print(f"Model selected for inference: {model_to_use}")
response_text = ""
print(f"Sending request to {provider} with model {model_to_use}.")
parameters = {
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"frequency_penalty": frequency_penalty,
}
if seed is not None:
parameters["seed"] = seed
stream = client.chat_completion(
model=model_to_use,
messages=messages,
stream=True,
**parameters
)
print("Streaming response: ", end="", flush=True)
for chunk in stream:
if hasattr(chunk, 'choices') and chunk.choices:
delta = chunk.choices[0].delta
if hasattr(delta, 'content') and delta.content:
token_chunk = delta.content
print(token_chunk, end="", flush=True)
response_text += token_chunk
yield response_text
print("\nStream finished.")
except Exception as e:
error_message = f"Error during inference: {e}"
print(error_message)
# If there was already some response, append error. Otherwise, yield error.
if 'response_text' in locals() and response_text:
response_text += f"\n{error_message}"
yield response_text
else:
yield error_message
finally:
# Restore HF_TOKEN in os.environ if it was temporarily removed/modified
if env_hf_token_temporarily_modified:
if original_hf_token_env_value is not None:
os.environ["HF_TOKEN"] = original_hf_token_env_value
print("Restored HF_TOKEN in environment from its original value.")
else:
# If it was unset and originally not present, ensure it remains unset
if "HF_TOKEN" in os.environ: # Should not happen if original was None and we deleted
del os.environ["HF_TOKEN"]
print("HF_TOKEN was originally not set and was temporarily removed; ensuring it remains not set in env.")
print("Response generation attempt complete.")
def validate_provider(api_key, provider_choice):
# This validation might need adjustment based on providers.
# For now, it assumes any custom key might work with other providers.
# If HF_TOKEN is the only one available (no custom key), restrict to hf-inference.
if not api_key.strip() and provider_choice != "hf-inference" and ACCESS_TOKEN:
gr.Warning("Default HF_TOKEN can only be used with 'hf-inference' provider. Switching to 'hf-inference'.")
return gr.update(value="hf-inference")
return gr.update(value=provider_choice)
with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
chatbot = gr.Chatbot(
height=600,
show_copy_button=True,
placeholder="Select a model and begin chatting. Supports multimodal inputs.",
layout="panel",
avatar_images=(None, "https://hf.co/front/assets/huggingface_logo.svg") # Bot avatar
)
msg = gr.MultimodalTextbox(
placeholder="Type a message or upload images...",
show_label=False,
container=False,
scale=12,
file_types=["image"],
file_count="multiple",
sources=["upload"]
)
with gr.Accordion("Settings", open=False):
system_message_box = gr.Textbox(
value="You are a helpful AI assistant that can understand images and text.",
placeholder="You are a helpful assistant.",
label="System Prompt"
)
with gr.Row():
with gr.Column():
max_tokens_slider = gr.Slider(minimum=1, maximum=4096, value=1024, step=1, label="Max tokens")
temperature_slider = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.01, label="Temperature") # Allow 0 for deterministic
top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.01, label="Top-P") # Allow 0
with gr.Column():
frequency_penalty_slider = gr.Slider(minimum=-2.0, maximum=2.0, value=0.0, step=0.1, label="Frequency Penalty")
seed_slider = gr.Slider(minimum=-1, maximum=65535, value=-1, step=1, label="Seed (-1 for random)")
providers_list = ["hf-inference", "cerebras", "together", "sambanova", "novita", "cohere", "fireworks-ai", "hyperbolic", "nebius"]
provider_radio = gr.Radio(choices=providers_list, value="hf-inference", label="Inference Provider")
byok_textbox = gr.Textbox(
value="", label="BYOK (Bring Your Own Key)",
info="Enter your Hugging Face API key (or provider-specific key). Overrides default. If empty, uses Space's HF_TOKEN (if set) for 'hf-inference'.",
placeholder="hf_... or provider_specific_key", type="password"
)
custom_model_box = gr.Textbox(
value="", label="Custom Model ID",
info="(Optional) Provide a model ID (e.g., 'meta-llama/Llama-3-8B-Instruct'). Overrides featured model selection.",
placeholder="org/model-name"
)
model_search_box = gr.Textbox(label="Filter Featured Models", placeholder="Search...", lines=1)
models_list = [
"meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.1-70B-Instruct",
"mistralai/Mistral-Nemo-Instruct-2407", "Qwen/Qwen2.5-72B-Instruct",
"microsoft/Phi-3.5-mini-instruct", "NousResearch/Hermes-3-Llama-3.1-8B",
# Add more or fetch dynamically if possible
]
featured_model_radio = gr.Radio(
label="Select a Featured Model", choices=models_list,
value="meta-llama/Llama-3.2-11B-Vision-Instruct", interactive=True
)
gr.Markdown("[All Text Gen Models](https://huggingface.co/models?pipeline_tag=text-generation&sort=trending) | [All Multimodal Models](https://huggingface.co/models?pipeline_tag=image-text-to-text&sort=trending)")
# Chat history state (using chatbot component directly for history)
def handle_user_message_submission(user_input_mmtb, chat_history_list):
# user_input_mmtb is a dict: {"text": "...", "files": ["path1", "path2"]}
text_content = user_input_mmtb.get("text", "")
files = user_input_mmtb.get("files", [])
# Construct the display for the user message in the chat
# For Gradio Chatbot, user message can be a string or a tuple (text, filepath) or (None, filepath)
# If multiple files, they need to be sent as separate messages or handled in display
if not text_content and not files:
return chat_history_list # Or raise an error/warning
# Append user message to history.
# The actual content for the API will be constructed in respond()
# For display, we can show text and a placeholder for images, or actual images if supported well.
# Let's pass the raw MultimodalTextbox output to history for now.
chat_history_list.append([user_input_mmtb, None])
return chat_history_list
def handle_bot_response_generation(
chat_history_list, system_msg, max_tokens, temp, top_p, freq_pen, seed_val,
prov, api_key_val, cust_model_val, search_term_val, feat_model_val
):
if not chat_history_list or chat_history_list[-1][0] is None:
yield chat_history_list # Or an error message
return
# The last user message is chat_history_list[-1][0]
# It's the dict from MultimodalTextbox: {"text": "...", "files": ["path1", ...]}
last_user_input_mmtb = chat_history_list[-1][0]
current_message_text = last_user_input_mmtb.get("text", "")
current_image_files = last_user_input_mmtb.get("files", [])
# Prepare history for the `respond` function (excluding the current turn's user message)
api_history = []
for user_msg_item, bot_msg_item in chat_history_list[:-1]:
# Convert past user messages (which are MMTB dicts) to API format or simple strings
past_user_text = user_msg_item.get("text", "")
# For simplicity, not including past images in API history here, but could be added
api_history.append((past_user_text, bot_msg_item))
# Stream the response
full_response = ""
for_stream_chunk in respond(
message=current_message_text,
image_files=current_image_files,
history=api_history, # Pass the processed history
system_message=system_msg,
max_tokens=max_tokens,
temperature=temp,
top_p=top_p,
frequency_penalty=freq_pen,
seed=seed_val,
provider=prov,
custom_api_key=api_key_val,
custom_model=cust_model_val,
model_search_term=search_term_val, # Note: search_term is for UI filtering, not API
selected_model=feat_model_val
):
full_response = for_stream_chunk
chat_history_list[-1][1] = full_response
yield chat_history_list
msg.submit(
handle_user_message_submission,
[msg, chatbot],
[chatbot],
queue=False
).then(
handle_bot_response_generation,
[chatbot, system_message_box, max_tokens_slider, temperature_slider, top_p_slider,
frequency_penalty_slider, seed_slider, provider_radio, byok_textbox, custom_model_box,
model_search_box, featured_model_radio],
[chatbot]
).then(
lambda: gr.update(value=None), # Clears MultimodalTextbox: {"text": None, "files": None}
[], # No inputs needed for this
[msg]
)
def filter_models_ui(search_term):
filtered = [m for m in models_list if search_term.lower() in m.lower()] if search_term else models_list
return gr.update(choices=filtered, value=filtered[0] if filtered else None)
model_search_box.change(fn=filter_models_ui, inputs=model_search_box, outputs=featured_model_radio)
# No need for set_custom_model_from_radio if custom_model_box overrides featured_model_radio directly in respond()
byok_textbox.change(fn=validate_provider, inputs=[byok_textbox, provider_radio], outputs=provider_radio)
provider_radio.change(fn=validate_provider, inputs=[byok_textbox, provider_radio], outputs=provider_radio)
print("Gradio interface initialized.")
if __name__ == "__main__":
print("Launching the demo application.")
# ForSpaces, share=True is often implied or handled by Spaces platform
# For local, share=True makes it public via Gradio link
demo.queue().launch(show_api=False) # .queue() is good for handling multiple users / long tasks