SaisExperiments's picture
Update app.py
43de95d verified
raw
history blame
4.69 kB
import gradio as gr
from huggingface_hub import InferenceClient
from huggingface_hub.inference_api import InferenceApiException
import os
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
**Note:** You might need to authenticate with Hugging Face for this to work reliably.
Run `huggingface-cli login` in your terminal or set the HUGGING_FACE_HUB_TOKEN environment variable.
Alternatively, pass your token directly: InferenceClient(token="hf_YOUR_TOKEN")
"""
# Initialize the Inference Client
# It will try to use HUGGING_FACE_HUB_TOKEN environment variable or cached login
try:
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
except Exception as e:
print(f"Error initializing InferenceClient: {e}")
# Optionally, provide a default token if needed and available
# token = os.getenv("HUGGING_FACE_HUB_TOKEN")
# if token:
# client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=token)
# else:
# raise ValueError("Could not initialize InferenceClient. Ensure you are logged in or provide a token.") from e
# For now, let's just raise it if initialization fails fundamentally
raise
def respond(
message: str,
history: list[tuple[str | None, str | None]],
system_message: str,
max_tokens: int,
temperature: float,
top_p: float,
):
"""
Generates a response using the Hugging Face Inference API.
Args:
message: The user's input message.
history: A list of tuples representing the conversation history.
Each tuple is (user_message, bot_message).
system_message: The system prompt to guide the model.
max_tokens: The maximum number of new tokens to generate.
temperature: Controls randomness (higher = more random).
top_p: Nucleus sampling parameter.
Yields:
The generated response incrementally.
"""
messages = [{"role": "system", "content": system_message}]
# Add conversation history
for user_msg, bot_msg in history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if bot_msg:
messages.append({"role": "assistant", "content": bot_msg})
# Add the latest user message
messages.append({"role": "user", "content": message})
response = ""
try:
# Start streaming the response
for msg_chunk in client.chat_completion(
messages=messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
# Check if there's content in the delta
token = msg_chunk.choices[0].delta.content
if token: # Add check for empty/None token
response += token
yield response # Yield the accumulated response so far
except InferenceApiException as e:
print(f"Inference API Error: {e}")
yield f"Sorry, I encountered an error: {e}"
except Exception as e:
print(f"An unexpected error occurred: {e}")
yield f"Sorry, an unexpected error occurred: {e}"
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
chatbot=gr.Chatbot(height=400), # Adjust chatbot height if desired
textbox=gr.Textbox(placeholder="Ask me anything...", container=False, scale=7),
title="Zephyr 7B Beta Chat",
description="Chat with the Zephyr 7B Beta model using the Hugging Face Inference API.",
theme="soft", # Optional: Apply a theme
examples=[
["Hello!"],
["Explain the concept of Large Language Models in simple terms."],
["Write a short poem about the moon."],
],
cache_examples=False, # Set to True to cache example results
retry_btn="Retry",
undo_btn="Undo",
clear_btn="Clear",
additional_inputs=[
gr.Textbox(value="You are a friendly and helpful chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"), # Note: Max temp often capped lower (e.g., 1.0 or 2.0)
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
additional_inputs_accordion=gr.Accordion(label="Advanced Options", open=False), # Group additional inputs
)
if __name__ == "__main__":
demo.launch()