File size: 4,546 Bytes
e00ad77
 
0047abe
 
 
e00ad77
0047abe
e8d4ae4
e00ad77
58bcb23
0047abe
 
58bcb23
0047abe
ca509cb
 
15152ff
0047abe
 
4282ccc
58bcb23
0047abe
58bcb23
 
0047abe
 
 
 
 
 
 
 
 
 
 
 
4282ccc
0047abe
 
 
 
 
 
 
 
 
 
cdfa6da
4282ccc
58bcb23
4282ccc
58bcb23
f33cc36
 
0047abe
1221286
f33cc36
 
 
 
0047abe
 
 
 
 
 
15152ff
 
0047abe
 
 
f33cc36
0047abe
 
 
 
 
 
 
15152ff
0047abe
 
e00ad77
cdfa6da
fa909a7
ca509cb
0047abe
ca509cb
 
 
 
 
 
 
 
fa909a7
ca509cb
 
15152ff
0047abe
 
 
15152ff
 
 
 
9d6a6b8
ff9a596
15152ff
ca509cb
15152ff
 
c1faa76
 
ca509cb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer  # Import the tokenizer
from langchain.memory import ConversationBufferMemory
from langchain.schema import HumanMessage, AIMessage

# Use the appropriate tokenizer for your model.
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")

# Define a maximum context length (tokens).  Check your model's documentation!
MAX_CONTEXT_LENGTH = 4096  # Example: Adjust this based on your model!

# Read the default prompt from a file
with open("prompt.txt", "r") as file:
    nvc_prompt_template = file.read()

# Initialize LangChain Conversation Memory
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)

def count_tokens(text: str) -> int:
    """Counts the number of tokens in a given string."""
    return len(tokenizer.encode(text))

def truncate_memory(memory, system_message: str, max_length: int):
    """
    Truncates the conversation memory messages to fit within the maximum token limit.
    
    Args:
        memory: The LangChain conversation memory object.
        system_message: The system message.
        max_length: The maximum number of tokens allowed.
        
    Returns:
        A list of messages (as dicts with role and content) that fit within the token limit.
    """
    truncated_messages = []
    system_tokens = count_tokens(system_message)
    current_length = system_tokens

    # Iterate backwards through the memory (newest to oldest)
    for msg in reversed(memory.chat_memory.messages):
        tokens = count_tokens(msg.content)
        if current_length + tokens <= max_length:
            role = "user" if isinstance(msg, HumanMessage) else "assistant"
            truncated_messages.insert(0, {"role": role, "content": msg.content})
            current_length += tokens
        else:
            break

    return truncated_messages

def respond(
    message,
    history: list[tuple[str, str]],  # Required by Gradio but we now use LangChain memory
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    """
    Responds to a user message while maintaining conversation history via LangChain memory.
    It builds the prompt with a system message and the (truncated) conversation history,
    streams the response from the client, and finally updates the memory with the new response.
    """
    # Use your prompt template as the system message.
    formatted_system_message = nvc_prompt_template

    # Prepare and add the new user message (with your special tokens) to memory.
    new_user_message = f"<|user|>\n{message}</s>"
    memory.chat_memory.add_message(HumanMessage(content=new_user_message))

    # Truncate memory to ensure the context fits within the maximum token length (reserve space for generation).
    truncated_history = truncate_memory(
        memory, formatted_system_message, MAX_CONTEXT_LENGTH - max_tokens - 100
    )
    # Ensure the current user message is present at the end.
    if not truncated_history or truncated_history[-1]["content"] != new_user_message:
        truncated_history.append({"role": "user", "content": new_user_message})

    # Build the full message list: system prompt + conversation history.
    messages = [{"role": "system", "content": formatted_system_message}] + truncated_history

    response = ""
    try:
        for chunk in client.chat_completion(
            messages,
            max_tokens=max_tokens,
            stream=True,
            temperature=temperature,
            top_p=top_p,
        ):
            token = chunk.choices[0].delta.content
            response += token
            yield response
    except Exception as e:
        print(f"An error occurred: {e}")
        yield "I'm sorry, I encountered an error. Please try again."

    # Once the full response is generated, add it to the LangChain memory.
    memory.chat_memory.add_message(AIMessage(content=f"<|assistant|>\n{response}</s>"))

# --- Gradio Interface ---
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value=nvc_prompt_template, label="System message", visible=True),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
    ],
)

if __name__ == "__main__":
    demo.launch()