File size: 3,412 Bytes
e00ad77 4282ccc e00ad77 4282ccc e8d4ae4 e00ad77 58bcb23 4282ccc 58bcb23 4282ccc ca509cb 15152ff 4282ccc 58bcb23 4282ccc cdfa6da 4282ccc 58bcb23 4282ccc 58bcb23 f33cc36 4282ccc 1221286 f33cc36 15152ff 4282ccc f33cc36 4282ccc 15152ff 4282ccc e00ad77 cdfa6da fa909a7 ca509cb 4282ccc ca509cb 4282ccc fa909a7 ca509cb 15152ff 9d6a6b8 15152ff ca509cb 15152ff c1faa76 ca509cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer
from langchain.memory import ConversationBufferWindowMemory
from langchain.schema import HumanMessage, AIMessage, SystemMessage
# Initialize tokenizer and inference client
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
MAX_CONTEXT_LENGTH = 4096
# Load prompt from file
with open("prompt.txt", "r") as file:
nvc_prompt_template = file.read()
# Initialize LangChain Memory (buffer window to keep recent conversation)
memory = ConversationBufferWindowMemory(k=10, return_messages=True)
def count_tokens(text: str) -> int:
return len(tokenizer.encode(text))
def truncate_history(messages, max_length):
truncated_messages = []
total_tokens = 0
for message in reversed(messages):
message_tokens = count_tokens(message.content)
if total_tokens + message_tokens <= max_length:
truncated_messages.insert(0, message)
total_tokens += message_tokens
else:
break
return truncated_messages
def respond(
message,
history,
system_message,
max_tokens,
temperature,
top_p,
):
formatted_system_message = nvc_prompt_template
# Retrieve conversation history from LangChain memory
memory.save_context({"input": message}, {"output": ""})
chat_history = memory.load_memory_variables({})["history"]
# Truncate history to ensure it fits within context window
max_history_tokens = MAX_CONTEXT_LENGTH - max_tokens - count_tokens(formatted_system_message) - 100
truncated_chat_history = truncate_history(chat_history, max_history_tokens)
# Construct the messages for inference
messages = [SystemMessage(content=formatted_system_message)]
messages.extend(truncated_chat_history)
messages.append(HumanMessage(content=message))
# Convert LangChain messages to the format required by HuggingFace client
formatted_messages = []
for msg in messages:
role = "system" if isinstance(msg, SystemMessage) else "user" if isinstance(msg, HumanMessage) else "assistant"
content = f"<|{role}|>\n{msg.content}</s>"
formatted_messages.append({"role": role, "content": content})
response = ""
try:
for chunk in client.chat_completion(
formatted_messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = chunk.choices[0].delta.content
response += token
yield response
# Save AI's response in LangChain memory
memory.chat_memory.add_ai_message(response)
except Exception as e:
print(f"An error occurred: {e}")
yield "I'm sorry, I encountered an error. Please try again."
# --- Gradio Interface ---
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value=nvc_prompt_template, label="System message", visible=True),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
],
)
if __name__ == "__main__":
demo.launch()
|