SmallBot / app.py
hertogateis's picture
Update app.py
e92705a verified
raw
history blame
3.26 kB
import streamlit as st
from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch
import random
# Load pre-trained T5 model and tokenizer
model_name = "t5-small" # You can use "t5-base" or "t5-large" for better quality but slower response
model = T5ForConditionalGeneration.from_pretrained(model_name)
tokenizer = T5Tokenizer.from_pretrained(model_name)
# Set device to GPU if available for faster inference, otherwise fallback to CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Initialize chat history and conversation context
if 'history' not in st.session_state:
st.session_state['history'] = []
if 'conversation' not in st.session_state:
st.session_state['conversation'] = []
# Define multiple system prompts to control bot's behavior
system_prompts = [
"You are a helpful assistant. Respond in a polite, friendly, and informative manner.",
"You are a conversational chatbot. Provide friendly, engaging, and empathetic responses.",
"You are an informative assistant. Respond clearly and concisely to any questions asked.",
"You are a fun, casual chatbot. Keep the conversation light-hearted and interesting."
]
# Select a random system prompt to start the conversation
def get_system_prompt():
return random.choice(system_prompts)
def generate_response(input_text):
# If it's the first interaction, add the system prompt to the conversation history
if len(st.session_state['history']) == 0:
system_prompt = get_system_prompt()
st.session_state['conversation'].append(f"System: {system_prompt}")
system_input = f"conversation: {system_prompt} "
st.session_state['history'].append(system_input)
# Prepare the user input by appending it to the history
user_input = f"conversation: {input_text} "
# Concatenate history (system prompt + user input)
full_input = "".join(st.session_state['history']) + user_input
# Tokenize input text and generate response from the model
input_ids = tokenizer.encode(full_input, return_tensors="pt").to(device)
outputs = model.generate(input_ids, max_length=1000, num_beams=5, top_p=0.95, temperature=0.7, pad_token_id=tokenizer.eos_token_id)
# Decode the model's output
bot_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Update the history with the new user input and the model's output
st.session_state['history'].append(user_input)
st.session_state['history'].append(f"bot: {bot_output} ")
# Add both user input and bot response to the conversation history for display
st.session_state['conversation'].append(f"You: {input_text}")
st.session_state['conversation'].append(f"Bot: {bot_output}")
return bot_output
# Streamlit Interface
st.title("Chat with T5")
# Display the conversation history
if st.session_state['conversation']:
for message in st.session_state['conversation']:
st.markdown(f"<p style='color:gray; padding:5px;'>{message}</p>", unsafe_allow_html=True)
# Create input box for user
user_input = st.text_input("You: ", "")
if user_input:
# Generate and display the bot's response
response = generate_response(user_input)
st.write(f"Bot: {response}")