Guanzheng commited on
Commit
c5710ac
·
1 Parent(s): 5505020

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -9
app.py CHANGED
@@ -86,6 +86,8 @@ def build_chat():
86
  prompt = conv.get_prompt()
87
  return prompt
88
 
 
 
89
  @spaces.GPU
90
  def generate(
91
  message: str,
@@ -97,15 +99,18 @@ def generate(
97
  top_k: int = 50,
98
  repetition_penalty: float = 1.2,
99
  ) -> Iterator[str]:
100
- conversation = []
101
- if system_prompt:
102
- conversation.append({"role": "system", "content": system_prompt})
103
- for user, assistant in chat_history:
104
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
105
- conversation.append({"role": "user", "content": message})
106
-
107
- chat = tokenizer.apply_chat_template(conversation, tokenize=False)
108
- inputs = tokenizer(chat, return_tensors="pt", add_special_tokens=False).to("cuda")
 
 
 
109
  if len(inputs) > MAX_INPUT_TOKEN_LENGTH:
110
  inputs = inputs[-MAX_INPUT_TOKEN_LENGTH:]
111
  gr.Warning("Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
86
  prompt = conv.get_prompt()
87
  return prompt
88
 
89
+ from fastchat.model import get_conversation_template
90
+
91
  @spaces.GPU
92
  def generate(
93
  message: str,
 
99
  top_k: int = 50,
100
  repetition_penalty: float = 1.2,
101
  ) -> Iterator[str]:
102
+ conv = get_conversation_template("vicuna")
103
+ conv.append_message(conv.roles[0], message)
104
+ conv.append_message(conv.roles[1], None)
105
+ prompt = conv.get_prompt()
106
+ # if system_prompt:
107
+ # conversation.append({"role": "system", "content": system_prompt})
108
+ # for user, assistant in chat_history:
109
+ # conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
110
+ # conversation.append({"role": "user", "content": message})
111
+
112
+ # chat = tokenizer.apply_chat_template(conversation, tokenize=False)
113
+ inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to("cuda")
114
  if len(inputs) > MAX_INPUT_TOKEN_LENGTH:
115
  inputs = inputs[-MAX_INPUT_TOKEN_LENGTH:]
116
  gr.Warning("Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")