File size: 3,412 Bytes
e00ad77
 
4282ccc
 
 
e00ad77
4282ccc
e8d4ae4
e00ad77
58bcb23
4282ccc
58bcb23
4282ccc
ca509cb
 
15152ff
4282ccc
 
 
58bcb23
 
 
4282ccc
 
 
 
 
 
 
 
 
cdfa6da
4282ccc
58bcb23
4282ccc
58bcb23
f33cc36
 
4282ccc
1221286
f33cc36
 
 
 
15152ff
 
4282ccc
 
 
 
 
 
 
f33cc36
4282ccc
 
 
 
15152ff
4282ccc
 
 
 
 
 
e00ad77
cdfa6da
fa909a7
ca509cb
4282ccc
ca509cb
 
 
 
 
 
 
 
4282ccc
 
 
 
fa909a7
ca509cb
 
15152ff
 
 
 
 
9d6a6b8
15152ff
 
ca509cb
15152ff
 
c1faa76
 
ca509cb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer
from langchain.memory import ConversationBufferWindowMemory
from langchain.schema import HumanMessage, AIMessage, SystemMessage

# Initialize tokenizer and inference client
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")

MAX_CONTEXT_LENGTH = 4096

# Load prompt from file
with open("prompt.txt", "r") as file:
    nvc_prompt_template = file.read()

# Initialize LangChain Memory (buffer window to keep recent conversation)
memory = ConversationBufferWindowMemory(k=10, return_messages=True)

def count_tokens(text: str) -> int:
    return len(tokenizer.encode(text))

def truncate_history(messages, max_length):
    truncated_messages = []
    total_tokens = 0

    for message in reversed(messages):
        message_tokens = count_tokens(message.content)
        if total_tokens + message_tokens <= max_length:
            truncated_messages.insert(0, message)
            total_tokens += message_tokens
        else:
            break

    return truncated_messages

def respond(
    message,
    history,
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    formatted_system_message = nvc_prompt_template

    # Retrieve conversation history from LangChain memory
    memory.save_context({"input": message}, {"output": ""})
    chat_history = memory.load_memory_variables({})["history"]

    # Truncate history to ensure it fits within context window
    max_history_tokens = MAX_CONTEXT_LENGTH - max_tokens - count_tokens(formatted_system_message) - 100
    truncated_chat_history = truncate_history(chat_history, max_history_tokens)

    # Construct the messages for inference
    messages = [SystemMessage(content=formatted_system_message)]
    messages.extend(truncated_chat_history)
    messages.append(HumanMessage(content=message))

    # Convert LangChain messages to the format required by HuggingFace client
    formatted_messages = []
    for msg in messages:
        role = "system" if isinstance(msg, SystemMessage) else "user" if isinstance(msg, HumanMessage) else "assistant"
        content = f"<|{role}|>\n{msg.content}</s>"
        formatted_messages.append({"role": role, "content": content})

    response = ""
    try:
        for chunk in client.chat_completion(
            formatted_messages,
            max_tokens=max_tokens,
            stream=True,
            temperature=temperature,
            top_p=top_p,
        ):
            token = chunk.choices[0].delta.content
            response += token
            yield response

        # Save AI's response in LangChain memory
        memory.chat_memory.add_ai_message(response)

    except Exception as e:
        print(f"An error occurred: {e}")
        yield "I'm sorry, I encountered an error. Please try again."

# --- Gradio Interface ---
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value=nvc_prompt_template, label="System message", visible=True),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
    ],
)

if __name__ == "__main__":
    demo.launch()