ntphuc149 commited on
Commit
c268c45
·
verified ·
1 Parent(s): f912d04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +194 -141
app.py CHANGED
@@ -1,142 +1,195 @@
1
- import time
2
- import requests
3
- import streamlit as st
4
-
5
- st.set_page_config(page_title="ViBidLawQA - Trợ lý AI hỗ trợ hỏi đáp luật Việt Nam", page_icon="./app/static/ai.jpg", layout="centered", initial_sidebar_state="expanded")
6
-
7
- with open("./static/styles.css") as f:
8
- st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
9
-
10
- if 'messages' not in st.session_state:
11
- st.session_state.messages = []
12
-
13
- st.markdown(f"""
14
- <div class=logo_area>
15
- <img src="./app/static/ai.jpg"/>
16
- </div>
17
- """, unsafe_allow_html=True)
18
- st.markdown("<h2 style='text-align: center;'>ViBidLQA Bot</h2>", unsafe_allow_html=True)
19
-
20
- url_api_extraction_model = st.sidebar.text_input(label="URL API Extraction model:")
21
- url_api_generation_model = st.sidebar.text_input(label="URL API Generation model:")
22
-
23
- answering_method = st.sidebar.selectbox(options=['Extraction', 'Generation'], label='Chọn hình trả lời câu hỏi:', index=0)
24
- context = st.sidebar.text_area(label='Nội dung văn bản pháp luật Việt Nam:', placeholder='Vui lòng nhập nội dung văn bản pháp luật Việt Nam tại đây...', height=500)
25
-
26
- if answering_method == 'Generation':
27
- print('Switching to generative model...')
28
- print('Loading generative model...')
29
-
30
- if answering_method == 'Extraction':
31
- print('Switching to extraction model...')
32
- print('Loading extraction model...')
33
-
34
- def get_abstractive_answer(context, question):
35
- data = {
36
- "context": context,
37
- "question": question
38
- }
39
-
40
- response = requests.post(url_api_generation_model, json=data)
41
- if response.status_code == 200:
42
- result = response.json()
43
- return result["answer"]
44
- else:
45
- return f"Lỗi: {response.status_code} - {response.text}"
46
-
47
- def generate_text_effect(answer):
48
- words = answer.split()
49
- for i in range(len(words)):
50
- time.sleep(0.03)
51
- yield " ".join(words[:i+1])
52
-
53
- def get_extractive_answer(context, question, stride=20, max_length=256, n_best=50, max_answer_length=512):
54
- data = {
55
- "context": context,
56
- "question": question,
57
- "stride": stride,
58
- "max_length": max_length,
59
- "n_best": n_best,
60
- "max_answer_length": max_answer_length
61
- }
62
-
63
- response = requests.post(url_api_extraction_model, json=data)
64
-
65
- if response.status_code == 200:
66
- result = response.json()
67
- return result["best_answer"]
68
- else:
69
- return f"Lỗi: {response.status_code} - {response.text}"
70
-
71
- for message in st.session_state.messages:
72
- if message['role'] == 'assistant':
73
- avatar_class = "assistant-avatar"
74
- message_class = "assistant-message"
75
- avatar = './app/static/ai.jpg'
76
- else:
77
- avatar_class = "user-avatar"
78
- message_class = "user-message"
79
- avatar = './app/static/human.png'
80
- st.markdown(f"""
81
- <div class="{message_class}">
82
- <img src="{avatar}" class="{avatar_class}" />
83
- <div class="stMarkdown">{message['content']}</div>
84
- </div>
85
- """, unsafe_allow_html=True)
86
-
87
- if prompt := st.chat_input(placeholder='Tôi có thể giúp được gì cho bạn?'):
88
- st.markdown(f"""
89
- <div class="user-message">
90
- <img src="./app/static/human.png" class="user-avatar" />
91
- <div class="stMarkdown">{prompt}</div>
92
- </div>
93
- """, unsafe_allow_html=True)
94
- st.session_state.messages.append({'role': 'user', 'content': prompt})
95
-
96
- message_placeholder = st.empty()
97
-
98
- for _ in range(2):
99
- for dots in ["", "●●", "●●●"]:
100
- time.sleep(0.2)
101
- message_placeholder.markdown(f"""
102
- <div class="assistant-message">
103
- <img src="./app/static/ai.jpg" class="assistant-avatar" />
104
- <div class="stMarkdown">{dots}</div>
105
- </div>
106
- """, unsafe_allow_html=True)
107
-
108
- full_response = ""
109
- if answering_method == 'Generation':
110
- abs_answer = get_abstractive_answer(context=context, question=prompt)
111
- for word in generate_text_effect(abs_answer):
112
- full_response = word
113
-
114
- message_placeholder.markdown(f"""
115
- <div class="assistant-message">
116
- <img src="./app/static/ai.jpg" class="assistant-avatar" />
117
- <div class="stMarkdown">{full_response}●</div>
118
- </div>
119
- """, unsafe_allow_html=True)
120
-
121
- else:
122
- ext_answer = get_extractive_answer(context=context, question=prompt)
123
- for word in generate_text_effect(ext_answer):
124
- full_response = word
125
-
126
- message_placeholder.markdown(f"""
127
- <div class="assistant-message">
128
- <img src="./app/static/ai.jpg" class="assistant-avatar" />
129
- <div class="stMarkdown">{full_response}●</div>
130
- </div>
131
- """, unsafe_allow_html=True)
132
-
133
- message_placeholder.markdown(f"""
134
- <div class="assistant-message">
135
- <img src="./app/static/ai.jpg" class="assistant-avatar" />
136
- <div class="stMarkdown">
137
- {full_response}
138
- </div>
139
- </div>
140
- """, unsafe_allow_html=True)
141
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  st.session_state.messages.append({'role': 'assistant', 'content': full_response})
 
1
+ import time
2
+ import json
3
+ import requests
4
+ import streamlit as st
5
+
6
+
7
+ st.set_page_config(page_title="ViBidLQA - Trợ lý AI hỗ trợ hỏi đáp luật Việt Nam", page_icon="./app/static/ai.jpg", layout="wide", initial_sidebar_state="expanded")
8
+
9
+ with open("./static/styles.css") as f:
10
+ st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
11
+
12
+ if 'messages' not in st.session_state:
13
+ st.session_state.messages = [{'role': 'assistant', 'content': "Xin chào. Tôi là trợ lý AI văn bản luật Đấu thầu Việt Nam được phát triển bởi Nguyễn Trường Phúc. Rất vui khi được hỗ trợ bạn trong các vấn đề pháp lý tại Việt Nam!"}]
14
+
15
+ st.markdown(f"""
16
+ <div class=logo_area>
17
+ <img src="./app/static/ai.jpg"/>
18
+ </div>
19
+ """, unsafe_allow_html=True)
20
+ st.markdown("<h2 style='text-align: center;'>The ViBidLQA System </h2>", unsafe_allow_html=True)
21
+
22
+ url_api_retrieval_model = st.sidebar.text_input(label="URL API Retrieval model:")
23
+ url_api_extraction_model = st.sidebar.text_input(label="URL API Extraction model:")
24
+ url_api_generation_model = st.sidebar.text_input(label="URL API Generation model:")
25
+
26
+ answering_method = st.sidebar.selectbox(options=['Extraction', 'Generation'], label='Chọn mô hình trả lời câu hỏi:', index=0)
27
+
28
+ if answering_method == 'Generation':
29
+ print('Switching to generative model...')
30
+ print('Loading generative model...')
31
+
32
+ if answering_method == 'Extraction':
33
+ print('Switching to extraction model...')
34
+ print('Loading extraction model...')
35
+
36
+ def retrieve_context(question, top_k=10):
37
+ data = {
38
+ "query": question,
39
+ "top_k": top_k
40
+ }
41
+
42
+ response = requests.post(url_api_retrieval_model, json=data)
43
+
44
+ if response.status_code == 200:
45
+ results = response.json()["results"]
46
+ print(f"Văn bản pháp luật được truy hồi: {results[0]['text']}")
47
+ print("="*100)
48
+ return results[0]["text"]
49
+ else:
50
+ return f"Lỗi: {response.status_code} - {response.text}"
51
+
52
+ def get_abstractive_answer(question):
53
+ context = retrieve_context(question=question)
54
+
55
+ data = {
56
+ "context": context,
57
+ "question": question
58
+ }
59
+
60
+ response = requests.post(url_api_generation_model, json=data)
61
+ if response.status_code == 200:
62
+ result = response.json()
63
+ return result["answer"]
64
+ else:
65
+ return f"Lỗi: {response.status_code} - {response.text}"
66
+
67
+ def get_abstractive_answer_stream(question):
68
+ context = retrieve_context(question=question)
69
+
70
+ data = {
71
+ "context": context,
72
+ "question": question
73
+ }
74
+
75
+ # Sử dụng requests với stream=True
76
+ response = requests.post(url_api_generation_model, json=data, stream=True)
77
+
78
+ if response.status_code == 200:
79
+ # Trả về response để xử lý streaming
80
+ return response
81
+ else:
82
+ return f"Lỗi: {response.status_code} - {response.text}"
83
+
84
+ def generate_text_effect(answer):
85
+ words = answer.split()
86
+ for i in range(len(words)):
87
+ time.sleep(0.03)
88
+ yield " ".join(words[:i+1])
89
+
90
+ def get_extractive_answer(question, stride=20, max_length=256, n_best=50, max_answer_length=512):
91
+ context = retrieve_context(question=question)
92
+
93
+ data = {
94
+ "context": context,
95
+ "question": question,
96
+ "stride": stride,
97
+ "max_length": max_length,
98
+ "n_best": n_best,
99
+ "max_answer_length": max_answer_length
100
+ }
101
+
102
+ response = requests.post(url_api_extraction_model, json=data)
103
+
104
+ if response.status_code == 200:
105
+ result = response.json()
106
+ return result["best_answer"]
107
+ else:
108
+ return f"Lỗi: {response.status_code} - {response.text}"
109
+
110
+ for message in st.session_state.messages:
111
+ if message['role'] == 'assistant':
112
+ avatar_class = "assistant-avatar"
113
+ message_class = "assistant-message"
114
+ avatar = './app/static/ai.jpg'
115
+ else:
116
+ avatar_class = "user-avatar"
117
+ message_class = "user-message"
118
+ avatar = './app/static/human.png'
119
+ st.markdown(f"""
120
+ <div class="{message_class}">
121
+ <img src="{avatar}" class="{avatar_class}" />
122
+ <div class="stMarkdown">{message['content']}</div>
123
+ </div>
124
+ """, unsafe_allow_html=True)
125
+
126
+ if prompt := st.chat_input(placeholder='Tôi có thể giúp được gì cho bạn?'):
127
+ st.markdown(f"""
128
+ <div class="user-message">
129
+ <img src="./app/static/human.png" class="user-avatar" />
130
+ <div class="stMarkdown">{prompt}</div>
131
+ </div>
132
+ """, unsafe_allow_html=True)
133
+ st.session_state.messages.append({'role': 'user', 'content': prompt})
134
+
135
+ message_placeholder = st.empty()
136
+
137
+
138
+ full_response = ""
139
+ if answering_method == 'Generation':
140
+ response_stream = get_abstractive_answer_stream(question=prompt)
141
+
142
+ if isinstance(response_stream, str):
143
+ full_response = response_stream
144
+ message_placeholder.markdown(f"""
145
+ <div class="assistant-message">
146
+ <img src="./app/static/ai.jpg" class="assistant-avatar" />
147
+ <div class="stMarkdown">{full_response}</div>
148
+ </div>
149
+ """, unsafe_allow_html=True)
150
+ else:
151
+ full_response = ""
152
+ for line in response_stream.iter_lines():
153
+ if line:
154
+ line = line.decode('utf-8')
155
+ if line.startswith('data: '):
156
+ data_str = line[6:]
157
+ if data_str == '[DONE]':
158
+ break
159
+
160
+ try:
161
+ data = json.loads(data_str)
162
+ token = data.get('token', '')
163
+ full_response += token
164
+
165
+ message_placeholder.markdown(f"""
166
+ <div class="assistant-message">
167
+ <img src="./app/static/ai.jpg" class="assistant-avatar" />
168
+ <div class="stMarkdown">{full_response}●</div>
169
+ </div>
170
+ """, unsafe_allow_html=True)
171
+
172
+ except json.JSONDecodeError:
173
+ pass
174
+ else:
175
+ ext_answer = get_extractive_answer(question=prompt)
176
+ for word in generate_text_effect(ext_answer):
177
+ full_response = word
178
+
179
+ message_placeholder.markdown(f"""
180
+ <div class="assistant-message">
181
+ <img src="./app/static/ai.jpg" class="assistant-avatar" />
182
+ <div class="stMarkdown">{full_response}●</div>
183
+ </div>
184
+ """, unsafe_allow_html=True)
185
+
186
+ message_placeholder.markdown(f"""
187
+ <div class="assistant-message">
188
+ <img src="./app/static/ai.jpg" class="assistant-avatar" />
189
+ <div class="stMarkdown">
190
+ {full_response}
191
+ </div>
192
+ </div>
193
+ """, unsafe_allow_html=True)
194
+
195
  st.session_state.messages.append({'role': 'assistant', 'content': full_response})