Tonic commited on
Commit
b9eff4b
·
1 Parent(s): dcb9c75

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -28,7 +28,7 @@ class TuluChatBot:
28
  prompt = f"<|assistant|>\n {self.system_message}\n\n <|user|>{user_message}\n\n<|assistant|>\n"
29
  return prompt
30
 
31
- def predict(self, user_message, temperature, max_new_tokens, top_p, repetition_penalty, do_sample):
32
  prompt = self.format_prompt(user_message)
33
  inputs = self.tokenizer(prompt, return_tensors='pt', add_special_tokens=False)
34
  input_ids = inputs["input_ids"].to(self.model.device)
@@ -48,14 +48,14 @@ class TuluChatBot:
48
  response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
49
  return response
50
 
51
- def gradio_predict(user_message, system_message, max_new_tokens, temperature, top_p, repetition_penalty, do_sample):
52
  Tulu_bot.set_system_message(system_message)
53
  if not do_sample:
54
  max_length = 780
55
  temperature = 0.9
56
  top_p = 0.9
57
  repetition_penalty = 0.9
58
- response = Tulu_bot.predict(user_message, temperature, max_new_tokens, top_p, repetition_penalty, do_sample)
59
  return response
60
 
61
  Tulu_bot = TuluChatBot(model, tokenizer)
@@ -81,7 +81,7 @@ with gr.Blocks(theme = "ParityError/Anime") as demo:
81
  output_text = gr.Textbox(label="🌷Tulu Response")
82
 
83
  def process(user_message, system_message, max_new_tokens, temperature, top_p, repetition_penalty, do_sample):
84
- return gradio_predict(user_message, system_message, max_new_tokens, temperature, top_p, repetition_penalty, do_sample)
85
 
86
  submit_button.click(
87
  process,
 
28
  prompt = f"<|assistant|>\n {self.system_message}\n\n <|user|>{user_message}\n\n<|assistant|>\n"
29
  return prompt
30
 
31
+ def Tulu(self, user_message, temperature, max_new_tokens, top_p, repetition_penalty, do_sample):
32
  prompt = self.format_prompt(user_message)
33
  inputs = self.tokenizer(prompt, return_tensors='pt', add_special_tokens=False)
34
  input_ids = inputs["input_ids"].to(self.model.device)
 
48
  response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
49
  return response
50
 
51
+ def gradio_Tulu(user_message, system_message, max_new_tokens, temperature, top_p, repetition_penalty, do_sample):
52
  Tulu_bot.set_system_message(system_message)
53
  if not do_sample:
54
  max_length = 780
55
  temperature = 0.9
56
  top_p = 0.9
57
  repetition_penalty = 0.9
58
+ response = Tulu_bot.Tulu(user_message, temperature, max_new_tokens, top_p, repetition_penalty, do_sample)
59
  return response
60
 
61
  Tulu_bot = TuluChatBot(model, tokenizer)
 
81
  output_text = gr.Textbox(label="🌷Tulu Response")
82
 
83
  def process(user_message, system_message, max_new_tokens, temperature, top_p, repetition_penalty, do_sample):
84
+ return gradio_Tulu(user_message, system_message, max_new_tokens, temperature, top_p, repetition_penalty, do_sample)
85
 
86
  submit_button.click(
87
  process,