hertogateis commited on
Commit
5957ac9
·
verified ·
1 Parent(s): 6422bb0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -25
app.py CHANGED
@@ -15,30 +15,17 @@ model.to(device)
15
  st.title("Mental Health Chatbot with T5")
16
 
17
  def generate_response(input_text):
18
- # Add conversational context to input
19
- input_text = f"You are a helpful assistant. {input_text}"
20
-
21
- # Tokenize input text
22
- input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
23
-
24
- # Generate a response from the model with advanced generation settings
25
  outputs = model.generate(input_ids,
26
- max_length=100, # max length of the output sequence
27
- num_beams=5, # Beam search for better results
28
- top_p=0.95, # Top-p sampling for more variety
29
- temperature=0.7, # Temperature controls randomness
30
- no_repeat_ngram_size=2, # Prevent repetition of n-grams
31
- pad_token_id=tokenizer.eos_token_id) # Padding token to avoid padding tokens being part of the output
32
-
33
- # Decode the model's output to a readable string
34
- bot_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
35
-
36
- return bot_output
37
-
38
- # Create input box for user to type a message
39
- user_input = st.text_input("You: ", "")
40
 
41
- if user_input:
42
- # Generate and display the bot's response
43
- response = generate_response(user_input)
44
- st.write(f"Bot: {response}")
 
15
  st.title("Mental Health Chatbot with T5")
16
 
17
  def generate_response(input_text):
18
+ input_ids = tokenizer.encode(input_text, return_tensors='pt')
 
 
 
 
 
 
19
  outputs = model.generate(input_ids,
20
+ min_length=5,
21
+ max_length=300,
22
+ do_sample=True, num_beams=5, no_repeat_ngram_size=2)
23
+ generated_text = tokenizer.decode(
24
+ outputs[0], skip_special_tokens=True)
25
+ return generated_text
26
+
27
+ prompt = st.chat_input(placeholder="Say Something!",key=None, max_chars=None, disabled=False, on_submit=None, args=None, kwargs=None)
28
+ if prompt:
29
+ with st.chat_message(name="AI",avatar=None):
30
+ st.write(generate_response(prompt))
 
 
 
31