Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import BertForSequenceClassification, BertTokenizer | |
import torch | |
import time | |
import random | |
# [Previous CSS styles remain the same] | |
def local_css(): | |
st.markdown(""" | |
<style> | |
.chat-container { | |
padding: 10px; | |
border-radius: 5px; | |
margin-bottom: 10px; | |
display: flex; | |
flex-direction: column; | |
} | |
.user-message { | |
background-color: #e3f2fd; | |
padding: 10px; | |
border-radius: 15px; | |
margin: 5px; | |
margin-left: 20%; | |
margin-right: 5px; | |
align-self: flex-end; | |
max-width: 70%; | |
} | |
.bot-message { | |
background-color: #f5f5f5; | |
padding: 10px; | |
border-radius: 15px; | |
margin: 5px; | |
margin-right: 20%; | |
margin-left: 5px; | |
align-self: flex-start; | |
max-width: 70%; | |
} | |
.chat-input { | |
position: fixed; | |
bottom: 0; | |
width: 100%; | |
padding: 20px; | |
background-color: white; | |
} | |
.thinking-animation { | |
display: flex; | |
align-items: center; | |
margin-left: 10px; | |
} | |
.dot { | |
width: 8px; | |
height: 8px; | |
margin: 0 3px; | |
background: #888; | |
border-radius: 50%; | |
animation: bounce 0.8s infinite; | |
} | |
.dot:nth-child(2) { animation-delay: 0.2s; } | |
.dot:nth-child(3) { animation-delay: 0.4s; } | |
@keyframes bounce { | |
0%, 100% { transform: translateY(0); } | |
50% { transform: translateY(-5px); } | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
def load_model(): | |
model = BertForSequenceClassification.from_pretrained("trituenhantaoio/bert-base-vietnamese-uncased") | |
tokenizer = BertTokenizer.from_pretrained("trituenhantaoio/bert-base-vietnamese-uncased") | |
return model, tokenizer | |
def predict(text, model, tokenizer): | |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
predicted_class = torch.argmax(predictions, dim=1).item() | |
confidence = predictions[0][predicted_class].item() | |
return predicted_class, confidence | |
def get_bot_response(text, predicted_class, confidence): | |
# Define response templates based on classes and confidence levels | |
responses = { | |
0: { # Example for class 0 (positive sentiment) | |
'high_conf': [ | |
"Tôi cảm nhận được sự tích cực trong câu nói của bạn. Xin chia sẻ thêm nhé!", | |
"Thật vui khi nghe điều đó. Bạn có thể kể thêm không?", | |
"Tuyệt vời! Tôi rất đồng ý với bạn về điều này." | |
], | |
'low_conf': [ | |
"Có vẻ như đây là điều tích cực. Đúng không nhỉ?", | |
"Tôi nghĩ đây là một góc nhìn thú vị đấy.", | |
"Nghe có vẻ tốt đấy, bạn nghĩ sao?" | |
] | |
}, | |
1: { # Example for class 1 (negative sentiment) | |
'high_conf': [ | |
"Tôi hiểu đây là điều khó khăn với bạn. Hãy chia sẻ thêm nhé.", | |
"Tôi rất tiếc khi nghe điều này. Bạn cần tôi giúp gì không?", | |
"Đúng là một tình huống khó khăn. Chúng ta cùng tìm giải pháp nhé." | |
], | |
'low_conf': [ | |
"Có vẻ như bạn đang gặp khó khăn. Tôi có hiểu đúng không?", | |
"Tôi không chắc mình hiểu hết, bạn có thể giải thích thêm được không?", | |
"Hãy chia sẻ thêm để tôi có thể hiểu rõ hơn nhé." | |
] | |
} | |
} | |
# Add more classes based on your model's output | |
# Determine confidence level | |
confidence_threshold = 0.8 | |
conf_level = 'high_conf' if confidence > confidence_threshold else 'low_conf' | |
# Get appropriate response list | |
try: | |
response_list = responses[predicted_class][conf_level] | |
response = random.choice(response_list) | |
except KeyError: | |
response = "Xin lỗi, tôi không chắc chắn về điều này. Bạn có thể giải thích rõ hơn được không?" | |
# Add context from user's input | |
context_response = f"{response}" | |
return context_response | |
def init_session_state(): | |
if 'messages' not in st.session_state: | |
st.session_state.messages = [] | |
if 'thinking' not in st.session_state: | |
st.session_state.thinking = False | |
def display_chat_history(): | |
for message in st.session_state.messages: | |
if message['role'] == 'user': | |
st.markdown(f'<div class="user-message">{message["content"]}</div>', unsafe_allow_html=True) | |
else: | |
st.markdown(f'<div class="bot-message">{message["content"]}</div>', unsafe_allow_html=True) | |
def main(): | |
st.set_page_config(page_title="Vietnamese Chatbot", page_icon="🤖", layout="wide") | |
local_css() | |
init_session_state() | |
# Load model | |
model, tokenizer = load_model() | |
# Chat interface | |
st.title("Chatbot Tiếng Việt 🤖") | |
st.markdown("Xin chào! Tôi có thể giúp gì cho bạn?") | |
# Chat history container | |
chat_container = st.container() | |
# Input container | |
with st.container(): | |
col1, col2 = st.columns([6, 1]) | |
with col1: | |
user_input = st.text_input("Nhập tin nhắn của bạn...", key="user_input", label_visibility="hidden") | |
with col2: | |
send_button = st.button("Gửi") | |
if user_input and send_button: | |
# Add user message | |
st.session_state.messages.append({"role": "user", "content": user_input}) | |
# Show thinking animation | |
st.session_state.thinking = True | |
# Get prediction | |
predicted_class, confidence = predict(user_input, model, tokenizer) | |
# Generate response | |
bot_response = get_bot_response(user_input, predicted_class, confidence) | |
# Add bot response | |
time.sleep(0.5) # Brief delay for natural feeling | |
st.session_state.messages.append({"role": "assistant", "content": bot_response}) | |
st.session_state.thinking = False | |
# Clear input and rerun | |
st.rerun() | |
# Display chat history | |
with chat_container: | |
display_chat_history() | |
if st.session_state.thinking: | |
st.markdown(""" | |
<div class="thinking-animation"> | |
<div class="dot"></div> | |
<div class="dot"></div> | |
<div class="dot"></div> | |
</div> | |
""", unsafe_allow_html=True) | |
if __name__ == "__main__": | |
main() |