File size: 4,546 Bytes
e00ad77 0047abe e00ad77 0047abe e8d4ae4 e00ad77 58bcb23 0047abe 58bcb23 0047abe ca509cb 15152ff 0047abe 4282ccc 58bcb23 0047abe 58bcb23 0047abe 4282ccc 0047abe cdfa6da 4282ccc 58bcb23 4282ccc 58bcb23 f33cc36 0047abe 1221286 f33cc36 0047abe 15152ff 0047abe f33cc36 0047abe 15152ff 0047abe e00ad77 cdfa6da fa909a7 ca509cb 0047abe ca509cb fa909a7 ca509cb 15152ff 0047abe 15152ff 9d6a6b8 ff9a596 15152ff ca509cb 15152ff c1faa76 ca509cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer # Import the tokenizer
from langchain.memory import ConversationBufferMemory
from langchain.schema import HumanMessage, AIMessage
# Use the appropriate tokenizer for your model.
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
# Define a maximum context length (tokens). Check your model's documentation!
MAX_CONTEXT_LENGTH = 4096 # Example: Adjust this based on your model!
# Read the default prompt from a file
with open("prompt.txt", "r") as file:
nvc_prompt_template = file.read()
# Initialize LangChain Conversation Memory
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
def count_tokens(text: str) -> int:
"""Counts the number of tokens in a given string."""
return len(tokenizer.encode(text))
def truncate_memory(memory, system_message: str, max_length: int):
"""
Truncates the conversation memory messages to fit within the maximum token limit.
Args:
memory: The LangChain conversation memory object.
system_message: The system message.
max_length: The maximum number of tokens allowed.
Returns:
A list of messages (as dicts with role and content) that fit within the token limit.
"""
truncated_messages = []
system_tokens = count_tokens(system_message)
current_length = system_tokens
# Iterate backwards through the memory (newest to oldest)
for msg in reversed(memory.chat_memory.messages):
tokens = count_tokens(msg.content)
if current_length + tokens <= max_length:
role = "user" if isinstance(msg, HumanMessage) else "assistant"
truncated_messages.insert(0, {"role": role, "content": msg.content})
current_length += tokens
else:
break
return truncated_messages
def respond(
message,
history: list[tuple[str, str]], # Required by Gradio but we now use LangChain memory
system_message,
max_tokens,
temperature,
top_p,
):
"""
Responds to a user message while maintaining conversation history via LangChain memory.
It builds the prompt with a system message and the (truncated) conversation history,
streams the response from the client, and finally updates the memory with the new response.
"""
# Use your prompt template as the system message.
formatted_system_message = nvc_prompt_template
# Prepare and add the new user message (with your special tokens) to memory.
new_user_message = f"<|user|>\n{message}</s>"
memory.chat_memory.add_message(HumanMessage(content=new_user_message))
# Truncate memory to ensure the context fits within the maximum token length (reserve space for generation).
truncated_history = truncate_memory(
memory, formatted_system_message, MAX_CONTEXT_LENGTH - max_tokens - 100
)
# Ensure the current user message is present at the end.
if not truncated_history or truncated_history[-1]["content"] != new_user_message:
truncated_history.append({"role": "user", "content": new_user_message})
# Build the full message list: system prompt + conversation history.
messages = [{"role": "system", "content": formatted_system_message}] + truncated_history
response = ""
try:
for chunk in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = chunk.choices[0].delta.content
response += token
yield response
except Exception as e:
print(f"An error occurred: {e}")
yield "I'm sorry, I encountered an error. Please try again."
# Once the full response is generated, add it to the LangChain memory.
memory.chat_memory.add_message(AIMessage(content=f"<|assistant|>\n{response}</s>"))
# --- Gradio Interface ---
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value=nvc_prompt_template, label="System message", visible=True),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
],
)
if __name__ == "__main__":
demo.launch()
|