File size: 3,794 Bytes
e00ad77
 
4f5a14f
e00ad77
1221286
e8d4ae4
e00ad77
58bcb23
15152ff
f33cc36
58bcb23
ca509cb
 
 
15152ff
58bcb23
2d10bdd
58bcb23
 
 
f33cc36
 
 
 
 
 
 
 
cdfa6da
 
 
58bcb23
e8d4ae4
cdfa6da
 
 
 
58bcb23
cdfa6da
1221286
cdfa6da
 
f33cc36
58bcb23
 
 
f33cc36
 
 
1221286
f33cc36
 
 
 
4f5a14f
15152ff
 
ca509cb
f33cc36
ca509cb
cdfa6da
 
ca509cb
cdfa6da
ca509cb
15152ff
ca509cb
e00ad77
cdfa6da
fa909a7
ca509cb
 
 
 
 
 
 
 
 
 
fa909a7
ca509cb
 
15152ff
 
 
 
 
ca509cb
15152ff
 
ca509cb
15152ff
 
c1faa76
 
ca509cb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer  # Import the tokenizer

# Use the appropriate tokenizer for your model.
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")

# Define a maximum context length (tokens).  Check your model's documentation!
MAX_CONTEXT_LENGTH = 4096  # Example: Adjust this based on your model!

# Read the default prompt from a file
with open("prompt.txt", "r") as file:
    nvc_prompt_template = file.read()

def count_tokens(text: str) -> int:
    """Counts the number of tokens in a given string."""
    return len(tokenizer.encode(text))

def truncate_history(history: list[tuple[str, str]], system_message: str, max_length: int) -> list[tuple[str, str]]:
    """Truncates the conversation history to fit within the maximum token limit.
    Args:
        history: The conversation history (list of user/assistant tuples).
        system_message: The system message.
        max_length: The maximum number of tokens allowed.
    Returns:
        The truncated history.
    """
    truncated_history = []
    system_message_tokens = count_tokens(system_message)
    current_length = system_message_tokens

    # Iterate backwards through the history (newest to oldest)
    for user_msg, assistant_msg in reversed(history):
        user_tokens = count_tokens(user_msg) if user_msg else 0
        assistant_tokens = count_tokens(assistant_msg) if assistant_msg else 0
        turn_tokens = user_tokens + assistant_tokens

        if current_length + turn_tokens <= max_length:
            truncated_history.insert(0, (user_msg, assistant_msg))  # Add to the beginning
            current_length += turn_tokens
        else:
            break  # Stop adding turns if we exceed the limit

    return truncated_history

def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    """Responds to a user message, maintaining conversation history, using special tokens and message list."""
    formatted_system_message = nvc_prompt_template

    truncated_history = truncate_history(history, formatted_system_message, MAX_CONTEXT_LENGTH - max_tokens - 100)  # Reserve space for the new message and some generation

    messages = [{"role": "system", "content": formatted_system_message}]  # Start with system message
    for user_msg, assistant_msg in truncated_history:
        if user_msg:
            messages.append({"role": "user", "content": f"<|user|>\n{user_msg}</s>"})
        if assistant_msg:
            messages.append({"role": "assistant", "content": f"<|assistant|>\n{assistant_msg}</s>"})

    messages.append({"role": "user", "content": f"<|user|>\n{message}</s>"})

    response = ""
    try:
        for chunk in client.chat_completion(
            messages,
            max_tokens=max_tokens,
            stream=True,
            temperature=temperature,
            top_p=top_p,
        ):
            token = chunk.choices[0].delta.content
            response += token
            yield response
    except Exception as e:
        print(f"An error occurred: {e}")
        yield "I'm sorry, I encountered an error. Please try again."

# --- Gradio Interface ---
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value=nvc_prompt_template, label="System message", visible=False),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
    ],
)

if __name__ == "__main__":
    demo.launch()