Nymbo's picture
Update app.py
e4bb2d0 verified
raw
history blame
10.3 kB
import gradio as gr
from openai import OpenAI
import os
# Retrieve the access token from the environment variable
ACCESS_TOKEN = os.getenv("HF_TOKEN")
print("Access token loaded.")
# Initialize the OpenAI client with the Hugging Face Inference API endpoint
client = OpenAI(
base_url="https://api-inference.huggingface.co/v1/",
api_key=ACCESS_TOKEN,
)
print("OpenAI client initialized.")
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
frequency_penalty,
seed,
model,
custom_model
):
"""
This function handles the chatbot response. It takes in:
- message: the user's new message
- history: the list of previous messages, each as a tuple (user_msg, assistant_msg)
- system_message: the system prompt
- max_tokens: the maximum number of tokens to generate in the response
- temperature: sampling temperature
- top_p: top-p (nucleus) sampling
- frequency_penalty: penalize repeated tokens in the response
- seed: a fixed seed for reproducibility; -1 will mean 'random'
- model: the selected model
- custom_model: the custom model path
"""
print(f"Received message: {message}")
print(f"History: {history}")
print(f"system message: {system_message}")
print(f"max tokens: {max_tokens}, Temperature: {temperature}, Top-P: {top_p}")
print(f"Frequency Penalty: {frequency_penalty}, Seed: {seed}")
print(f"Selected Model: {model}")
print(f"Custom model: {custom_model}")
# Convert seed to None if -1 (meaning random)
if seed == -1:
seed = None
# Construct the messages array required by the API
messages = [{"role": "system", "content": system_message}]
# Add conversation history to the context
for val in history:
user_part = val[0]
assistant_part = val[1]
if user_part:
messages.append({"role": "user", "content": user_part})
print(f"Added user message to context: {user_part}")
ifassistant_part:
messages.append({"role": "assistant", "content": assistant_part})
print(f"Added assistant message to context: {assistant_part}")
# Append the latest user message
messages.append({"role": "user", "content": message})
# Start with an empty string to build the response as tokens stream in
response = ""
print("Sending request to OpenAI API.")
# Make the request to the HF Inference API via openAI-like client
for message_chunk in client.chat.completions.create(
model=custom_model if custom_model.strip() != "" else model,
max_tokens=max_tokens,
stream=True, # Stream the response
temperature=temperature,
top_p=top_p,
frequency_penalty=frequency_penalty, # <--
seed=seed, # <--
messages=messages
):
# Extract the token text from the response chunk
token_text = message_chunk.choices[0].message.content
print(f"Received token: {token_text}")
response += token_text
yield response
print("Completed response generation.")
# Create a Chatbot component with a specified height
chatbot = gr.Chatbot(height=600)
print("Chatbot interface created.")
# Define the Gradio interface
with gr.Blocks(theme='Nymbo/Nymbo_Theme') as demo:
# Tab for basic settings
with gr.Tab("Basic Settings"):
with gr.Column(elem_id="prompt-container"):
with gr.Row():
# Textbox for user to input the message
text_prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt here", lines=3, elem_id="prompt-text-input")
with gr.Row():
# Textbox for custom model input
custom_model = gr.textbox(label="Custom Model", info="HuggingFace model path (optional)", placeholder="meta-llama/Llama-3.3-70B-Instruct", lines=1, elem_id="model-search-input")
# Accordion for selecting the model
with gr.Accordion("Featured models", open=True):
# Textbox for searching models
model_search = gr.textbox(Label="Filter models", placeholder="Search for a featured model...", lines=1, elem_id="model-search-input")
# Radio buttons to select the desired model
model = gr.Radio(label="Select a model below", value="meta-llama/Llama-3.3-70B-Instruct", choices=[
"meta-llama/Llama-3.3-70B-Instruct",
"anthropic/claude-3",
"anthropic/claude-instant-3",
"anthropic/claude-2",
"anthropic/claude-2",
"anthropic/claude-instant-2",
"anthropic/claude-1.3",
"anthropic/claude-instant-1.3",
"anthropic/claude-1",
"anthropic/claude-instant-1",
"anthropic/claude-0.3",
"anthropic/claude-instant-0.3",
"anthropic/claude-0.1",
"anthropic/claude-instant-0.1",
"anthropic/claude-v2",
"anthropic/claude-instant-v2",
"anthropic/claude-v1",
"anthropic/claude-instant-v1",
"anthropic/claude-v0.3",
"anthropic/claude-instant-v0.3",
"anthropic/claude-v0.1",
"anthropic/claude-instant-v0.1",
], interactive=True, elem_id="model-radio")
# Filtering models based on search input
def filter_models(search_term):
filtered_models = [m for m in model.choices if search_term.lower() in m.lower()]
return gr.update(choices=filtered_models)
# Update model list when search box is used
model_search.change(filter_models, inputs=model, outputs=model)
# Tab for advanced settings
with gr.Tab("Advanced Settings"):
with gr.Row():
# Text box for specifying the system message
system_message = gr.text box(value="", label="System message")
with gr.Row():
# Slider for setting the maximum new tokens
max_tokens = gr.Slider(minimum=1, maximum=4096, value=512, step=1, label="Max new tokens")
with gr.Row():
# Slider for setting the temperature
temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
with gr.Row():
#Slider for setting top-p
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-P")
with gr.Row():
#Slider for setting frequency penalty
frequency_penalty = gr.Slider(minimum=-2.0, maximum=2.0, value=0.0, step=0.1, label="Frequency Penalty")
with gr.Row():
#Slider for setting the seed
seed = gr.SLider(minimum=-1, maximum=65535, value=-1, step=1, label="Seed (-1 for random)")
# Tab for information
with gr.tab("Information"):
with gr.Row():
# Display a sample prompt
gr.textbox(label="Sample prompt", value="Enter a prompt | ultra detail, ultra elaboration, ultra quality, perfect.")
with gr.Accordion("Featured Models (WiP)", open=False):
gr.html(
"""
<p><a href="https://huggingface.co/models?inferences=warm&pipeline_tag=text-to-text&sort=trending">View more models</a></p>
<table style="width:100%; text-align:center; margin:auto;">
<tr>
<th>Model</th>
<th>Description</th>
</tr>
<tr>
<td>meta-llama/Llama-3.3-70B-Instruct</td>
<td>High-quality, large-scale language model</td>
</tr>
<tr>
<td>anthropic/claude-3</td>
<td> Advanced conversational AI model</td>
</tr>
<tr>
<td>anthropic/claude-instant-3</td>
<td> Fast and efficient conversational AI model</td>
</tr>
</table>
"""
)
with gr.Accordion("Parameters Overview", open=False):
gr.markdown(
"""
## System Message
- **Description**: The system message provides context and instructions to the model.
- **Default**: ""
## Max New Tokens
- **Description**: The maximum number of tokens to generate in the response.
- **Default**: 512
- **Range**: 1 to 4096
## Temperature
- **Description**: Controls the randomness of the output. Lower values make the output more deterministic, higher values make it output more varied.
- **Default**: 0.7
- **Range**: 0.1 to 4.0
## Top-P
- **Description**: Controls the diversity of the output. Lower values make the output more focused, higher values make it more varied.
- **Default**: 0.7
- **Range**: 0.1 to 1.0
## Frequency Penalty
- **Description**: Penalizes repeated tokens in the response. Higher values makes the output less repetitive.
- **Default**: 0.0
- **Range**: -2.0 to 2.0
## Seed
- **Description**: A fixed seed for reproducibility. -1 for random.
- **Default**: -1
- **Range**: -1 to 65535
"""
)
"""
# Row containing the 'Run' button to trigger the query function
with gr.Row():
text_button = gr.Button("Run", variant='primary', elem_id="gen-button")
# Row for displaying the generated response
with gr.Row():
response_output = gr.Textbox(label="Response Output", elem_id="response-output")
# Set up button to call the respond function
text_button.click(
respond,
inputs=[
text_prompt, model, custom_model, system_message, max_tokens, temperature, top_p, frequency_penalty, seed
],
outputs=[response_output]
)
print("Gradio interface initialized.")
if __name__ == "__main__":
demo.launch(show_api=False, share=False)