johntheajs commited on
Commit
14a2982
·
verified ·
1 Parent(s): f1fdc2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -11
app.py CHANGED
@@ -3,31 +3,33 @@ import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
  # Load the model and tokenizer
6
- model_id = "google/gemma-7b" # Replace with "google/gemma-7b-it" for instruction tuning
7
  tokenizer = AutoTokenizer.from_pretrained(model_id)
8
  model = AutoModelForCausalLM.from_pretrained(model_id)
9
 
10
  # Function to generate responses based on user messages
11
  def generate_response(messages):
12
- input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt").to(model.device)
13
- outputs = model.generate(input_ids, max_new_tokens=100)
14
  generated_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
15
  return generated_response
16
 
 
17
  st.title("Gemma Chatbot")
18
  messages = []
 
19
  user_input = st.text_input("You:", "")
20
  if st.button("Send"):
21
  if user_input:
22
- messages.append({"role": "user", "content": user_input})
23
  bot_response = generate_response(messages)
24
- messages.append({"role": "assistant", "content": bot_response})
25
  else:
26
- st.warning("Please enter a message to process.")
27
 
28
  # Display conversation
29
- for message in messages:
30
- if message["role"] == "user":
31
- st.text_input("You:", value=message["content"], disabled=True)
32
- elif message["role"] == "assistant":
33
- st.text_area("Gemma:", value=message["content"], disabled=True)
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
  # Load the model and tokenizer
6
+ model_id = "google/gemma-7b"
7
  tokenizer = AutoTokenizer.from_pretrained(model_id)
8
  model = AutoModelForCausalLM.from_pretrained(model_id)
9
 
10
  # Function to generate responses based on user messages
11
  def generate_response(messages):
12
+ input_ids = tokenizer.encode(messages, return_tensors="pt").to(model.device)
13
+ outputs = model.generate(input_ids, max_length=100, pad_token_id=tokenizer.eos_token_id)
14
  generated_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
15
  return generated_response
16
 
17
+ # Streamlit app
18
  st.title("Gemma Chatbot")
19
  messages = []
20
+
21
  user_input = st.text_input("You:", "")
22
  if st.button("Send"):
23
  if user_input:
24
+ messages.append(user_input)
25
  bot_response = generate_response(messages)
26
+ messages.append(bot_response)
27
  else:
28
+ st.warning("Please enter a message.")
29
 
30
  # Display conversation
31
+ for i, message in enumerate(messages):
32
+ if i % 2 == 0:
33
+ st.text_input("You:", value=message, disabled=True)
34
+ else:
35
+ st.text_area("Gemma:", value=message, disabled=True)