ntphuc149 commited on
Commit
c4d7d0b
·
verified ·
1 Parent(s): 1d4d4a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +251 -71
app.py CHANGED
@@ -3,68 +3,94 @@ import json
3
  import requests
4
  import streamlit as st
5
 
 
6
 
7
- st.set_page_config(page_title="ViBidLQA - Trợ lý AI hỗ trợ hỏi đáp luật Việt Nam", page_icon="./app/static/ai.jpg", layout="centered", initial_sidebar_state="expanded")
 
 
 
 
 
 
 
 
 
 
8
 
9
  with open("./static/styles.css") as f:
10
  st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
11
-
12
  if 'messages' not in st.session_state:
13
- st.session_state.messages = [{'role': 'assistant', 'content': "Xin chào. Tôi là trợ lý AI văn bản luật Đấu thầu Việt Nam được phát triển bởi Nguyễn Trường Phúc. Rất vui khi được hỗ trợ bạn trong các vấn đề pháp lý tại Việt Nam!"}]
14
 
15
  st.markdown(f"""
16
  <div class=logo_area>
17
  <img src="./app/static/ai.jpg"/>
18
  </div>
19
  """, unsafe_allow_html=True)
20
- st.markdown("<h2 style='text-align: center;'>The ViBidLQA System </h2>", unsafe_allow_html=True)
21
-
22
- url_api_retrieval_model = st.secrets["ViBidLQA_Retrieval_Module"]
23
- url_api_extraction_model = st.secrets["ViBidLQA_EQA_Module"]
24
- url_api_generation_model = st.secrets["ViBidLQA_AQA_Module"]
25
 
26
  answering_method = st.sidebar.selectbox(options=['Extraction', 'Generation'], label='Chọn mô hình trả lời câu hỏi:', index=0)
27
 
28
  if answering_method == 'Generation':
29
- print('Switching to generative model...')
30
- print('Loading generative model...')
31
 
32
  if answering_method == 'Extraction':
33
- print('Switching to extraction model...')
34
- print('Loading extraction model...')
35
 
36
- def retrieve_context(question, top_k=10):
37
  data = {
38
- "query": question,
39
- "top_k": top_k
40
  }
 
 
 
 
 
 
 
 
41
 
42
- response = requests.post(url_api_retrieval_model, json=data)
 
 
 
 
 
43
 
44
  if response.status_code == 200:
45
- results = response.json()["results"]
46
- print(f"Văn bản pháp luật được truy hồi: {results[0]['text']}")
47
- print("="*100)
48
- return results[0]["text"]
49
  else:
50
  return f"Lỗi: {response.status_code} - {response.text}"
51
-
52
- def get_abstractive_answer(question):
53
- context = retrieve_context(question=question)
54
 
 
55
  data = {
56
- "context": context,
57
  "question": question
58
  }
59
 
60
- response = requests.post(url_api_generation_model, json=data)
 
61
  if response.status_code == 200:
62
- result = response.json()
63
- return result["answer"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  else:
65
  return f"Lỗi: {response.status_code} - {response.text}"
66
 
67
- def get_abstractive_answer_stream(question):
68
  context = retrieve_context(question=question)
69
 
70
  data = {
@@ -72,9 +98,8 @@ def get_abstractive_answer_stream(question):
72
  "question": question
73
  }
74
 
75
- # Sử dụng requests với stream=True
76
  response = requests.post(url_api_generation_model, json=data, stream=True)
77
-
78
  if response.status_code == 200:
79
  return response
80
  else:
@@ -132,55 +157,210 @@ if prompt := st.chat_input(placeholder='Tôi có thể giúp được gì cho b
132
  st.session_state.messages.append({'role': 'user', 'content': prompt})
133
 
134
  message_placeholder = st.empty()
135
-
136
 
137
  full_response = ""
138
  if answering_method == 'Generation':
139
- response_stream = get_abstractive_answer_stream(question=prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
- if isinstance(response_stream, str):
142
- full_response = response_stream
143
- message_placeholder.markdown(f"""
144
- <div class="assistant-message">
145
- <img src="./app/static/ai.jpg" class="assistant-avatar" />
146
- <div class="stMarkdown">{full_response}</div>
147
- </div>
148
- """, unsafe_allow_html=True)
149
  else:
150
- full_response = ""
151
- for line in response_stream.iter_lines():
152
- if line:
153
- line = line.decode('utf-8')
154
- if line.startswith('data: '):
155
- data_str = line[6:]
156
- if data_str == '[DONE]':
157
- break
158
-
159
- try:
160
- data = json.loads(data_str)
161
- token = data.get('token', '')
162
- full_response += token
163
-
164
- message_placeholder.markdown(f"""
165
- <div class="assistant-message">
166
- <img src="./app/static/ai.jpg" class="assistant-avatar" />
167
- <div class="stMarkdown">{full_response}●</div>
168
- </div>
169
- """, unsafe_allow_html=True)
170
 
171
- except json.JSONDecodeError:
172
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  else:
174
- ext_answer = get_extractive_answer(question=prompt)
175
- for word in generate_text_effect(ext_answer):
176
- full_response = word
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
- message_placeholder.markdown(f"""
179
- <div class="assistant-message">
180
- <img src="./app/static/ai.jpg" class="assistant-avatar" />
181
- <div class="stMarkdown">{full_response}●</div>
182
- </div>
183
- """, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
  message_placeholder.markdown(f"""
186
  <div class="assistant-message">
 
3
  import requests
4
  import streamlit as st
5
 
6
+ st.set_page_config(page_title="ViBidLQA - Trợ lý AI văn bản pháp luật Việt Nam", page_icon="./app/static/ai.jpg", layout="centered", initial_sidebar_state="expanded")
7
 
8
+ routing_response_module = st.secrets["ViBidLQA_Routing_Module"]
9
+ retrieval_module = st.secrets["ViBidLQA_Retrieval_Module"]
10
+ ext_QA_module = st.secrets["ViBidLQA_EQA_Module"]
11
+ abs_QA_module = st.secrets["ViBidLQA_AQA_Module"]
12
+
13
+ url_api_question_classify_model = f"{routing_response_module}/query_classify"
14
+ url_api_unrelated_question_response_model = f"{routing_response_module}/response_unrelated_question"
15
+ url_api_introduce_system_model = f"{routing_response_module}/about_me"
16
+ url_api_retrieval_model = f"{retrieval_module}/search"
17
+ url_api_extraction_model = f"{ext_QA_module}/answer"
18
+ url_api_generation_model = f"{abs_QA_module}/answer"
19
 
20
  with open("./static/styles.css") as f:
21
  st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
22
+
23
  if 'messages' not in st.session_state:
24
+ st.session_state.messages = [{'role': 'assistant', 'content': "Xin chào. Tôi là trợ lý AI văn bản luật Đấu thầu Việt Nam được phát triển bởi Nguyễn Trường Phúc và các cộng sự. Rất vui khi được hỗ trợ bạn trong các vấn đề pháp lý tại Việt Nam!"}]
25
 
26
  st.markdown(f"""
27
  <div class=logo_area>
28
  <img src="./app/static/ai.jpg"/>
29
  </div>
30
  """, unsafe_allow_html=True)
31
+ st.markdown("<h2 style='text-align: center;'>ViBidLQA</h2>", unsafe_allow_html=True)
 
 
 
 
32
 
33
  answering_method = st.sidebar.selectbox(options=['Extraction', 'Generation'], label='Chọn mô hình trả lời câu hỏi:', index=0)
34
 
35
  if answering_method == 'Generation':
36
+ print('Switched to generative model...')
 
37
 
38
  if answering_method == 'Extraction':
39
+ print('Switched to extraction model...')
 
40
 
41
+ def classify_question(question):
42
  data = {
43
+ "question": question
 
44
  }
45
+
46
+ response = requests.post(url_api_question_classify_model, json=data)
47
+
48
+ if response.status_code == 200:
49
+ print(response)
50
+ return response
51
+ else:
52
+ return f"Lỗi: {response.status_code} - {response.text}"
53
 
54
+ def introduce_system(question):
55
+ data = {
56
+ "question": question
57
+ }
58
+
59
+ response = requests.post(url_api_introduce_system_model, json=data, stream=True)
60
 
61
  if response.status_code == 200:
62
+ return response
 
 
 
63
  else:
64
  return f"Lỗi: {response.status_code} - {response.text}"
 
 
 
65
 
66
+ def response_unrelated_question(question):
67
  data = {
 
68
  "question": question
69
  }
70
 
71
+ response = requests.post(url_api_unrelated_question_response_model, json=data, stream=True)
72
+
73
  if response.status_code == 200:
74
+ return response
75
+ else:
76
+ return f"Lỗi: {response.status_code} - {response.text}"
77
+
78
+ def retrieve_context(question, top_k=10):
79
+ data = {
80
+ "query": question,
81
+ "top_k": top_k
82
+ }
83
+
84
+ response = requests.post(url_api_retrieval_model, json=data)
85
+
86
+ if response.status_code == 200:
87
+ results = response.json()["results"]
88
+ print(f"Retrieved bidding legal context: {results[0]['text']}")
89
+ return results[0]["text"]
90
  else:
91
  return f"Lỗi: {response.status_code} - {response.text}"
92
 
93
+ def get_abstractive_answer(question):
94
  context = retrieve_context(question=question)
95
 
96
  data = {
 
98
  "question": question
99
  }
100
 
 
101
  response = requests.post(url_api_generation_model, json=data, stream=True)
102
+
103
  if response.status_code == 200:
104
  return response
105
  else:
 
157
  st.session_state.messages.append({'role': 'user', 'content': prompt})
158
 
159
  message_placeholder = st.empty()
 
160
 
161
  full_response = ""
162
  if answering_method == 'Generation':
163
+ classify_result = classify_question(question=prompt).json()
164
+
165
+ print(f"The type of user query: {classify_result}")
166
+
167
+ if classify_result == "BIDDING_RELATED":
168
+ abs_answer = get_abstractive_answer(question=prompt)
169
+
170
+ if isinstance(abs_answer, str):
171
+ full_response = abs_answer
172
+ message_placeholder.markdown(f"""
173
+ <div class="assistant-message">
174
+ <img src="./app/static/ai.jpg" class="assistant-avatar" />
175
+ <div class="stMarkdown">{full_response}</div>
176
+ </div>
177
+ """, unsafe_allow_html=True)
178
+ else:
179
+ full_response = ""
180
+ for line in abs_answer.iter_lines():
181
+ if line:
182
+ line = line.decode('utf-8')
183
+ if line.startswith('data: '):
184
+ data_str = line[6:]
185
+ if data_str == '[DONE]':
186
+ break
187
+
188
+ try:
189
+ data = json.loads(data_str)
190
+ token = data.get('token', '')
191
+ full_response += token
192
+
193
+ message_placeholder.markdown(f"""
194
+ <div class="assistant-message">
195
+ <img src="./app/static/ai.jpg" class="assistant-avatar" />
196
+ <div class="stMarkdown">{full_response}●</div>
197
+ </div>
198
+ """, unsafe_allow_html=True)
199
+
200
+ except json.JSONDecodeError:
201
+ pass
202
+
203
+ elif classify_result == "ABOUT_CHATBOT":
204
+ answer = introduce_system(question=prompt)
205
+
206
+ if isinstance(answer, str):
207
+ full_response = answer
208
+ message_placeholder.markdown(f"""
209
+ <div class="assistant-message">
210
+ <img src="./app/static/ai.jpg" class="assistant-avatar" />
211
+ <div class="stMarkdown">{full_response}</div>
212
+ </div>
213
+ """, unsafe_allow_html=True)
214
+ else:
215
+ full_response = ""
216
+ for line in answer.iter_lines():
217
+ if line:
218
+ line = line.decode('utf-8')
219
+ if line.startswith('data: '):
220
+ data_str = line[6:]
221
+ if data_str == '[DONE]':
222
+ break
223
+
224
+ try:
225
+ data = json.loads(data_str)
226
+ token = data.get('token', '')
227
+ full_response += token
228
+
229
+ message_placeholder.markdown(f"""
230
+ <div class="assistant-message">
231
+ <img src="./app/static/ai.jpg" class="assistant-avatar" />
232
+ <div class="stMarkdown">{full_response}●</div>
233
+ </div>
234
+ """, unsafe_allow_html=True)
235
+
236
+ except json.JSONDecodeError:
237
+ pass
238
 
 
 
 
 
 
 
 
 
239
  else:
240
+ answer = response_unrelated_question(question=prompt)
241
+
242
+ if isinstance(answer, str):
243
+ full_response = answer
244
+ message_placeholder.markdown(f"""
245
+ <div class="assistant-message">
246
+ <img src="./app/static/ai.jpg" class="assistant-avatar" />
247
+ <div class="stMarkdown">{full_response}</div>
248
+ </div>
249
+ """, unsafe_allow_html=True)
250
+ else:
251
+ full_response = ""
252
+ for line in answer.iter_lines():
253
+ if line:
254
+ line = line.decode('utf-8')
255
+ if line.startswith('data: '):
256
+ data_str = line[6:]
257
+ if data_str == '[DONE]':
258
+ break
 
259
 
260
+ try:
261
+ data = json.loads(data_str)
262
+ token = data.get('token', '')
263
+ full_response += token
264
+
265
+ message_placeholder.markdown(f"""
266
+ <div class="assistant-message">
267
+ <img src="./app/static/ai.jpg" class="assistant-avatar" />
268
+ <div class="stMarkdown">{full_response}●</div>
269
+ </div>
270
+ """, unsafe_allow_html=True)
271
+
272
+ except json.JSONDecodeError:
273
+ pass
274
+
275
  else:
276
+ classify_result = classify_question(question=prompt).json()
277
+
278
+ print(f"The type of user query: {classify_result}")
279
+
280
+ if classify_result == "BIDDING_RELATED":
281
+ ext_answer = get_extractive_answer(question=prompt)
282
+
283
+ for word in generate_text_effect(ext_answer):
284
+ full_response = word
285
+
286
+ message_placeholder.markdown(f"""
287
+ <div class="assistant-message">
288
+ <img src="./app/static/ai.jpg" class="assistant-avatar" />
289
+ <div class="stMarkdown">{full_response}●</div>
290
+ </div>
291
+ """, unsafe_allow_html=True)
292
+ elif classify_result == "ABOUT_CHATBOT":
293
+ answer = introduce_system(question=prompt)
294
+
295
+ if isinstance(answer, str):
296
+ full_response = answer
297
+ message_placeholder.markdown(f"""
298
+ <div class="assistant-message">
299
+ <img src="./app/static/ai.jpg" class="assistant-avatar" />
300
+ <div class="stMarkdown">{full_response}</div>
301
+ </div>
302
+ """, unsafe_allow_html=True)
303
+ else:
304
+ full_response = ""
305
+ for line in answer.iter_lines():
306
+ if line:
307
+ line = line.decode('utf-8')
308
+ if line.startswith('data: '):
309
+ data_str = line[6:]
310
+ if data_str == '[DONE]':
311
+ break
312
+
313
+ try:
314
+ data = json.loads(data_str)
315
+ token = data.get('token', '')
316
+ full_response += token
317
+
318
+ message_placeholder.markdown(f"""
319
+ <div class="assistant-message">
320
+ <img src="./app/static/ai.jpg" class="assistant-avatar" />
321
+ <div class="stMarkdown">{full_response}●</div>
322
+ </div>
323
+ """, unsafe_allow_html=True)
324
+
325
+ except json.JSONDecodeError:
326
+ pass
327
+
328
+ else:
329
+ answer = response_unrelated_question(question=prompt)
330
 
331
+ if isinstance(answer, str):
332
+ full_response = answer
333
+ message_placeholder.markdown(f"""
334
+ <div class="assistant-message">
335
+ <img src="./app/static/ai.jpg" class="assistant-avatar" />
336
+ <div class="stMarkdown">{full_response}</div>
337
+ </div>
338
+ """, unsafe_allow_html=True)
339
+ else:
340
+ full_response = ""
341
+ for line in answer.iter_lines():
342
+ if line:
343
+ line = line.decode('utf-8')
344
+ if line.startswith('data: '):
345
+ data_str = line[6:]
346
+ if data_str == '[DONE]':
347
+ break
348
+
349
+ try:
350
+ data = json.loads(data_str)
351
+ token = data.get('token', '')
352
+ full_response += token
353
+
354
+ message_placeholder.markdown(f"""
355
+ <div class="assistant-message">
356
+ <img src="./app/static/ai.jpg" class="assistant-avatar" />
357
+ <div class="stMarkdown">{full_response}●</div>
358
+ </div>
359
+ """, unsafe_allow_html=True)
360
+
361
+ except json.JSONDecodeError:
362
+ pass
363
+
364
 
365
  message_placeholder.markdown(f"""
366
  <div class="assistant-message">