File size: 3,794 Bytes
e00ad77 4f5a14f e00ad77 1221286 e8d4ae4 e00ad77 58bcb23 15152ff f33cc36 58bcb23 ca509cb 15152ff 58bcb23 2d10bdd 58bcb23 f33cc36 cdfa6da 58bcb23 e8d4ae4 cdfa6da 58bcb23 cdfa6da 1221286 cdfa6da f33cc36 58bcb23 f33cc36 1221286 f33cc36 4f5a14f 15152ff ca509cb f33cc36 ca509cb cdfa6da ca509cb cdfa6da ca509cb 15152ff ca509cb e00ad77 cdfa6da fa909a7 ca509cb fa909a7 ca509cb 15152ff ca509cb 15152ff ca509cb 15152ff c1faa76 ca509cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer # Import the tokenizer
# Use the appropriate tokenizer for your model.
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
# Define a maximum context length (tokens). Check your model's documentation!
MAX_CONTEXT_LENGTH = 4096 # Example: Adjust this based on your model!
# Read the default prompt from a file
with open("prompt.txt", "r") as file:
nvc_prompt_template = file.read()
def count_tokens(text: str) -> int:
"""Counts the number of tokens in a given string."""
return len(tokenizer.encode(text))
def truncate_history(history: list[tuple[str, str]], system_message: str, max_length: int) -> list[tuple[str, str]]:
"""Truncates the conversation history to fit within the maximum token limit.
Args:
history: The conversation history (list of user/assistant tuples).
system_message: The system message.
max_length: The maximum number of tokens allowed.
Returns:
The truncated history.
"""
truncated_history = []
system_message_tokens = count_tokens(system_message)
current_length = system_message_tokens
# Iterate backwards through the history (newest to oldest)
for user_msg, assistant_msg in reversed(history):
user_tokens = count_tokens(user_msg) if user_msg else 0
assistant_tokens = count_tokens(assistant_msg) if assistant_msg else 0
turn_tokens = user_tokens + assistant_tokens
if current_length + turn_tokens <= max_length:
truncated_history.insert(0, (user_msg, assistant_msg)) # Add to the beginning
current_length += turn_tokens
else:
break # Stop adding turns if we exceed the limit
return truncated_history
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
"""Responds to a user message, maintaining conversation history, using special tokens and message list."""
formatted_system_message = nvc_prompt_template
truncated_history = truncate_history(history, formatted_system_message, MAX_CONTEXT_LENGTH - max_tokens - 100) # Reserve space for the new message and some generation
messages = [{"role": "system", "content": formatted_system_message}] # Start with system message
for user_msg, assistant_msg in truncated_history:
if user_msg:
messages.append({"role": "user", "content": f"<|user|>\n{user_msg}</s>"})
if assistant_msg:
messages.append({"role": "assistant", "content": f"<|assistant|>\n{assistant_msg}</s>"})
messages.append({"role": "user", "content": f"<|user|>\n{message}</s>"})
response = ""
try:
for chunk in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = chunk.choices[0].delta.content
response += token
yield response
except Exception as e:
print(f"An error occurred: {e}")
yield "I'm sorry, I encountered an error. Please try again."
# --- Gradio Interface ---
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value=nvc_prompt_template, label="System message", visible=False),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
],
)
if __name__ == "__main__":
demo.launch()
|