Update app.py
Browse files
app.py
CHANGED
@@ -3,68 +3,94 @@ import json
|
|
3 |
import requests
|
4 |
import streamlit as st
|
5 |
|
|
|
6 |
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
with open("./static/styles.css") as f:
|
10 |
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
|
11 |
-
|
12 |
if 'messages' not in st.session_state:
|
13 |
-
st.session_state.messages = [{'role': 'assistant', 'content': "Xin chào. Tôi là trợ lý AI văn bản luật Đấu thầu Việt Nam được phát triển bởi Nguyễn Trường Phúc
|
14 |
|
15 |
st.markdown(f"""
|
16 |
<div class=logo_area>
|
17 |
<img src="./app/static/ai.jpg"/>
|
18 |
</div>
|
19 |
""", unsafe_allow_html=True)
|
20 |
-
st.markdown("<h2 style='text-align: center;'>
|
21 |
-
|
22 |
-
url_api_retrieval_model = st.secrets["ViBidLQA_Retrieval_Module"]
|
23 |
-
url_api_extraction_model = st.secrets["ViBidLQA_EQA_Module"]
|
24 |
-
url_api_generation_model = st.secrets["ViBidLQA_AQA_Module"]
|
25 |
|
26 |
answering_method = st.sidebar.selectbox(options=['Extraction', 'Generation'], label='Chọn mô hình trả lời câu hỏi:', index=0)
|
27 |
|
28 |
if answering_method == 'Generation':
|
29 |
-
print('
|
30 |
-
print('Loading generative model...')
|
31 |
|
32 |
if answering_method == 'Extraction':
|
33 |
-
print('
|
34 |
-
print('Loading extraction model...')
|
35 |
|
36 |
-
def
|
37 |
data = {
|
38 |
-
"
|
39 |
-
"top_k": top_k
|
40 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
if response.status_code == 200:
|
45 |
-
|
46 |
-
print(f"Văn bản pháp luật được truy hồi: {results[0]['text']}")
|
47 |
-
print("="*100)
|
48 |
-
return results[0]["text"]
|
49 |
else:
|
50 |
return f"Lỗi: {response.status_code} - {response.text}"
|
51 |
-
|
52 |
-
def get_abstractive_answer(question):
|
53 |
-
context = retrieve_context(question=question)
|
54 |
|
|
|
55 |
data = {
|
56 |
-
"context": context,
|
57 |
"question": question
|
58 |
}
|
59 |
|
60 |
-
response = requests.post(
|
|
|
61 |
if response.status_code == 200:
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
else:
|
65 |
return f"Lỗi: {response.status_code} - {response.text}"
|
66 |
|
67 |
-
def
|
68 |
context = retrieve_context(question=question)
|
69 |
|
70 |
data = {
|
@@ -72,9 +98,8 @@ def get_abstractive_answer_stream(question):
|
|
72 |
"question": question
|
73 |
}
|
74 |
|
75 |
-
# Sử dụng requests với stream=True
|
76 |
response = requests.post(url_api_generation_model, json=data, stream=True)
|
77 |
-
|
78 |
if response.status_code == 200:
|
79 |
return response
|
80 |
else:
|
@@ -132,55 +157,210 @@ if prompt := st.chat_input(placeholder='Tôi có thể giúp được gì cho b
|
|
132 |
st.session_state.messages.append({'role': 'user', 'content': prompt})
|
133 |
|
134 |
message_placeholder = st.empty()
|
135 |
-
|
136 |
|
137 |
full_response = ""
|
138 |
if answering_method == 'Generation':
|
139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
-
if isinstance(response_stream, str):
|
142 |
-
full_response = response_stream
|
143 |
-
message_placeholder.markdown(f"""
|
144 |
-
<div class="assistant-message">
|
145 |
-
<img src="./app/static/ai.jpg" class="assistant-avatar" />
|
146 |
-
<div class="stMarkdown">{full_response}</div>
|
147 |
-
</div>
|
148 |
-
""", unsafe_allow_html=True)
|
149 |
else:
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
""", unsafe_allow_html=True)
|
170 |
|
171 |
-
|
172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
else:
|
174 |
-
|
175 |
-
|
176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
<div class="
|
182 |
-
|
183 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
|
185 |
message_placeholder.markdown(f"""
|
186 |
<div class="assistant-message">
|
|
|
3 |
import requests
|
4 |
import streamlit as st
|
5 |
|
6 |
+
st.set_page_config(page_title="ViBidLQA - Trợ lý AI văn bản pháp luật Việt Nam", page_icon="./app/static/ai.jpg", layout="centered", initial_sidebar_state="expanded")
|
7 |
|
8 |
+
routing_response_module = st.secrets["ViBidLQA_Routing_Module"]
|
9 |
+
retrieval_module = st.secrets["ViBidLQA_Retrieval_Module"]
|
10 |
+
ext_QA_module = st.secrets["ViBidLQA_EQA_Module"]
|
11 |
+
abs_QA_module = st.secrets["ViBidLQA_AQA_Module"]
|
12 |
+
|
13 |
+
url_api_question_classify_model = f"{routing_response_module}/query_classify"
|
14 |
+
url_api_unrelated_question_response_model = f"{routing_response_module}/response_unrelated_question"
|
15 |
+
url_api_introduce_system_model = f"{routing_response_module}/about_me"
|
16 |
+
url_api_retrieval_model = f"{retrieval_module}/search"
|
17 |
+
url_api_extraction_model = f"{ext_QA_module}/answer"
|
18 |
+
url_api_generation_model = f"{abs_QA_module}/answer"
|
19 |
|
20 |
with open("./static/styles.css") as f:
|
21 |
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
|
22 |
+
|
23 |
if 'messages' not in st.session_state:
|
24 |
+
st.session_state.messages = [{'role': 'assistant', 'content': "Xin chào. Tôi là trợ lý AI văn bản luật Đấu thầu Việt Nam được phát triển bởi Nguyễn Trường Phúc và các cộng sự. Rất vui khi được hỗ trợ bạn trong các vấn đề pháp lý tại Việt Nam!"}]
|
25 |
|
26 |
st.markdown(f"""
|
27 |
<div class=logo_area>
|
28 |
<img src="./app/static/ai.jpg"/>
|
29 |
</div>
|
30 |
""", unsafe_allow_html=True)
|
31 |
+
st.markdown("<h2 style='text-align: center;'>ViBidLQA</h2>", unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
32 |
|
33 |
answering_method = st.sidebar.selectbox(options=['Extraction', 'Generation'], label='Chọn mô hình trả lời câu hỏi:', index=0)
|
34 |
|
35 |
if answering_method == 'Generation':
|
36 |
+
print('Switched to generative model...')
|
|
|
37 |
|
38 |
if answering_method == 'Extraction':
|
39 |
+
print('Switched to extraction model...')
|
|
|
40 |
|
41 |
+
def classify_question(question):
|
42 |
data = {
|
43 |
+
"question": question
|
|
|
44 |
}
|
45 |
+
|
46 |
+
response = requests.post(url_api_question_classify_model, json=data)
|
47 |
+
|
48 |
+
if response.status_code == 200:
|
49 |
+
print(response)
|
50 |
+
return response
|
51 |
+
else:
|
52 |
+
return f"Lỗi: {response.status_code} - {response.text}"
|
53 |
|
54 |
+
def introduce_system(question):
|
55 |
+
data = {
|
56 |
+
"question": question
|
57 |
+
}
|
58 |
+
|
59 |
+
response = requests.post(url_api_introduce_system_model, json=data, stream=True)
|
60 |
|
61 |
if response.status_code == 200:
|
62 |
+
return response
|
|
|
|
|
|
|
63 |
else:
|
64 |
return f"Lỗi: {response.status_code} - {response.text}"
|
|
|
|
|
|
|
65 |
|
66 |
+
def response_unrelated_question(question):
|
67 |
data = {
|
|
|
68 |
"question": question
|
69 |
}
|
70 |
|
71 |
+
response = requests.post(url_api_unrelated_question_response_model, json=data, stream=True)
|
72 |
+
|
73 |
if response.status_code == 200:
|
74 |
+
return response
|
75 |
+
else:
|
76 |
+
return f"Lỗi: {response.status_code} - {response.text}"
|
77 |
+
|
78 |
+
def retrieve_context(question, top_k=10):
|
79 |
+
data = {
|
80 |
+
"query": question,
|
81 |
+
"top_k": top_k
|
82 |
+
}
|
83 |
+
|
84 |
+
response = requests.post(url_api_retrieval_model, json=data)
|
85 |
+
|
86 |
+
if response.status_code == 200:
|
87 |
+
results = response.json()["results"]
|
88 |
+
print(f"Retrieved bidding legal context: {results[0]['text']}")
|
89 |
+
return results[0]["text"]
|
90 |
else:
|
91 |
return f"Lỗi: {response.status_code} - {response.text}"
|
92 |
|
93 |
+
def get_abstractive_answer(question):
|
94 |
context = retrieve_context(question=question)
|
95 |
|
96 |
data = {
|
|
|
98 |
"question": question
|
99 |
}
|
100 |
|
|
|
101 |
response = requests.post(url_api_generation_model, json=data, stream=True)
|
102 |
+
|
103 |
if response.status_code == 200:
|
104 |
return response
|
105 |
else:
|
|
|
157 |
st.session_state.messages.append({'role': 'user', 'content': prompt})
|
158 |
|
159 |
message_placeholder = st.empty()
|
|
|
160 |
|
161 |
full_response = ""
|
162 |
if answering_method == 'Generation':
|
163 |
+
classify_result = classify_question(question=prompt).json()
|
164 |
+
|
165 |
+
print(f"The type of user query: {classify_result}")
|
166 |
+
|
167 |
+
if classify_result == "BIDDING_RELATED":
|
168 |
+
abs_answer = get_abstractive_answer(question=prompt)
|
169 |
+
|
170 |
+
if isinstance(abs_answer, str):
|
171 |
+
full_response = abs_answer
|
172 |
+
message_placeholder.markdown(f"""
|
173 |
+
<div class="assistant-message">
|
174 |
+
<img src="./app/static/ai.jpg" class="assistant-avatar" />
|
175 |
+
<div class="stMarkdown">{full_response}</div>
|
176 |
+
</div>
|
177 |
+
""", unsafe_allow_html=True)
|
178 |
+
else:
|
179 |
+
full_response = ""
|
180 |
+
for line in abs_answer.iter_lines():
|
181 |
+
if line:
|
182 |
+
line = line.decode('utf-8')
|
183 |
+
if line.startswith('data: '):
|
184 |
+
data_str = line[6:]
|
185 |
+
if data_str == '[DONE]':
|
186 |
+
break
|
187 |
+
|
188 |
+
try:
|
189 |
+
data = json.loads(data_str)
|
190 |
+
token = data.get('token', '')
|
191 |
+
full_response += token
|
192 |
+
|
193 |
+
message_placeholder.markdown(f"""
|
194 |
+
<div class="assistant-message">
|
195 |
+
<img src="./app/static/ai.jpg" class="assistant-avatar" />
|
196 |
+
<div class="stMarkdown">{full_response}●</div>
|
197 |
+
</div>
|
198 |
+
""", unsafe_allow_html=True)
|
199 |
+
|
200 |
+
except json.JSONDecodeError:
|
201 |
+
pass
|
202 |
+
|
203 |
+
elif classify_result == "ABOUT_CHATBOT":
|
204 |
+
answer = introduce_system(question=prompt)
|
205 |
+
|
206 |
+
if isinstance(answer, str):
|
207 |
+
full_response = answer
|
208 |
+
message_placeholder.markdown(f"""
|
209 |
+
<div class="assistant-message">
|
210 |
+
<img src="./app/static/ai.jpg" class="assistant-avatar" />
|
211 |
+
<div class="stMarkdown">{full_response}</div>
|
212 |
+
</div>
|
213 |
+
""", unsafe_allow_html=True)
|
214 |
+
else:
|
215 |
+
full_response = ""
|
216 |
+
for line in answer.iter_lines():
|
217 |
+
if line:
|
218 |
+
line = line.decode('utf-8')
|
219 |
+
if line.startswith('data: '):
|
220 |
+
data_str = line[6:]
|
221 |
+
if data_str == '[DONE]':
|
222 |
+
break
|
223 |
+
|
224 |
+
try:
|
225 |
+
data = json.loads(data_str)
|
226 |
+
token = data.get('token', '')
|
227 |
+
full_response += token
|
228 |
+
|
229 |
+
message_placeholder.markdown(f"""
|
230 |
+
<div class="assistant-message">
|
231 |
+
<img src="./app/static/ai.jpg" class="assistant-avatar" />
|
232 |
+
<div class="stMarkdown">{full_response}●</div>
|
233 |
+
</div>
|
234 |
+
""", unsafe_allow_html=True)
|
235 |
+
|
236 |
+
except json.JSONDecodeError:
|
237 |
+
pass
|
238 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
239 |
else:
|
240 |
+
answer = response_unrelated_question(question=prompt)
|
241 |
+
|
242 |
+
if isinstance(answer, str):
|
243 |
+
full_response = answer
|
244 |
+
message_placeholder.markdown(f"""
|
245 |
+
<div class="assistant-message">
|
246 |
+
<img src="./app/static/ai.jpg" class="assistant-avatar" />
|
247 |
+
<div class="stMarkdown">{full_response}</div>
|
248 |
+
</div>
|
249 |
+
""", unsafe_allow_html=True)
|
250 |
+
else:
|
251 |
+
full_response = ""
|
252 |
+
for line in answer.iter_lines():
|
253 |
+
if line:
|
254 |
+
line = line.decode('utf-8')
|
255 |
+
if line.startswith('data: '):
|
256 |
+
data_str = line[6:]
|
257 |
+
if data_str == '[DONE]':
|
258 |
+
break
|
|
|
259 |
|
260 |
+
try:
|
261 |
+
data = json.loads(data_str)
|
262 |
+
token = data.get('token', '')
|
263 |
+
full_response += token
|
264 |
+
|
265 |
+
message_placeholder.markdown(f"""
|
266 |
+
<div class="assistant-message">
|
267 |
+
<img src="./app/static/ai.jpg" class="assistant-avatar" />
|
268 |
+
<div class="stMarkdown">{full_response}●</div>
|
269 |
+
</div>
|
270 |
+
""", unsafe_allow_html=True)
|
271 |
+
|
272 |
+
except json.JSONDecodeError:
|
273 |
+
pass
|
274 |
+
|
275 |
else:
|
276 |
+
classify_result = classify_question(question=prompt).json()
|
277 |
+
|
278 |
+
print(f"The type of user query: {classify_result}")
|
279 |
+
|
280 |
+
if classify_result == "BIDDING_RELATED":
|
281 |
+
ext_answer = get_extractive_answer(question=prompt)
|
282 |
+
|
283 |
+
for word in generate_text_effect(ext_answer):
|
284 |
+
full_response = word
|
285 |
+
|
286 |
+
message_placeholder.markdown(f"""
|
287 |
+
<div class="assistant-message">
|
288 |
+
<img src="./app/static/ai.jpg" class="assistant-avatar" />
|
289 |
+
<div class="stMarkdown">{full_response}●</div>
|
290 |
+
</div>
|
291 |
+
""", unsafe_allow_html=True)
|
292 |
+
elif classify_result == "ABOUT_CHATBOT":
|
293 |
+
answer = introduce_system(question=prompt)
|
294 |
+
|
295 |
+
if isinstance(answer, str):
|
296 |
+
full_response = answer
|
297 |
+
message_placeholder.markdown(f"""
|
298 |
+
<div class="assistant-message">
|
299 |
+
<img src="./app/static/ai.jpg" class="assistant-avatar" />
|
300 |
+
<div class="stMarkdown">{full_response}</div>
|
301 |
+
</div>
|
302 |
+
""", unsafe_allow_html=True)
|
303 |
+
else:
|
304 |
+
full_response = ""
|
305 |
+
for line in answer.iter_lines():
|
306 |
+
if line:
|
307 |
+
line = line.decode('utf-8')
|
308 |
+
if line.startswith('data: '):
|
309 |
+
data_str = line[6:]
|
310 |
+
if data_str == '[DONE]':
|
311 |
+
break
|
312 |
+
|
313 |
+
try:
|
314 |
+
data = json.loads(data_str)
|
315 |
+
token = data.get('token', '')
|
316 |
+
full_response += token
|
317 |
+
|
318 |
+
message_placeholder.markdown(f"""
|
319 |
+
<div class="assistant-message">
|
320 |
+
<img src="./app/static/ai.jpg" class="assistant-avatar" />
|
321 |
+
<div class="stMarkdown">{full_response}●</div>
|
322 |
+
</div>
|
323 |
+
""", unsafe_allow_html=True)
|
324 |
+
|
325 |
+
except json.JSONDecodeError:
|
326 |
+
pass
|
327 |
+
|
328 |
+
else:
|
329 |
+
answer = response_unrelated_question(question=prompt)
|
330 |
|
331 |
+
if isinstance(answer, str):
|
332 |
+
full_response = answer
|
333 |
+
message_placeholder.markdown(f"""
|
334 |
+
<div class="assistant-message">
|
335 |
+
<img src="./app/static/ai.jpg" class="assistant-avatar" />
|
336 |
+
<div class="stMarkdown">{full_response}</div>
|
337 |
+
</div>
|
338 |
+
""", unsafe_allow_html=True)
|
339 |
+
else:
|
340 |
+
full_response = ""
|
341 |
+
for line in answer.iter_lines():
|
342 |
+
if line:
|
343 |
+
line = line.decode('utf-8')
|
344 |
+
if line.startswith('data: '):
|
345 |
+
data_str = line[6:]
|
346 |
+
if data_str == '[DONE]':
|
347 |
+
break
|
348 |
+
|
349 |
+
try:
|
350 |
+
data = json.loads(data_str)
|
351 |
+
token = data.get('token', '')
|
352 |
+
full_response += token
|
353 |
+
|
354 |
+
message_placeholder.markdown(f"""
|
355 |
+
<div class="assistant-message">
|
356 |
+
<img src="./app/static/ai.jpg" class="assistant-avatar" />
|
357 |
+
<div class="stMarkdown">{full_response}●</div>
|
358 |
+
</div>
|
359 |
+
""", unsafe_allow_html=True)
|
360 |
+
|
361 |
+
except json.JSONDecodeError:
|
362 |
+
pass
|
363 |
+
|
364 |
|
365 |
message_placeholder.markdown(f"""
|
366 |
<div class="assistant-message">
|