File size: 3,348 Bytes
e00ad77 2d10bdd e00ad77 2d10bdd e00ad77 58bcb23 2d10bdd 58bcb23 2d10bdd 58bcb23 2d10bdd 58bcb23 2d10bdd cdfa6da 58bcb23 cdfa6da 58bcb23 cdfa6da 2d10bdd cdfa6da 2d10bdd 58bcb23 2d10bdd 58bcb23 2d10bdd 58bcb23 2d10bdd 58bcb23 2d10bdd cdfa6da 58bcb23 cdfa6da 58bcb23 e00ad77 cdfa6da fa909a7 58bcb23 fa909a7 2d10bdd 58bcb23 f2b4cb5 2d10bdd fa909a7 58bcb23 fa909a7 2d10bdd fa909a7 c1faa76 58bcb23 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import LlamaTokenizer # Use LlamaTokenizer instead of AutoTokenizer
# Load the correct tokenizer
tokenizer = LlamaTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
# Define max context length (tokens)
MAX_CONTEXT_LENGTH = 4096
default_nvc_prompt_template = """You are Roos, an NVC (Nonviolent Communication) Chatbot. Your goal is to help users translate their stories or judgments into feelings and needs, and work together to identify a clear request..."""
def count_tokens(text: str) -> int:
"""Counts the number of tokens in a given string."""
return len(tokenizer.encode(text))
def truncate_history(history: list[tuple[str, str]], system_message: str, max_length: int) -> list[tuple[str, str]]:
"""Truncates the conversation history to fit within the maximum token limit."""
truncated_history = []
system_message_tokens = count_tokens(system_message)
current_length = system_message_tokens
for user_msg, assistant_msg in reversed(history):
user_tokens = count_tokens(user_msg) if user_msg else 0
assistant_tokens = count_tokens(assistant_msg) if assistant_msg else 0
turn_tokens = user_tokens + assistant_tokens
if current_length + turn_tokens <= max_length:
truncated_history.insert(0, (user_msg, assistant_msg)) # Add to the beginning
current_length += turn_tokens
else:
break # Stop if limit exceeded
return truncated_history
def respond(message, history, system_message, max_tokens, temperature, top_p):
"""Handles user message and generates a response."""
if message.lower() == "clear memory":
return "", [] # Reset chat history
formatted_system_message = system_message
truncated_history = truncate_history(history, formatted_system_message, MAX_CONTEXT_LENGTH - max_tokens - 100)
messages = [{"role": "system", "content": formatted_system_message}]
for user_msg, assistant_msg in truncated_history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
response = ""
try:
for chunk in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p
):
token = chunk.choices[0].delta.content
response += token
yield response
except Exception as e:
print(f"Error: {e}")
yield "I'm sorry, I encountered an error. Please try again."
# Build Gradio UI
demo = gr.ChatInterface(
fn=respond,
additional_inputs=[
gr.Textbox(value=default_nvc_prompt_template, label="System message", lines=10),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
],
)
if __name__ == "__main__":
demo.launch(share=True)
|