michailroussos commited on
Commit
9b00c4f
·
1 Parent(s): 4668547
Files changed (1) hide show
  1. app.py +32 -14
app.py CHANGED
@@ -16,19 +16,25 @@ FastLanguageModel.for_inference(model) # Enable optimized inference
16
 
17
  # Define the response function
18
  def respond(message, history, system_message, max_tokens, temperature, top_p):
 
19
  messages = [{"role": "system", "content": system_message}]
20
- for exchange in history:
21
- messages.append({"role": "user", "content": exchange["user"]})
22
- messages.append({"role": "assistant", "content": exchange["assistant"]})
 
 
 
23
  messages.append({"role": "user", "content": message})
24
-
 
25
  inputs = tokenizer.apply_chat_template(
26
  messages,
27
  tokenize=True,
28
  add_generation_prompt=True,
29
  return_tensors="pt",
30
  ).to("cuda" if torch.cuda.is_available() else "cpu")
31
-
 
32
  attention_mask = inputs.ne(tokenizer.pad_token_id).long()
33
  generated_tokens = model.generate(
34
  input_ids=inputs,
@@ -40,25 +46,37 @@ def respond(message, history, system_message, max_tokens, temperature, top_p):
40
  )
41
  response = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
42
 
 
 
 
43
  history.append({"user": message, "assistant": response})
44
- formatted_history = [
45
- {"role": "user", "content": exchange["user"]} if "user" in exchange else
46
- {"role": "assistant", "content": exchange["assistant"]}
47
- for exchange in history
48
- ]
 
 
 
 
 
 
 
49
  return formatted_history
50
 
 
51
  # Define the Gradio interface
52
  demo = gr.ChatInterface(
53
- respond,
54
  additional_inputs=[
55
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
56
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
57
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
58
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
59
  ],
60
  type="messages",
61
  )
62
 
 
63
  if __name__ == "__main__":
64
  demo.launch(share=False) # Use share=False for local testing
 
16
 
17
  # Define the response function
18
  def respond(message, history, system_message, max_tokens, temperature, top_p):
19
+ # Add the system message and include previous conversation history
20
  messages = [{"role": "system", "content": system_message}]
21
+ if history:
22
+ for entry in history:
23
+ messages.append({"role": "user", "content": entry["user"]})
24
+ messages.append({"role": "assistant", "content": entry["assistant"]})
25
+
26
+ # Add the user's new input
27
  messages.append({"role": "user", "content": message})
28
+
29
+ # Tokenize inputs
30
  inputs = tokenizer.apply_chat_template(
31
  messages,
32
  tokenize=True,
33
  add_generation_prompt=True,
34
  return_tensors="pt",
35
  ).to("cuda" if torch.cuda.is_available() else "cpu")
36
+
37
+ # Generate the response
38
  attention_mask = inputs.ne(tokenizer.pad_token_id).long()
39
  generated_tokens = model.generate(
40
  input_ids=inputs,
 
46
  )
47
  response = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
48
 
49
+ # Update history
50
+ if history is None:
51
+ history = []
52
  history.append({"user": message, "assistant": response})
53
+
54
+ print("history:")
55
+ print(history)
56
+ # Format history for Gradio (strictly enforce the role-content format)
57
+ formatted_history = []
58
+ for entry in history:
59
+ formatted_history.append({"role": "user", "content": entry["user"]})
60
+ formatted_history.append({"role": "assistant", "content": entry["assistant"]})
61
+
62
+ print("formatted_history:")
63
+ print(formatted_history)
64
+ # Return formatted history
65
  return formatted_history
66
 
67
+
68
  # Define the Gradio interface
69
  demo = gr.ChatInterface(
70
+ fn=respond,
71
  additional_inputs=[
72
+ gr.Textbox(value="You are a helpful assistant.", label="System message"),
73
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max tokens"),
74
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
75
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
76
  ],
77
  type="messages",
78
  )
79
 
80
+
81
  if __name__ == "__main__":
82
  demo.launch(share=False) # Use share=False for local testing