Spaces:
Running
Running
import streamlit as st | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import uuid | |
import time | |
# Page configuration | |
st.set_page_config( | |
page_title="ChatBot", | |
page_icon="💬", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
# Initialize session state variables | |
if "chat_history" not in st.session_state: | |
st.session_state.chat_history = {} | |
if "current_chat_id" not in st.session_state: | |
st.session_state.current_chat_id = None | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# Load model and tokenizer | |
def load_model(): | |
model_name = "facebook/blenderbot-400M-distill" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
return tokenizer, model | |
tokenizer, model = load_model() | |
# Function to generate response | |
def generate_response(prompt): | |
inputs = tokenizer(prompt, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs.input_ids, | |
max_length=100, | |
num_return_sequences=1, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response | |
# Custom CSS | |
st.markdown(""" | |
<style> | |
.main { | |
background-color: #f9f9f9; | |
} | |
.stTextInput>div>div>input { | |
background-color: white; | |
} | |
.chat-message { | |
padding: 1rem; | |
border-radius: 0.5rem; | |
margin-bottom: 1rem; | |
display: flex; | |
flex-direction: row; | |
align-items: flex-start; | |
} | |
.chat-message.user { | |
background-color: #f0f0f0; | |
} | |
.chat-message.bot { | |
background-color: #e6f7ff; | |
} | |
.chat-message .avatar { | |
width: 40px; | |
height: 40px; | |
border-radius: 50%; | |
object-fit: cover; | |
margin-right: 1rem; | |
} | |
.chat-message .message { | |
flex-grow: 1; | |
} | |
.sidebar-chat { | |
padding: 0.5rem; | |
border-radius: 0.5rem; | |
margin-bottom: 0.5rem; | |
cursor: pointer; | |
} | |
.sidebar-chat:hover { | |
background-color: #f0f0f0; | |
} | |
.sidebar-chat.active { | |
background-color: #e6f7ff; | |
font-weight: bold; | |
} | |
.stButton>button { | |
width: 100%; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Sidebar for chat history | |
with st.sidebar: | |
st.title("💬 Chats") | |
# New chat button | |
if st.button("+ New Chat"): | |
# Generate a new chat ID | |
new_chat_id = str(uuid.uuid4()) | |
st.session_state.current_chat_id = new_chat_id | |
st.session_state.chat_history[new_chat_id] = { | |
"title": f"Chat {len(st.session_state.chat_history) + 1}", | |
"messages": [] | |
} | |
st.session_state.messages = [] | |
st.rerun() | |
st.markdown("---") | |
# Display chat history | |
for chat_id, chat_data in st.session_state.chat_history.items(): | |
chat_class = "active" if chat_id == st.session_state.current_chat_id else "" | |
if st.sidebar.markdown(f""" | |
<div class="sidebar-chat {chat_class}" id="{chat_id}"> | |
{chat_data["title"]} | |
</div> | |
""", unsafe_allow_html=True): | |
st.session_state.current_chat_id = chat_id | |
st.session_state.messages = chat_data["messages"] | |
st.rerun() | |
# Main chat interface | |
st.title("ChatBot") | |
# Initialize a new chat if none exists | |
if not st.session_state.current_chat_id and not st.session_state.chat_history: | |
new_chat_id = str(uuid.uuid4()) | |
st.session_state.current_chat_id = new_chat_id | |
st.session_state.chat_history[new_chat_id] = { | |
"title": "New Chat", | |
"messages": [] | |
} | |
# Display chat messages | |
if st.session_state.current_chat_id: | |
for i, message in enumerate(st.session_state.messages): | |
if message["role"] == "user": | |
st.markdown(f""" | |
<div class="chat-message user"> | |
<img class="avatar" src="https://api.dicebear.com/7.x/bottts/svg?seed=user" alt="User Avatar"> | |
<div class="message">{message["content"]}</div> | |
</div> | |
""", unsafe_allow_html=True) | |
else: | |
st.markdown(f""" | |
<div class="chat-message bot"> | |
<img class="avatar" src="https://api.dicebear.com/7.x/bottts/svg?seed=bot" alt="Bot Avatar"> | |
<div class="message">{message["content"]}</div> | |
</div> | |
""", unsafe_allow_html=True) | |
# Chat input | |
if prompt := st.chat_input("Type your message here..."): | |
if st.session_state.current_chat_id: | |
# Add user message to chat | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
# Update chat history | |
st.session_state.chat_history[st.session_state.current_chat_id]["messages"] = st.session_state.messages | |
# Update chat title if it's the first message | |
if len(st.session_state.messages) == 1: | |
st.session_state.chat_history[st.session_state.current_chat_id]["title"] = prompt[:20] + "..." if len(prompt) > 20 else prompt | |
st.rerun() | |
# Generate and display bot response for the last user message | |
if st.session_state.messages and st.session_state.messages[-1]["role"] == "user": | |
with st.spinner("Thinking..."): | |
# Simulate thinking time | |
time.sleep(0.5) | |
# Generate response | |
response = generate_response(st.session_state.messages[-1]["content"]) | |
# Add bot response to chat | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |
# Update chat history | |
st.session_state.chat_history[st.session_state.current_chat_id]["messages"] = st.session_state.messages | |
st.rerun() | |