Spaces:
Sleeping
Sleeping
File size: 5,873 Bytes
0d89bdf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import uuid
import time
# Page configuration
st.set_page_config(
page_title="ChatBot",
page_icon="💬",
layout="wide",
initial_sidebar_state="expanded"
)
# Initialize session state variables
if "chat_history" not in st.session_state:
st.session_state.chat_history = {}
if "current_chat_id" not in st.session_state:
st.session_state.current_chat_id = None
if "messages" not in st.session_state:
st.session_state.messages = []
# Load model and tokenizer
@st.cache_resource
def load_model():
model_name = "facebook/blenderbot-400M-distill"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
return tokenizer, model
tokenizer, model = load_model()
# Function to generate response
def generate_response(prompt):
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
inputs.input_ids,
max_length=100,
num_return_sequences=1,
temperature=0.7,
top_p=0.9,
do_sample=True
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
# Custom CSS
st.markdown("""
<style>
.main {
background-color: #f9f9f9;
}
.stTextInput>div>div>input {
background-color: white;
}
.chat-message {
padding: 1rem;
border-radius: 0.5rem;
margin-bottom: 1rem;
display: flex;
flex-direction: row;
align-items: flex-start;
}
.chat-message.user {
background-color: #f0f0f0;
}
.chat-message.bot {
background-color: #e6f7ff;
}
.chat-message .avatar {
width: 40px;
height: 40px;
border-radius: 50%;
object-fit: cover;
margin-right: 1rem;
}
.chat-message .message {
flex-grow: 1;
}
.sidebar-chat {
padding: 0.5rem;
border-radius: 0.5rem;
margin-bottom: 0.5rem;
cursor: pointer;
}
.sidebar-chat:hover {
background-color: #f0f0f0;
}
.sidebar-chat.active {
background-color: #e6f7ff;
font-weight: bold;
}
.stButton>button {
width: 100%;
}
</style>
""", unsafe_allow_html=True)
# Sidebar for chat history
with st.sidebar:
st.title("💬 Chats")
# New chat button
if st.button("+ New Chat"):
# Generate a new chat ID
new_chat_id = str(uuid.uuid4())
st.session_state.current_chat_id = new_chat_id
st.session_state.chat_history[new_chat_id] = {
"title": f"Chat {len(st.session_state.chat_history) + 1}",
"messages": []
}
st.session_state.messages = []
st.rerun()
st.markdown("---")
# Display chat history
for chat_id, chat_data in st.session_state.chat_history.items():
chat_class = "active" if chat_id == st.session_state.current_chat_id else ""
if st.sidebar.markdown(f"""
<div class="sidebar-chat {chat_class}" id="{chat_id}">
{chat_data["title"]}
</div>
""", unsafe_allow_html=True):
st.session_state.current_chat_id = chat_id
st.session_state.messages = chat_data["messages"]
st.rerun()
# Main chat interface
st.title("ChatBot")
# Initialize a new chat if none exists
if not st.session_state.current_chat_id and not st.session_state.chat_history:
new_chat_id = str(uuid.uuid4())
st.session_state.current_chat_id = new_chat_id
st.session_state.chat_history[new_chat_id] = {
"title": "New Chat",
"messages": []
}
# Display chat messages
if st.session_state.current_chat_id:
for i, message in enumerate(st.session_state.messages):
if message["role"] == "user":
st.markdown(f"""
<div class="chat-message user">
<img class="avatar" src="https://api.dicebear.com/7.x/bottts/svg?seed=user" alt="User Avatar">
<div class="message">{message["content"]}</div>
</div>
""", unsafe_allow_html=True)
else:
st.markdown(f"""
<div class="chat-message bot">
<img class="avatar" src="https://api.dicebear.com/7.x/bottts/svg?seed=bot" alt="Bot Avatar">
<div class="message">{message["content"]}</div>
</div>
""", unsafe_allow_html=True)
# Chat input
if prompt := st.chat_input("Type your message here..."):
if st.session_state.current_chat_id:
# Add user message to chat
st.session_state.messages.append({"role": "user", "content": prompt})
# Update chat history
st.session_state.chat_history[st.session_state.current_chat_id]["messages"] = st.session_state.messages
# Update chat title if it's the first message
if len(st.session_state.messages) == 1:
st.session_state.chat_history[st.session_state.current_chat_id]["title"] = prompt[:20] + "..." if len(prompt) > 20 else prompt
st.rerun()
# Generate and display bot response for the last user message
if st.session_state.messages and st.session_state.messages[-1]["role"] == "user":
with st.spinner("Thinking..."):
# Simulate thinking time
time.sleep(0.5)
# Generate response
response = generate_response(st.session_state.messages[-1]["content"])
# Add bot response to chat
st.session_state.messages.append({"role": "assistant", "content": response})
# Update chat history
st.session_state.chat_history[st.session_state.current_chat_id]["messages"] = st.session_state.messages
st.rerun()
|