ruslanmv commited on
Commit
4282ccc
·
verified ·
1 Parent(s): 9d6a6b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -41
app.py CHANGED
@@ -1,75 +1,73 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
- from transformers import AutoTokenizer # Import the tokenizer
 
 
4
 
5
- # Use the appropriate tokenizer for your model.
6
  tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
7
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
9
- # Define a maximum context length (tokens). Check your model's documentation!
10
- MAX_CONTEXT_LENGTH = 4096 # Example: Adjust this based on your model!
11
 
12
- # Read the default prompt from a file
13
  with open("prompt.txt", "r") as file:
14
  nvc_prompt_template = file.read()
15
 
 
 
 
16
  def count_tokens(text: str) -> int:
17
- """Counts the number of tokens in a given string."""
18
  return len(tokenizer.encode(text))
19
 
20
- def truncate_history(history: list[tuple[str, str]], system_message: str, max_length: int) -> list[tuple[str, str]]:
21
- """Truncates the conversation history to fit within the maximum token limit.
22
- Args:
23
- history: The conversation history (list of user/assistant tuples).
24
- system_message: The system message.
25
- max_length: The maximum number of tokens allowed.
26
- Returns:
27
- The truncated history.
28
- """
29
- truncated_history = []
30
- system_message_tokens = count_tokens(system_message)
31
- current_length = system_message_tokens
32
-
33
- # Iterate backwards through the history (newest to oldest)
34
- for user_msg, assistant_msg in reversed(history):
35
- user_tokens = count_tokens(user_msg) if user_msg else 0
36
- assistant_tokens = count_tokens(assistant_msg) if assistant_msg else 0
37
- turn_tokens = user_tokens + assistant_tokens
38
-
39
- if current_length + turn_tokens <= max_length:
40
- truncated_history.insert(0, (user_msg, assistant_msg)) # Add to the beginning
41
- current_length += turn_tokens
42
  else:
43
- break # Stop adding turns if we exceed the limit
44
 
45
- return truncated_history
46
 
47
  def respond(
48
  message,
49
- history: list[tuple[str, str]],
50
  system_message,
51
  max_tokens,
52
  temperature,
53
  top_p,
54
  ):
55
- """Responds to a user message, maintaining conversation history, using special tokens and message list."""
56
  formatted_system_message = nvc_prompt_template
57
 
58
- truncated_history = truncate_history(history, formatted_system_message, MAX_CONTEXT_LENGTH - max_tokens - 100) # Reserve space for the new message and some generation
 
 
 
 
 
 
59
 
60
- messages = [{"role": "system", "content": formatted_system_message}] # Start with system message
61
- for user_msg, assistant_msg in truncated_history:
62
- if user_msg:
63
- messages.append({"role": "user", "content": f"<|user|>\n{user_msg}</s>"})
64
- if assistant_msg:
65
- messages.append({"role": "assistant", "content": f"<|assistant|>\n{assistant_msg}</s>"})
66
 
67
- messages.append({"role": "user", "content": f"<|user|>\n{message}</s>"})
 
 
 
 
 
68
 
69
  response = ""
70
  try:
71
  for chunk in client.chat_completion(
72
- messages,
73
  max_tokens=max_tokens,
74
  stream=True,
75
  temperature=temperature,
@@ -78,6 +76,10 @@ def respond(
78
  token = chunk.choices[0].delta.content
79
  response += token
80
  yield response
 
 
 
 
81
  except Exception as e:
82
  print(f"An error occurred: {e}")
83
  yield "I'm sorry, I encountered an error. Please try again."
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
+ from transformers import AutoTokenizer
4
+ from langchain.memory import ConversationBufferWindowMemory
5
+ from langchain.schema import HumanMessage, AIMessage, SystemMessage
6
 
7
+ # Initialize tokenizer and inference client
8
  tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
9
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
10
 
11
+ MAX_CONTEXT_LENGTH = 4096
 
12
 
13
+ # Load prompt from file
14
  with open("prompt.txt", "r") as file:
15
  nvc_prompt_template = file.read()
16
 
17
+ # Initialize LangChain Memory (buffer window to keep recent conversation)
18
+ memory = ConversationBufferWindowMemory(k=10, return_messages=True)
19
+
20
  def count_tokens(text: str) -> int:
 
21
  return len(tokenizer.encode(text))
22
 
23
+ def truncate_history(messages, max_length):
24
+ truncated_messages = []
25
+ total_tokens = 0
26
+
27
+ for message in reversed(messages):
28
+ message_tokens = count_tokens(message.content)
29
+ if total_tokens + message_tokens <= max_length:
30
+ truncated_messages.insert(0, message)
31
+ total_tokens += message_tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  else:
33
+ break
34
 
35
+ return truncated_messages
36
 
37
  def respond(
38
  message,
39
+ history,
40
  system_message,
41
  max_tokens,
42
  temperature,
43
  top_p,
44
  ):
 
45
  formatted_system_message = nvc_prompt_template
46
 
47
+ # Retrieve conversation history from LangChain memory
48
+ memory.save_context({"input": message}, {"output": ""})
49
+ chat_history = memory.load_memory_variables({})["history"]
50
+
51
+ # Truncate history to ensure it fits within context window
52
+ max_history_tokens = MAX_CONTEXT_LENGTH - max_tokens - count_tokens(formatted_system_message) - 100
53
+ truncated_chat_history = truncate_history(chat_history, max_history_tokens)
54
 
55
+ # Construct the messages for inference
56
+ messages = [SystemMessage(content=formatted_system_message)]
57
+ messages.extend(truncated_chat_history)
58
+ messages.append(HumanMessage(content=message))
 
 
59
 
60
+ # Convert LangChain messages to the format required by HuggingFace client
61
+ formatted_messages = []
62
+ for msg in messages:
63
+ role = "system" if isinstance(msg, SystemMessage) else "user" if isinstance(msg, HumanMessage) else "assistant"
64
+ content = f"<|{role}|>\n{msg.content}</s>"
65
+ formatted_messages.append({"role": role, "content": content})
66
 
67
  response = ""
68
  try:
69
  for chunk in client.chat_completion(
70
+ formatted_messages,
71
  max_tokens=max_tokens,
72
  stream=True,
73
  temperature=temperature,
 
76
  token = chunk.choices[0].delta.content
77
  response += token
78
  yield response
79
+
80
+ # Save AI's response in LangChain memory
81
+ memory.chat_memory.add_ai_message(response)
82
+
83
  except Exception as e:
84
  print(f"An error occurred: {e}")
85
  yield "I'm sorry, I encountered an error. Please try again."