File size: 3,348 Bytes
e00ad77
 
2d10bdd
e00ad77
2d10bdd
 
e00ad77
58bcb23
2d10bdd
 
58bcb23
2d10bdd
58bcb23
 
2d10bdd
58bcb23
 
 
2d10bdd
cdfa6da
 
 
58bcb23
cdfa6da
 
 
 
58bcb23
cdfa6da
2d10bdd
cdfa6da
 
2d10bdd
58bcb23
 
 
2d10bdd
 
 
58bcb23
2d10bdd
58bcb23
2d10bdd
 
58bcb23
2d10bdd
 
cdfa6da
 
58bcb23
cdfa6da
58bcb23
 
 
e00ad77
cdfa6da
fa909a7
58bcb23
 
 
 
 
 
 
 
 
 
fa909a7
2d10bdd
58bcb23
f2b4cb5
2d10bdd
fa909a7
58bcb23
fa909a7
2d10bdd
 
 
 
fa909a7
 
c1faa76
 
58bcb23
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import LlamaTokenizer  # Use LlamaTokenizer instead of AutoTokenizer

# Load the correct tokenizer
tokenizer = LlamaTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")

# Define max context length (tokens)
MAX_CONTEXT_LENGTH = 4096  

default_nvc_prompt_template = """You are Roos, an NVC (Nonviolent Communication) Chatbot. Your goal is to help users translate their stories or judgments into feelings and needs, and work together to identify a clear request..."""

def count_tokens(text: str) -> int:
    """Counts the number of tokens in a given string."""
    return len(tokenizer.encode(text))

def truncate_history(history: list[tuple[str, str]], system_message: str, max_length: int) -> list[tuple[str, str]]:
    """Truncates the conversation history to fit within the maximum token limit."""
    truncated_history = []
    system_message_tokens = count_tokens(system_message)
    current_length = system_message_tokens

    for user_msg, assistant_msg in reversed(history):
        user_tokens = count_tokens(user_msg) if user_msg else 0
        assistant_tokens = count_tokens(assistant_msg) if assistant_msg else 0
        turn_tokens = user_tokens + assistant_tokens

        if current_length + turn_tokens <= max_length:
            truncated_history.insert(0, (user_msg, assistant_msg))  # Add to the beginning
            current_length += turn_tokens
        else:
            break  # Stop if limit exceeded

    return truncated_history

def respond(message, history, system_message, max_tokens, temperature, top_p):
    """Handles user message and generates a response."""
    
    if message.lower() == "clear memory":
        return "", []  # Reset chat history

    formatted_system_message = system_message
    truncated_history = truncate_history(history, formatted_system_message, MAX_CONTEXT_LENGTH - max_tokens - 100)

    messages = [{"role": "system", "content": formatted_system_message}]
    
    for user_msg, assistant_msg in truncated_history:
        if user_msg:
            messages.append({"role": "user", "content": user_msg})
        if assistant_msg:
            messages.append({"role": "assistant", "content": assistant_msg})

    messages.append({"role": "user", "content": message})

    response = ""
    try:
        for chunk in client.chat_completion(
            messages,
            max_tokens=max_tokens,
            stream=True,
            temperature=temperature,
            top_p=top_p
        ):
            token = chunk.choices[0].delta.content
            response += token
            yield response
    except Exception as e:
        print(f"Error: {e}")
        yield "I'm sorry, I encountered an error. Please try again."

# Build Gradio UI
demo = gr.ChatInterface(
    fn=respond,
    additional_inputs=[
        gr.Textbox(value=default_nvc_prompt_template, label="System message", lines=10),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
    ],
)

if __name__ == "__main__":
    demo.launch(share=True)