hertogateis commited on
Commit
348a356
·
verified ·
1 Parent(s): 5957ac9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -15
app.py CHANGED
@@ -1,19 +1,9 @@
1
  import streamlit as st
2
- from transformers import T5ForConditionalGeneration, T5Tokenizer
3
- import torch
4
-
5
- # Load pre-trained model and tokenizer from the "KhantKyaw/T5-small_new_chatbot"
6
- model_name = "KhantKyaw/T5-small_new_chatbot" # Use the fine-tuned model
7
- model = T5ForConditionalGeneration.from_pretrained(model_name)
8
  tokenizer = T5Tokenizer.from_pretrained(model_name)
9
-
10
- # Set device to GPU if available for faster inference, otherwise fallback to CPU
11
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
- model.to(device)
13
-
14
- # Streamlit Interface
15
- st.title("Mental Health Chatbot with T5")
16
-
17
  def generate_response(input_text):
18
  input_ids = tokenizer.encode(input_text, return_tensors='pt')
19
  outputs = model.generate(input_ids,
@@ -24,8 +14,8 @@ def generate_response(input_text):
24
  outputs[0], skip_special_tokens=True)
25
  return generated_text
26
 
 
27
  prompt = st.chat_input(placeholder="Say Something!",key=None, max_chars=None, disabled=False, on_submit=None, args=None, kwargs=None)
28
  if prompt:
29
  with st.chat_message(name="AI",avatar=None):
30
  st.write(generate_response(prompt))
31
-
 
1
  import streamlit as st
2
+ import transformers
3
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
4
+ model_name = "t5-small"
 
 
 
5
  tokenizer = T5Tokenizer.from_pretrained(model_name)
6
+ model = T5ForConditionalGeneration.from_pretrained(model_name)
 
 
 
 
 
 
 
7
  def generate_response(input_text):
8
  input_ids = tokenizer.encode(input_text, return_tensors='pt')
9
  outputs = model.generate(input_ids,
 
14
  outputs[0], skip_special_tokens=True)
15
  return generated_text
16
 
17
+
18
  prompt = st.chat_input(placeholder="Say Something!",key=None, max_chars=None, disabled=False, on_submit=None, args=None, kwargs=None)
19
  if prompt:
20
  with st.chat_message(name="AI",avatar=None):
21
  st.write(generate_response(prompt))