import gradio as gr from huggingface_hub import InferenceClient from transformers import AutoTokenizer from langchain.memory import ConversationBufferWindowMemory from langchain.schema import HumanMessage, AIMessage, SystemMessage # Initialize tokenizer and inference client tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta") client = InferenceClient("HuggingFaceH4/zephyr-7b-beta") MAX_CONTEXT_LENGTH = 4096 # Load prompt from file with open("prompt.txt", "r") as file: nvc_prompt_template = file.read() # Initialize LangChain Memory (buffer window to keep recent conversation) memory = ConversationBufferWindowMemory(k=10, return_messages=True) def count_tokens(text: str) -> int: return len(tokenizer.encode(text)) def truncate_history(messages, max_length): truncated_messages = [] total_tokens = 0 for message in reversed(messages): message_tokens = count_tokens(message.content) if total_tokens + message_tokens <= max_length: truncated_messages.insert(0, message) total_tokens += message_tokens else: break return truncated_messages def respond( message, history, system_message, max_tokens, temperature, top_p, ): formatted_system_message = nvc_prompt_template # Retrieve conversation history from LangChain memory memory.save_context({"input": message}, {"output": ""}) chat_history = memory.load_memory_variables({})["history"] # Truncate history to ensure it fits within context window max_history_tokens = MAX_CONTEXT_LENGTH - max_tokens - count_tokens(formatted_system_message) - 100 truncated_chat_history = truncate_history(chat_history, max_history_tokens) # Construct the messages for inference messages = [SystemMessage(content=formatted_system_message)] messages.extend(truncated_chat_history) messages.append(HumanMessage(content=message)) # Convert LangChain messages to the format required by HuggingFace client formatted_messages = [] for msg in messages: role = "system" if isinstance(msg, SystemMessage) else "user" if isinstance(msg, HumanMessage) else "assistant" content = f"<|{role}|>\n{msg.content}" formatted_messages.append({"role": role, "content": content}) response = "" try: for chunk in client.chat_completion( formatted_messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p, ): token = chunk.choices[0].delta.content response += token yield response # Save AI's response in LangChain memory memory.chat_memory.add_ai_message(response) except Exception as e: print(f"An error occurred: {e}") yield "I'm sorry, I encountered an error. Please try again." # --- Gradio Interface --- demo = gr.ChatInterface( respond, additional_inputs=[ gr.Textbox(value=nvc_prompt_template, label="System message", visible=True), gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"), ], ) if __name__ == "__main__": demo.launch()