File size: 7,160 Bytes
a4c3bcc
 
 
 
2bc9c40
a4c3bcc
2bc9c40
a4c3bcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bc9c40
a4c3bcc
2bc9c40
a4c3bcc
2bc9c40
 
a4c3bcc
2bc9c40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4c3bcc
2bc9c40
 
 
 
 
 
a4c3bcc
2bc9c40
 
 
 
 
 
 
 
 
 
 
a4c3bcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bc9c40
a4c3bcc
 
 
 
 
 
 
2bc9c40
 
a4c3bcc
 
 
 
2bc9c40
a4c3bcc
 
 
2bc9c40
a4c3bcc
2bc9c40
a4c3bcc
 
2bc9c40
a4c3bcc
 
 
 
 
2bc9c40
 
a4c3bcc
2bc9c40
 
a4c3bcc
2bc9c40
 
a4c3bcc
 
 
2bc9c40
a4c3bcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import streamlit as st
from transformers import BertForSequenceClassification, BertTokenizer
import torch
import time
import random

# [Previous CSS styles remain the same]
def local_css():
    st.markdown("""
        <style>
        .chat-container {
            padding: 10px;
            border-radius: 5px;
            margin-bottom: 10px;
            display: flex;
            flex-direction: column;
        }
        
        .user-message {
            background-color: #e3f2fd;
            padding: 10px;
            border-radius: 15px;
            margin: 5px;
            margin-left: 20%;
            margin-right: 5px;
            align-self: flex-end;
            max-width: 70%;
        }
        
        .bot-message {
            background-color: #f5f5f5;
            padding: 10px;
            border-radius: 15px;
            margin: 5px;
            margin-right: 20%;
            margin-left: 5px;
            align-self: flex-start;
            max-width: 70%;
        }
        
        .chat-input {
            position: fixed;
            bottom: 0;
            width: 100%;
            padding: 20px;
            background-color: white;
        }
        
        .thinking-animation {
            display: flex;
            align-items: center;
            margin-left: 10px;
        }
        
        .dot {
            width: 8px;
            height: 8px;
            margin: 0 3px;
            background: #888;
            border-radius: 50%;
            animation: bounce 0.8s infinite;
        }
        
        .dot:nth-child(2) { animation-delay: 0.2s; }
        .dot:nth-child(3) { animation-delay: 0.4s; }
        
        @keyframes bounce {
            0%, 100% { transform: translateY(0); }
            50% { transform: translateY(-5px); }
        }
        </style>
    """, unsafe_allow_html=True)

@st.cache_resource
def load_model():
    model = BertForSequenceClassification.from_pretrained("trituenhantaoio/bert-base-vietnamese-uncased")
    tokenizer = BertTokenizer.from_pretrained("trituenhantaoio/bert-base-vietnamese-uncased")
    return model, tokenizer

def predict(text, model, tokenizer):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    
    with torch.no_grad():
        outputs = model(**inputs)
        predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
        predicted_class = torch.argmax(predictions, dim=1).item()
        confidence = predictions[0][predicted_class].item()
    
    return predicted_class, confidence

def get_bot_response(text, predicted_class, confidence):
    # Define response templates based on classes and confidence levels
    responses = {
        0: {  # Example for class 0 (positive sentiment)
            'high_conf': [
                "Tôi cảm nhận được sự tích cực trong câu nói của bạn. Xin chia sẻ thêm nhé!",
                "Thật vui khi nghe điều đó. Bạn có thể kể thêm không?",
                "Tuyệt vời! Tôi rất đồng ý với bạn về điều này."
            ],
            'low_conf': [
                "Có vẻ như đây là điều tích cực. Đúng không nhỉ?",
                "Tôi nghĩ đây là một góc nhìn thú vị đấy.",
                "Nghe có vẻ tốt đấy, bạn nghĩ sao?"
            ]
        },
        1: {  # Example for class 1 (negative sentiment)
            'high_conf': [
                "Tôi hiểu đây là điều khó khăn với bạn. Hãy chia sẻ thêm nhé.",
                "Tôi rất tiếc khi nghe điều này. Bạn cần tôi giúp gì không?",
                "Đúng là một tình huống khó khăn. Chúng ta cùng tìm giải pháp nhé."
            ],
            'low_conf': [
                "Có vẻ như bạn đang gặp khó khăn. Tôi có hiểu đúng không?",
                "Tôi không chắc mình hiểu hết, bạn có thể giải thích thêm được không?",
                "Hãy chia sẻ thêm để tôi có thể hiểu rõ hơn nhé."
            ]
        }
    }

    # Add more classes based on your model's output
    
    # Determine confidence level
    confidence_threshold = 0.8
    conf_level = 'high_conf' if confidence > confidence_threshold else 'low_conf'
    
    # Get appropriate response list
    try:
        response_list = responses[predicted_class][conf_level]
        response = random.choice(response_list)
    except KeyError:
        response = "Xin lỗi, tôi không chắc chắn về điều này. Bạn có thể giải thích rõ hơn được không?"

    # Add context from user's input
    context_response = f"{response}"
    
    return context_response

def init_session_state():
    if 'messages' not in st.session_state:
        st.session_state.messages = []
    if 'thinking' not in st.session_state:
        st.session_state.thinking = False

def display_chat_history():
    for message in st.session_state.messages:
        if message['role'] == 'user':
            st.markdown(f'<div class="user-message">{message["content"]}</div>', unsafe_allow_html=True)
        else:
            st.markdown(f'<div class="bot-message">{message["content"]}</div>', unsafe_allow_html=True)

def main():
    st.set_page_config(page_title="Vietnamese Chatbot", page_icon="🤖", layout="wide")
    local_css()
    init_session_state()
    
    # Load model
    model, tokenizer = load_model()
    
    # Chat interface
    st.title("Chatbot Tiếng Việt 🤖")
    st.markdown("Xin chào! Tôi có thể giúp gì cho bạn?")
    
    # Chat history container
    chat_container = st.container()
    
    # Input container
    with st.container():
        col1, col2 = st.columns([6, 1])
        with col1:
            user_input = st.text_input("Nhập tin nhắn của bạn...", key="user_input", label_visibility="hidden")
        with col2:
            send_button = st.button("Gửi")
    
    if user_input and send_button:
        # Add user message
        st.session_state.messages.append({"role": "user", "content": user_input})
        
        # Show thinking animation
        st.session_state.thinking = True
        
        # Get prediction
        predicted_class, confidence = predict(user_input, model, tokenizer)
        
        # Generate response
        bot_response = get_bot_response(user_input, predicted_class, confidence)
        
        # Add bot response
        time.sleep(0.5)  # Brief delay for natural feeling
        st.session_state.messages.append({"role": "assistant", "content": bot_response})
        st.session_state.thinking = False
        
        # Clear input and rerun
        st.rerun()
    
    # Display chat history
    with chat_container:
        display_chat_history()
        
        if st.session_state.thinking:
            st.markdown("""
                <div class="thinking-animation">
                    <div class="dot"></div>
                    <div class="dot"></div>
                    <div class="dot"></div>
                </div>
            """, unsafe_allow_html=True)

if __name__ == "__main__":
    main()