File size: 5,278 Bytes
e04246a
 
c137ca3
 
a8e97ac
e04246a
43de95d
 
a8e97ac
43de95d
 
 
 
 
 
a8e97ac
c137ca3
 
 
43de95d
a8e97ac
 
c137ca3
 
e04246a
 
a8e97ac
43de95d
 
 
 
 
e04246a
a8e97ac
43de95d
 
 
 
 
 
 
 
 
 
 
 
 
a8e97ac
e04246a
 
43de95d
 
a8e97ac
 
43de95d
 
e04246a
43de95d
e04246a
 
43de95d
a8e97ac
43de95d
 
a8e97ac
 
 
 
 
 
43de95d
 
 
 
 
 
c137ca3
 
 
 
 
 
 
 
 
 
 
 
 
 
a8e97ac
43de95d
 
e04246a
a8e97ac
43de95d
 
 
e04246a
 
43de95d
 
 
 
 
a8e97ac
43de95d
 
 
a8e97ac
43de95d
 
 
 
e04246a
43de95d
 
 
e04246a
 
 
43de95d
e04246a
 
 
 
43de95d
e04246a
 
 
 
c137ca3
 
43de95d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import gradio as gr
from huggingface_hub import InferenceClient
# Import the correct exception class
from huggingface_hub.utils import HfHubHTTPError
import os

"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference

**Note:** You might need to authenticate with Hugging Face for this to work reliably.
Run `huggingface-cli login` in your terminal or set the HUGGING_FACE_HUB_TOKEN environment variable.
Alternatively, pass your token directly: InferenceClient(token="hf_YOUR_TOKEN")
"""
# Initialize the Inference Client
# It will try to use HUGGING_FACE_HUB_TOKEN environment variable or cached login
try:
    # You might need to provide a token if you haven't logged in via CLI
    # token = os.getenv("HUGGING_FACE_HUB_TOKEN")
    # client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=token)
    client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
except Exception as e:
    print(f"Error initializing InferenceClient: {e}")
    raise ValueError("Could not initialize InferenceClient. Ensure you are logged in or provide a token.") from e


def respond(
    message: str,
    history: list[tuple[str | None, str | None]],
    system_message: str,
    max_tokens: int,
    temperature: float,
    top_p: float,
):
    """
    Generates a response using the Hugging Face Inference API.

    Args:
        message: The user's input message.
        history: A list of tuples representing the conversation history.
                 Each tuple is (user_message, bot_message).
        system_message: The system prompt to guide the model.
        max_tokens: The maximum number of new tokens to generate.
        temperature: Controls randomness (higher = more random).
        top_p: Nucleus sampling parameter.

    Yields:
        The generated response incrementally.
    """
    messages = [{"role": "system", "content": system_message}]

    # Add conversation history
    for user_msg, bot_msg in history:
        if user_msg:
            messages.append({"role": "user", "content": user_msg})
        if bot_msg:
            messages.append({"role": "assistant", "content": bot_msg})

    # Add the latest user message
    messages.append({"role": "user", "content": message})

    response = ""
    try:
        # Start streaming the response
        for msg_chunk in client.chat_completion(
            messages=messages,
            max_tokens=max_tokens,
            stream=True,
            temperature=temperature,
            top_p=top_p,
        ):
            # Check if there's content in the delta
            token = msg_chunk.choices[0].delta.content
            if token:  # Add check for empty/None token
                response += token
                yield response # Yield the accumulated response so far

    # Catch HTTP errors from the Hugging Face Hub API
    except HfHubHTTPError as e:
        error_message = f"Inference API Error: {e}"
        # Try to get more details from the response if available
        if e.response:
            try:
                details = e.response.json()
                error_message += f"\nDetails: {details.get('error', 'N/A')}"
            except Exception: # Catch potential JSON decoding errors
                pass # Keep the original error message
        print(error_message)
        yield f"Sorry, I encountered an error communicating with the model service: {e}" # Display a user-friendly message

    # Catch other potential errors
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        yield f"Sorry, an unexpected error occurred: {e}"


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    respond,
    chatbot=gr.Chatbot(height=400), # Adjust chatbot height if desired
    textbox=gr.Textbox(placeholder="Ask me anything...", container=False, scale=7),
    title="Zephyr 7B Beta Chat",
    description="Chat with the Zephyr 7B Beta model using the Hugging Face Inference API.",
    theme="soft", # Optional: Apply a theme
    examples=[
        ["Hello!"],
        ["Explain the concept of Large Language Models in simple terms."],
        ["Write a short poem about the moon."],
    ],
    cache_examples=False, # Set to True to cache example results
    retry_btn="Retry",
    undo_btn="Undo",
    clear_btn="Clear",
    additional_inputs=[
        gr.Textbox(value="You are a friendly and helpful chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"), # Note: Max temp often capped lower (e.g., 1.0 or 2.0)
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
     additional_inputs_accordion=gr.Accordion(label="Advanced Options", open=False), # Group additional inputs
)


if __name__ == "__main__":
    # Ensure huggingface_hub library is up-to-date: pip install --upgrade huggingface_hub
    print("Launching Gradio Interface...")
    demo.launch()