SaisExperiments's picture
Update app.py
c137ca3 verified
raw
history blame
5.28 kB
import gradio as gr
from huggingface_hub import InferenceClient
# Import the correct exception class
from huggingface_hub.utils import HfHubHTTPError
import os
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
**Note:** You might need to authenticate with Hugging Face for this to work reliably.
Run `huggingface-cli login` in your terminal or set the HUGGING_FACE_HUB_TOKEN environment variable.
Alternatively, pass your token directly: InferenceClient(token="hf_YOUR_TOKEN")
"""
# Initialize the Inference Client
# It will try to use HUGGING_FACE_HUB_TOKEN environment variable or cached login
try:
# You might need to provide a token if you haven't logged in via CLI
# token = os.getenv("HUGGING_FACE_HUB_TOKEN")
# client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=token)
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
except Exception as e:
print(f"Error initializing InferenceClient: {e}")
raise ValueError("Could not initialize InferenceClient. Ensure you are logged in or provide a token.") from e
def respond(
message: str,
history: list[tuple[str | None, str | None]],
system_message: str,
max_tokens: int,
temperature: float,
top_p: float,
):
"""
Generates a response using the Hugging Face Inference API.
Args:
message: The user's input message.
history: A list of tuples representing the conversation history.
Each tuple is (user_message, bot_message).
system_message: The system prompt to guide the model.
max_tokens: The maximum number of new tokens to generate.
temperature: Controls randomness (higher = more random).
top_p: Nucleus sampling parameter.
Yields:
The generated response incrementally.
"""
messages = [{"role": "system", "content": system_message}]
# Add conversation history
for user_msg, bot_msg in history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if bot_msg:
messages.append({"role": "assistant", "content": bot_msg})
# Add the latest user message
messages.append({"role": "user", "content": message})
response = ""
try:
# Start streaming the response
for msg_chunk in client.chat_completion(
messages=messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
# Check if there's content in the delta
token = msg_chunk.choices[0].delta.content
if token: # Add check for empty/None token
response += token
yield response # Yield the accumulated response so far
# Catch HTTP errors from the Hugging Face Hub API
except HfHubHTTPError as e:
error_message = f"Inference API Error: {e}"
# Try to get more details from the response if available
if e.response:
try:
details = e.response.json()
error_message += f"\nDetails: {details.get('error', 'N/A')}"
except Exception: # Catch potential JSON decoding errors
pass # Keep the original error message
print(error_message)
yield f"Sorry, I encountered an error communicating with the model service: {e}" # Display a user-friendly message
# Catch other potential errors
except Exception as e:
print(f"An unexpected error occurred: {e}")
yield f"Sorry, an unexpected error occurred: {e}"
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
chatbot=gr.Chatbot(height=400), # Adjust chatbot height if desired
textbox=gr.Textbox(placeholder="Ask me anything...", container=False, scale=7),
title="Zephyr 7B Beta Chat",
description="Chat with the Zephyr 7B Beta model using the Hugging Face Inference API.",
theme="soft", # Optional: Apply a theme
examples=[
["Hello!"],
["Explain the concept of Large Language Models in simple terms."],
["Write a short poem about the moon."],
],
cache_examples=False, # Set to True to cache example results
retry_btn="Retry",
undo_btn="Undo",
clear_btn="Clear",
additional_inputs=[
gr.Textbox(value="You are a friendly and helpful chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"), # Note: Max temp often capped lower (e.g., 1.0 or 2.0)
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
additional_inputs_accordion=gr.Accordion(label="Advanced Options", open=False), # Group additional inputs
)
if __name__ == "__main__":
# Ensure huggingface_hub library is up-to-date: pip install --upgrade huggingface_hub
print("Launching Gradio Interface...")
demo.launch()