JustKiddo commited on
Commit
dc14176
·
verified ·
1 Parent(s): 09bbd11

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -161
app.py CHANGED
@@ -1,181 +1,145 @@
 
1
  import streamlit as st
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
4
- import time
 
 
5
 
6
- # Custom CSS for the chat interface
7
- def local_css():
8
- st.markdown("""
9
- <style>
10
- .chat-container {
11
- padding: 10px;
12
- border-radius: 5px;
13
- margin-bottom: 10px;
14
- display: flex;
15
- flex-direction: column;
16
- }
17
-
18
- .user-message {
19
- background-color: #e3f2fd;
20
- padding: 10px;
21
- border-radius: 15px;
22
- margin: 5px;
23
- margin-left: 20%;
24
- margin-right: 5px;
25
- align-self: flex-end;
26
- max-width: 70%;
27
- }
28
-
29
- .bot-message {
30
- background-color: #f5f5f5;
31
- padding: 10px;
32
- border-radius: 15px;
33
- margin: 5px;
34
- margin-right: 20%;
35
- margin-left: 5px;
36
- align-self: flex-start;
37
- max-width: 70%;
38
- }
39
-
40
- .thinking-animation {
41
- display: flex;
42
- align-items: center;
43
- margin-left: 10px;
44
- }
45
-
46
- .dot {
47
- width: 8px;
48
- height: 8px;
49
- margin: 0 3px;
50
- background: #888;
51
- border-radius: 50%;
52
- animation: bounce 0.8s infinite;
53
- }
54
-
55
- .dot:nth-child(2) { animation-delay: 0.2s; }
56
- .dot:nth-child(3) { animation-delay: 0.4s; }
57
-
58
- @keyframes bounce {
59
- 0%, 100% { transform: translateY(0); }
60
- 50% { transform: translateY(-5px); }
61
- }
62
- </style>
63
- """, unsafe_allow_html=True)
64
 
65
- # Load model and tokenizer
66
- @st.cache_resource
67
- def load_model():
68
- # Using VietAI's Vietnamese GPT model
69
- model_name = "google-t5/t5-base"
70
- tokenizer = AutoTokenizer.from_pretrained(model_name)
71
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
72
- return model, tokenizer
73
 
74
- def generate_response(prompt, model, tokenizer, max_length=4169):
75
- # Prepare input
76
- inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
77
-
78
- # Generate response
79
- with torch.no_grad():
80
- outputs = model.generate(
81
- inputs.input_ids,
82
- max_length=max_length,
83
- num_return_sequences=1,
84
- temperature=0.7,
85
- top_k=50,
86
- top_p=0.95,
87
- do_sample=True,
88
- pad_token_id=tokenizer.eos_token_id,
89
- attention_mask=inputs.attention_mask
90
- )
91
-
92
- # Decode response
93
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
94
- # Remove the input prompt from the response
95
- response = response[len(prompt):].strip()
96
- return response
97
-
98
- def init_session_state():
99
- if 'messages' not in st.session_state:
100
- st.session_state.messages = []
101
- if 'thinking' not in st.session_state:
102
- st.session_state.thinking = False
103
 
104
- def display_chat_history():
105
- for message in st.session_state.messages:
106
- if message['role'] == 'user':
107
- st.markdown(f'<div class="user-message">{message["content"]}</div>', unsafe_allow_html=True)
108
- else:
109
- st.markdown(f'<div class="bot-message">{message["content"]}</div>', unsafe_allow_html=True)
 
110
 
111
- def main():
112
- st.set_page_config(
113
- page_title="IOGPT",
114
- page_icon="🤖",
115
- layout="wide"
116
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
- local_css()
119
- init_session_state()
 
 
 
 
 
120
 
121
- # Load model
122
- model, tokenizer = load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
- # Chat interface
125
- st.title("IOGPT 🤖")
126
- st.markdown("Xin chào! Tôi là trợ lý IOGPT. Hãy hỏi tôi bất cứ điều gì!")
127
 
128
- # Chat history container
129
- chat_container = st.container()
 
130
 
131
- # Input container
132
- with st.container():
133
- col1, col2 = st.columns([6, 1])
134
- with col1:
135
- user_input = st.text_input(
136
- "Nhập tin nhắn của bạn...",
137
- key="user_input",
138
- label_visibility="hidden"
139
- )
140
- with col2:
141
- send_button = st.button("Gửi")
142
 
143
- if user_input and send_button:
144
- # Add user message
145
- st.session_state.messages.append({"role": "user", "content": user_input})
146
-
147
- # Show thinking animation
148
- st.session_state.thinking = True
149
 
150
- # Prepare conversation history
151
- conversation_history = "\n".join([
152
- f"{'User: ' if msg['role'] == 'user' else 'Assistant: '}{msg['content']}"
153
- for msg in st.session_state.messages[-3:] # Last 3 messages for context
154
- ])
155
 
156
- # Generate response
157
- prompt = f"{conversation_history}\nAssistant:"
158
- bot_response = generate_response(prompt, model, tokenizer)
159
 
160
- # Add bot response
161
- st.session_state.messages.append({"role": "assistant", "content": bot_response})
162
- st.session_state.thinking = False
163
 
164
- # Clear input and rerun
165
- st.rerun()
166
-
167
- # Display chat history
168
- with chat_container:
169
- display_chat_history()
170
-
171
- if st.session_state.thinking:
172
- st.markdown("""
173
- <div class="thinking-animation">
174
- <div class="dot"></div>
175
- <div class="dot"></div>
176
- <div class="dot"></div>
177
- </div>
178
- """, unsafe_allow_html=True)
179
 
180
  if __name__ == "__main__":
181
  main()
 
1
+ import os
2
  import streamlit as st
 
3
  import torch
4
+ from transformers import AutoTokenizer, AutoModel
5
+ import numpy as np
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
 
8
+ # Get the port from Heroku environment, default to 8501 for local development
9
+ PORT = int(os.environ.get('PORT', 8501))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ class LazyLoadModel:
12
+ def __init__(self, model_name='intfloat/multilingual-e5-small'):
13
+ self.model_name = model_name
14
+ self._tokenizer = None
15
+ self._model = None
 
 
 
16
 
17
+ @property
18
+ def tokenizer(self):
19
+ if self._tokenizer is None:
20
+ print("Loading tokenizer...")
21
+ self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
22
+ return self._tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ @property
25
+ def model(self):
26
+ if self._model is None:
27
+ print("Loading model...")
28
+ # Use float16 to reduce memory and potentially speed up loading
29
+ self._model = AutoModel.from_pretrained(self.model_name, torch_dtype=torch.float16)
30
+ return self._model
31
 
32
+ class VietnameseChatbot:
33
+ def __init__(self):
34
+ """
35
+ Initialize the Vietnamese chatbot with lazy-loaded model
36
+ """
37
+ self.model_loader = LazyLoadModel()
38
+
39
+ # Very minimal conversation data to reduce startup time
40
+ self.conversation_data = [
41
+ {"query": "Xin chào", "response": "Chào bạn!"},
42
+ {"query": "Bạn là ai?", "response": "Tôi là trợ lý AI."},
43
+ ]
44
+
45
+ def embed_text(self, text):
46
+ """
47
+ Generate embeddings for input text
48
+ """
49
+ try:
50
+ # Tokenize and generate embeddings
51
+ inputs = self.model_loader.tokenizer(text, return_tensors='pt', padding=True, truncation=True)
52
+
53
+ with torch.no_grad():
54
+ model_output = self.model_loader.model(**inputs)
55
+
56
+ # Mean pooling
57
+ embeddings = self.mean_pooling(model_output, inputs['attention_mask'])
58
+ return embeddings.numpy()
59
+ except Exception as e:
60
+ print(f"Embedding error: {e}")
61
+ return None
62
 
63
+ def mean_pooling(self, model_output, attention_mask):
64
+ """
65
+ Perform mean pooling on model output
66
+ """
67
+ token_embeddings = model_output[0]
68
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
69
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
70
 
71
+ def get_response(self, user_query):
72
+ """
73
+ Find the most similar response from conversation data
74
+ """
75
+ try:
76
+ # Embed user query
77
+ query_embedding = self.embed_text(user_query)
78
+
79
+ if query_embedding is None:
80
+ return "Xin lỗi, đã có lỗi xảy ra."
81
+
82
+ # Embed conversation data
83
+ conversation_embeddings = np.array([
84
+ self.embed_text(item['query'])[0] for item in self.conversation_data
85
+ ])
86
+
87
+ # Calculate cosine similarities
88
+ similarities = cosine_similarity(query_embedding, conversation_embeddings)[0]
89
+
90
+ # Find most similar response
91
+ best_match_index = np.argmax(similarities)
92
+
93
+ # Return response if similarity is above threshold
94
+ if similarities[best_match_index] > 0.5:
95
+ return self.conversation_data[best_match_index]['response']
96
+
97
+ return "Xin lỗi, tôi không hiểu câu hỏi của bạn."
98
+ except Exception as e:
99
+ print(f"Response generation error: {e}")
100
+ return "Đã xảy ra lỗi. Xin vui lòng thử lại."
101
+
102
+ def main():
103
+ # Server configuration to use Heroku-assigned port
104
+ if 'PORT' in os.environ:
105
+ #st.set_option('server.port', PORT)
106
+ print(f"Server starting on port {PORT}")
107
+
108
+ st.title("🤖 Trợ Lý AI Tiếng Việt")
109
 
110
+ # Initialize chatbot
111
+ chatbot = VietnameseChatbot()
 
112
 
113
+ # Chat history in session state
114
+ if 'messages' not in st.session_state:
115
+ st.session_state.messages = []
116
 
117
+ # Display chat messages
118
+ for message in st.session_state.messages:
119
+ with st.chat_message(message["role"]):
120
+ st.markdown(message["content"])
 
 
 
 
 
 
 
121
 
122
+ # User input
123
+ if prompt := st.chat_input("Hãy nói gì đó..."):
124
+ # Add user message to chat history
125
+ st.session_state.messages.append({"role": "user", "content": prompt})
 
 
126
 
127
+ # Display user message
128
+ with st.chat_message("user"):
129
+ st.markdown(prompt)
 
 
130
 
131
+ # Get chatbot response
132
+ response = chatbot.get_response(prompt)
 
133
 
134
+ # Display chatbot response
135
+ with st.chat_message("assistant"):
136
+ st.markdown(response)
137
 
138
+ # Add assistant message to chat history
139
+ st.session_state.messages.append({"role": "assistant", "content": response})
140
+
141
+ # Logging for Heroku diagnostics
142
+ print("Chatbot application is initializing...")
 
 
 
 
 
 
 
 
 
 
143
 
144
  if __name__ == "__main__":
145
  main()