ntphuc149 commited on
Commit
634ebde
·
verified ·
1 Parent(s): c4d7d0b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -234
app.py CHANGED
@@ -3,18 +3,19 @@ import json
3
  import requests
4
  import streamlit as st
5
 
6
- st.set_page_config(page_title="ViBidLQA - Trợ lý AI văn bản pháp luật Việt Nam", page_icon="./app/static/ai.jpg", layout="centered", initial_sidebar_state="expanded")
7
 
8
  routing_response_module = st.secrets["ViBidLQA_Routing_Module"]
9
  retrieval_module = st.secrets["ViBidLQA_Retrieval_Module"]
10
  ext_QA_module = st.secrets["ViBidLQA_EQA_Module"]
11
  abs_QA_module = st.secrets["ViBidLQA_AQA_Module"]
12
 
 
13
  url_api_question_classify_model = f"{routing_response_module}/query_classify"
14
  url_api_unrelated_question_response_model = f"{routing_response_module}/response_unrelated_question"
15
- url_api_introduce_system_model = f"{routing_response_module}/about_me"
16
  url_api_retrieval_model = f"{retrieval_module}/search"
17
- url_api_extraction_model = f"{ext_QA_module}/answer"
18
  url_api_generation_model = f"{abs_QA_module}/answer"
19
 
20
  with open("./static/styles.css") as f:
@@ -30,14 +31,6 @@ st.markdown(f"""
30
  """, unsafe_allow_html=True)
31
  st.markdown("<h2 style='text-align: center;'>ViBidLQA</h2>", unsafe_allow_html=True)
32
 
33
- answering_method = st.sidebar.selectbox(options=['Extraction', 'Generation'], label='Chọn mô hình trả lời câu hỏi:', index=0)
34
-
35
- if answering_method == 'Generation':
36
- print('Switched to generative model...')
37
-
38
- if answering_method == 'Extraction':
39
- print('Switched to extraction model...')
40
-
41
  def classify_question(question):
42
  data = {
43
  "question": question
@@ -85,16 +78,36 @@ def retrieve_context(question, top_k=10):
85
 
86
  if response.status_code == 200:
87
  results = response.json()["results"]
88
- print(f"Retrieved bidding legal context: {results[0]['text']}")
89
- return results[0]["text"]
90
  else:
91
- return f"Lỗi: {response.status_code} - {response.text}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  def get_abstractive_answer(question):
94
- context = retrieve_context(question=question)
 
 
 
 
 
 
95
 
96
  data = {
97
- "context": context,
98
  "question": question
99
  }
100
 
@@ -111,35 +124,15 @@ def generate_text_effect(answer):
111
  time.sleep(0.03)
112
  yield " ".join(words[:i+1])
113
 
114
- def get_extractive_answer(question, stride=20, max_length=256, n_best=50, max_answer_length=512):
115
- context = retrieve_context(question=question)
116
-
117
- data = {
118
- "context": context,
119
- "question": question,
120
- "stride": stride,
121
- "max_length": max_length,
122
- "n_best": n_best,
123
- "max_answer_length": max_answer_length
124
- }
125
-
126
- response = requests.post(url_api_extraction_model, json=data)
127
-
128
- if response.status_code == 200:
129
- result = response.json()
130
- return result["best_answer"]
131
- else:
132
- return f"Lỗi: {response.status_code} - {response.text}"
133
-
134
  for message in st.session_state.messages:
135
  if message['role'] == 'assistant':
136
  avatar_class = "assistant-avatar"
137
  message_class = "assistant-message"
138
  avatar = './app/static/ai.jpg'
139
  else:
140
- avatar_class = "user-avatar"
141
  message_class = "user-message"
142
- avatar = './app/static/human.png'
143
  st.markdown(f"""
144
  <div class="{message_class}">
145
  <img src="{avatar}" class="{avatar_class}" />
@@ -150,8 +143,7 @@ for message in st.session_state.messages:
150
  if prompt := st.chat_input(placeholder='Tôi có thể giúp được gì cho bạn?'):
151
  st.markdown(f"""
152
  <div class="user-message">
153
- <img src="./app/static/human.png" class="user-avatar" />
154
- <div class="stMarkdown">{prompt}</div>
155
  </div>
156
  """, unsafe_allow_html=True)
157
  st.session_state.messages.append({'role': 'user', 'content': prompt})
@@ -159,208 +151,117 @@ if prompt := st.chat_input(placeholder='Tôi có thể giúp được gì cho b
159
  message_placeholder = st.empty()
160
 
161
  full_response = ""
162
- if answering_method == 'Generation':
163
- classify_result = classify_question(question=prompt).json()
164
-
165
- print(f"The type of user query: {classify_result}")
166
-
167
- if classify_result == "BIDDING_RELATED":
168
- abs_answer = get_abstractive_answer(question=prompt)
169
-
170
- if isinstance(abs_answer, str):
171
- full_response = abs_answer
172
- message_placeholder.markdown(f"""
173
- <div class="assistant-message">
174
- <img src="./app/static/ai.jpg" class="assistant-avatar" />
175
- <div class="stMarkdown">{full_response}</div>
176
- </div>
177
- """, unsafe_allow_html=True)
178
- else:
179
- full_response = ""
180
- for line in abs_answer.iter_lines():
181
- if line:
182
- line = line.decode('utf-8')
183
- if line.startswith('data: '):
184
- data_str = line[6:]
185
- if data_str == '[DONE]':
186
- break
 
 
 
 
187
 
188
- try:
189
- data = json.loads(data_str)
190
- token = data.get('token', '')
191
- full_response += token
192
-
193
- message_placeholder.markdown(f"""
194
- <div class="assistant-message">
195
- <img src="./app/static/ai.jpg" class="assistant-avatar" />
196
- <div class="stMarkdown">{full_response}●</div>
197
- </div>
198
- """, unsafe_allow_html=True)
199
-
200
- except json.JSONDecodeError:
201
- pass
202
-
203
- elif classify_result == "ABOUT_CHATBOT":
204
- answer = introduce_system(question=prompt)
205
-
206
- if isinstance(answer, str):
207
- full_response = answer
208
- message_placeholder.markdown(f"""
209
- <div class="assistant-message">
210
- <img src="./app/static/ai.jpg" class="assistant-avatar" />
211
- <div class="stMarkdown">{full_response}</div>
212
- </div>
213
- """, unsafe_allow_html=True)
214
- else:
215
- full_response = ""
216
- for line in answer.iter_lines():
217
- if line:
218
- line = line.decode('utf-8')
219
- if line.startswith('data: '):
220
- data_str = line[6:]
221
- if data_str == '[DONE]':
222
- break
223
 
224
- try:
225
- data = json.loads(data_str)
226
- token = data.get('token', '')
227
- full_response += token
228
-
229
- message_placeholder.markdown(f"""
230
- <div class="assistant-message">
231
- <img src="./app/static/ai.jpg" class="assistant-avatar" />
232
- <div class="stMarkdown">{full_response}●</div>
233
- </div>
234
- """, unsafe_allow_html=True)
235
-
236
- except json.JSONDecodeError:
237
- pass
238
-
239
  else:
240
- answer = response_unrelated_question(question=prompt)
241
-
242
- if isinstance(answer, str):
243
- full_response = answer
244
- message_placeholder.markdown(f"""
245
- <div class="assistant-message">
246
- <img src="./app/static/ai.jpg" class="assistant-avatar" />
247
- <div class="stMarkdown">{full_response}</div>
248
- </div>
249
- """, unsafe_allow_html=True)
250
- else:
251
- full_response = ""
252
- for line in answer.iter_lines():
253
- if line:
254
- line = line.decode('utf-8')
255
- if line.startswith('data: '):
256
- data_str = line[6:]
257
- if data_str == '[DONE]':
258
- break
259
 
260
- try:
261
- data = json.loads(data_str)
262
- token = data.get('token', '')
263
- full_response += token
264
-
265
- message_placeholder.markdown(f"""
266
- <div class="assistant-message">
267
- <img src="./app/static/ai.jpg" class="assistant-avatar" />
268
- <div class="stMarkdown">{full_response}●</div>
269
- </div>
270
- """, unsafe_allow_html=True)
271
-
272
- except json.JSONDecodeError:
273
- pass
274
 
275
  else:
276
- classify_result = classify_question(question=prompt).json()
277
-
278
- print(f"The type of user query: {classify_result}")
279
-
280
- if classify_result == "BIDDING_RELATED":
281
- ext_answer = get_extractive_answer(question=prompt)
282
-
283
- for word in generate_text_effect(ext_answer):
284
- full_response = word
285
-
286
- message_placeholder.markdown(f"""
287
- <div class="assistant-message">
288
- <img src="./app/static/ai.jpg" class="assistant-avatar" />
289
- <div class="stMarkdown">{full_response}●</div>
290
- </div>
291
- """, unsafe_allow_html=True)
292
- elif classify_result == "ABOUT_CHATBOT":
293
- answer = introduce_system(question=prompt)
294
-
295
- if isinstance(answer, str):
296
- full_response = answer
297
- message_placeholder.markdown(f"""
298
- <div class="assistant-message">
299
- <img src="./app/static/ai.jpg" class="assistant-avatar" />
300
- <div class="stMarkdown">{full_response}</div>
301
- </div>
302
- """, unsafe_allow_html=True)
303
- else:
304
- full_response = ""
305
- for line in answer.iter_lines():
306
- if line:
307
- line = line.decode('utf-8')
308
- if line.startswith('data: '):
309
- data_str = line[6:]
310
- if data_str == '[DONE]':
311
- break
312
-
313
- try:
314
- data = json.loads(data_str)
315
- token = data.get('token', '')
316
- full_response += token
317
-
318
- message_placeholder.markdown(f"""
319
- <div class="assistant-message">
320
- <img src="./app/static/ai.jpg" class="assistant-avatar" />
321
- <div class="stMarkdown">{full_response}●</div>
322
- </div>
323
- """, unsafe_allow_html=True)
324
-
325
- except json.JSONDecodeError:
326
- pass
327
-
328
  else:
329
- answer = response_unrelated_question(question=prompt)
330
-
331
- if isinstance(answer, str):
332
- full_response = answer
333
- message_placeholder.markdown(f"""
334
- <div class="assistant-message">
335
- <img src="./app/static/ai.jpg" class="assistant-avatar" />
336
- <div class="stMarkdown">{full_response}</div>
337
- </div>
338
- """, unsafe_allow_html=True)
339
- else:
340
- full_response = ""
341
- for line in answer.iter_lines():
342
- if line:
343
- line = line.decode('utf-8')
344
- if line.startswith('data: '):
345
- data_str = line[6:]
346
- if data_str == '[DONE]':
347
- break
348
 
349
- try:
350
- data = json.loads(data_str)
351
- token = data.get('token', '')
352
- full_response += token
353
-
354
- message_placeholder.markdown(f"""
355
- <div class="assistant-message">
356
- <img src="./app/static/ai.jpg" class="assistant-avatar" />
357
- <div class="stMarkdown">{full_response}●</div>
358
- </div>
359
- """, unsafe_allow_html=True)
360
-
361
- except json.JSONDecodeError:
362
- pass
363
-
364
 
365
  message_placeholder.markdown(f"""
366
  <div class="assistant-message">
 
3
  import requests
4
  import streamlit as st
5
 
6
+ st.set_page_config(page_title="ViBidLQA - Trợ lý AI hỗ trợ hỏi đáp luật Việt Nam", page_icon="./app/static/ai.jpg", layout="centered", initial_sidebar_state="collapsed")
7
 
8
  routing_response_module = st.secrets["ViBidLQA_Routing_Module"]
9
  retrieval_module = st.secrets["ViBidLQA_Retrieval_Module"]
10
  ext_QA_module = st.secrets["ViBidLQA_EQA_Module"]
11
  abs_QA_module = st.secrets["ViBidLQA_AQA_Module"]
12
 
13
+
14
  url_api_question_classify_model = f"{routing_response_module}/query_classify"
15
  url_api_unrelated_question_response_model = f"{routing_response_module}/response_unrelated_question"
16
+ url_api_introduce_system_modelff = f"{routing_response_module}/about_me"
17
  url_api_retrieval_model = f"{retrieval_module}/search"
18
+ url_api_reranker_model = f"{reranker_module}/rerank"
19
  url_api_generation_model = f"{abs_QA_module}/answer"
20
 
21
  with open("./static/styles.css") as f:
 
31
  """, unsafe_allow_html=True)
32
  st.markdown("<h2 style='text-align: center;'>ViBidLQA</h2>", unsafe_allow_html=True)
33
 
 
 
 
 
 
 
 
 
34
  def classify_question(question):
35
  data = {
36
  "question": question
 
78
 
79
  if response.status_code == 200:
80
  results = response.json()["results"]
81
+ return results
 
82
  else:
83
+ return f"Lỗi tại Retrieval Module: {response.status_code} - {response.text}"
84
+
85
+ def rerank_context(url_rerank_module, question, relevant_docs, top_k=5):
86
+ data = {
87
+ "question": question,
88
+ "relevant_docs": relevant_docs,
89
+ "top_k": top_k
90
+ }
91
+
92
+ response = requests.post(url_rerank_module, json=data)
93
+
94
+ if response.status_code == 200:
95
+ results = response.json()["reranked_docs"]
96
+ return results
97
+ else:
98
+ return f"Lỗi tại Rerank module: {response.status_code} - {response.text}"
99
 
100
  def get_abstractive_answer(question):
101
+ retrieved_context = retrieve_context(question=question)
102
+ retrieved_context = [item['text'] for item in retrieved_context]
103
+
104
+ reranked_context = rerank_context(url_rerank_module=url_api_reranker_model,
105
+ question=question,
106
+ relevant_docs=retrieved_context,
107
+ top_k=5)[0]
108
 
109
  data = {
110
+ "context": reranked_context,
111
  "question": question
112
  }
113
 
 
124
  time.sleep(0.03)
125
  yield " ".join(words[:i+1])
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  for message in st.session_state.messages:
128
  if message['role'] == 'assistant':
129
  avatar_class = "assistant-avatar"
130
  message_class = "assistant-message"
131
  avatar = './app/static/ai.jpg'
132
  else:
133
+ avatar_class = ""
134
  message_class = "user-message"
135
+ avatar = ''
136
  st.markdown(f"""
137
  <div class="{message_class}">
138
  <img src="{avatar}" class="{avatar_class}" />
 
143
  if prompt := st.chat_input(placeholder='Tôi có thể giúp được gì cho bạn?'):
144
  st.markdown(f"""
145
  <div class="user-message">
146
+ <div class="stMarkdown">{prompt}</div>
 
147
  </div>
148
  """, unsafe_allow_html=True)
149
  st.session_state.messages.append({'role': 'user', 'content': prompt})
 
151
  message_placeholder = st.empty()
152
 
153
  full_response = ""
154
+ classify_result = classify_question(question=prompt).json()
155
+
156
+ print(f"The type of user query: {classify_result}")
157
+
158
+ if classify_result == "BIDDING_RELATED":
159
+ abs_answer = get_abstractive_answer(question=prompt)
160
+
161
+ if isinstance(abs_answer, str):
162
+ full_response = abs_answer
163
+ message_placeholder.markdown(f"""
164
+ <div class="assistant-message">
165
+ <img src="./app/static/ai.jpg" class="assistant-avatar" />
166
+ <div class="stMarkdown">{full_response}</div>
167
+ </div>
168
+ """, unsafe_allow_html=True)
169
+ else:
170
+ full_response = ""
171
+ for line in abs_answer.iter_lines():
172
+ if line:
173
+ line = line.decode('utf-8')
174
+ if line.startswith('data: '):
175
+ data_str = line[6:]
176
+ if data_str == '[DONE]':
177
+ break
178
+
179
+ try:
180
+ data = json.loads(data_str)
181
+ token = data.get('token', '')
182
+ full_response += token
183
 
184
+ message_placeholder.markdown(f"""
185
+ <div class="assistant-message">
186
+ <img src="./app/static/ai.jpg" class="assistant-avatar" />
187
+ <div class="stMarkdown">{full_response}●</div>
188
+ </div>
189
+ """, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
+ except json.JSONDecodeError:
192
+ pass
193
+
194
+ elif classify_result == "ABOUT_CHATBOT":
195
+ answer = introduce_system(question=prompt)
196
+
197
+ if isinstance(answer, str):
198
+ full_response = answer
199
+ message_placeholder.markdown(f"""
200
+ <div class="assistant-message">
201
+ <img src="./app/static/ai.jpg" class="assistant-avatar" />
202
+ <div class="stMarkdown">{full_response}</div>
203
+ </div>
204
+ """, unsafe_allow_html=True)
 
205
  else:
206
+ full_response = ""
207
+ for line in answer.iter_lines():
208
+ if line:
209
+ line = line.decode('utf-8')
210
+ if line.startswith('data: '):
211
+ data_str = line[6:]
212
+ if data_str == '[DONE]':
213
+ break
214
+
215
+ try:
216
+ data = json.loads(data_str)
217
+ token = data.get('token', '')
218
+ full_response += token
 
 
 
 
 
 
219
 
220
+ message_placeholder.markdown(f"""
221
+ <div class="assistant-message">
222
+ <img src="./app/static/ai.jpg" class="assistant-avatar" />
223
+ <div class="stMarkdown">{full_response}●</div>
224
+ </div>
225
+ """, unsafe_allow_html=True)
226
+
227
+ except json.JSONDecodeError:
228
+ pass
 
 
 
 
 
229
 
230
  else:
231
+ answer = response_unrelated_question(question=prompt)
232
+
233
+ if isinstance(answer, str):
234
+ full_response = answer
235
+ message_placeholder.markdown(f"""
236
+ <div class="assistant-message">
237
+ <img src="./app/static/ai.jpg" class="assistant-avatar" />
238
+ <div class="stMarkdown">{full_response}</div>
239
+ </div>
240
+ """, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  else:
242
+ full_response = ""
243
+ for line in answer.iter_lines():
244
+ if line:
245
+ line = line.decode('utf-8')
246
+ if line.startswith('data: '):
247
+ data_str = line[6:]
248
+ if data_str == '[DONE]':
249
+ break
250
+
251
+ try:
252
+ data = json.loads(data_str)
253
+ token = data.get('token', '')
254
+ full_response += token
 
 
 
 
 
 
255
 
256
+ message_placeholder.markdown(f"""
257
+ <div class="assistant-message">
258
+ <img src="./app/static/ai.jpg" class="assistant-avatar" />
259
+ <div class="stMarkdown">{full_response}●</div>
260
+ </div>
261
+ """, unsafe_allow_html=True)
262
+
263
+ except json.JSONDecodeError:
264
+ pass
 
 
 
 
 
 
265
 
266
  message_placeholder.markdown(f"""
267
  <div class="assistant-message">