IntelliChat-v1 / app.py
ntphuc149's picture
Update app.py
c4d7d0b verified
raw
history blame
15 kB
import time
import json
import requests
import streamlit as st
st.set_page_config(page_title="ViBidLQA - Trợ lý AI văn bản pháp luật Việt Nam", page_icon="./app/static/ai.jpg", layout="centered", initial_sidebar_state="expanded")
routing_response_module = st.secrets["ViBidLQA_Routing_Module"]
retrieval_module = st.secrets["ViBidLQA_Retrieval_Module"]
ext_QA_module = st.secrets["ViBidLQA_EQA_Module"]
abs_QA_module = st.secrets["ViBidLQA_AQA_Module"]
url_api_question_classify_model = f"{routing_response_module}/query_classify"
url_api_unrelated_question_response_model = f"{routing_response_module}/response_unrelated_question"
url_api_introduce_system_model = f"{routing_response_module}/about_me"
url_api_retrieval_model = f"{retrieval_module}/search"
url_api_extraction_model = f"{ext_QA_module}/answer"
url_api_generation_model = f"{abs_QA_module}/answer"
with open("./static/styles.css") as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
if 'messages' not in st.session_state:
st.session_state.messages = [{'role': 'assistant', 'content': "Xin chào. Tôi là trợ lý AI văn bản luật Đấu thầu Việt Nam được phát triển bởi Nguyễn Trường Phúc và các cộng sự. Rất vui khi được hỗ trợ bạn trong các vấn đề pháp lý tại Việt Nam!"}]
st.markdown(f"""
<div class=logo_area>
<img src="./app/static/ai.jpg"/>
</div>
""", unsafe_allow_html=True)
st.markdown("<h2 style='text-align: center;'>ViBidLQA</h2>", unsafe_allow_html=True)
answering_method = st.sidebar.selectbox(options=['Extraction', 'Generation'], label='Chọn mô hình trả lời câu hỏi:', index=0)
if answering_method == 'Generation':
print('Switched to generative model...')
if answering_method == 'Extraction':
print('Switched to extraction model...')
def classify_question(question):
data = {
"question": question
}
response = requests.post(url_api_question_classify_model, json=data)
if response.status_code == 200:
print(response)
return response
else:
return f"Lỗi: {response.status_code} - {response.text}"
def introduce_system(question):
data = {
"question": question
}
response = requests.post(url_api_introduce_system_model, json=data, stream=True)
if response.status_code == 200:
return response
else:
return f"Lỗi: {response.status_code} - {response.text}"
def response_unrelated_question(question):
data = {
"question": question
}
response = requests.post(url_api_unrelated_question_response_model, json=data, stream=True)
if response.status_code == 200:
return response
else:
return f"Lỗi: {response.status_code} - {response.text}"
def retrieve_context(question, top_k=10):
data = {
"query": question,
"top_k": top_k
}
response = requests.post(url_api_retrieval_model, json=data)
if response.status_code == 200:
results = response.json()["results"]
print(f"Retrieved bidding legal context: {results[0]['text']}")
return results[0]["text"]
else:
return f"Lỗi: {response.status_code} - {response.text}"
def get_abstractive_answer(question):
context = retrieve_context(question=question)
data = {
"context": context,
"question": question
}
response = requests.post(url_api_generation_model, json=data, stream=True)
if response.status_code == 200:
return response
else:
return f"Lỗi: {response.status_code} - {response.text}"
def generate_text_effect(answer):
words = answer.split()
for i in range(len(words)):
time.sleep(0.03)
yield " ".join(words[:i+1])
def get_extractive_answer(question, stride=20, max_length=256, n_best=50, max_answer_length=512):
context = retrieve_context(question=question)
data = {
"context": context,
"question": question,
"stride": stride,
"max_length": max_length,
"n_best": n_best,
"max_answer_length": max_answer_length
}
response = requests.post(url_api_extraction_model, json=data)
if response.status_code == 200:
result = response.json()
return result["best_answer"]
else:
return f"Lỗi: {response.status_code} - {response.text}"
for message in st.session_state.messages:
if message['role'] == 'assistant':
avatar_class = "assistant-avatar"
message_class = "assistant-message"
avatar = './app/static/ai.jpg'
else:
avatar_class = "user-avatar"
message_class = "user-message"
avatar = './app/static/human.png'
st.markdown(f"""
<div class="{message_class}">
<img src="{avatar}" class="{avatar_class}" />
<div class="stMarkdown">{message['content']}</div>
</div>
""", unsafe_allow_html=True)
if prompt := st.chat_input(placeholder='Tôi có thể giúp được gì cho bạn?'):
st.markdown(f"""
<div class="user-message">
<img src="./app/static/human.png" class="user-avatar" />
<div class="stMarkdown">{prompt}</div>
</div>
""", unsafe_allow_html=True)
st.session_state.messages.append({'role': 'user', 'content': prompt})
message_placeholder = st.empty()
full_response = ""
if answering_method == 'Generation':
classify_result = classify_question(question=prompt).json()
print(f"The type of user query: {classify_result}")
if classify_result == "BIDDING_RELATED":
abs_answer = get_abstractive_answer(question=prompt)
if isinstance(abs_answer, str):
full_response = abs_answer
message_placeholder.markdown(f"""
<div class="assistant-message">
<img src="./app/static/ai.jpg" class="assistant-avatar" />
<div class="stMarkdown">{full_response}</div>
</div>
""", unsafe_allow_html=True)
else:
full_response = ""
for line in abs_answer.iter_lines():
if line:
line = line.decode('utf-8')
if line.startswith('data: '):
data_str = line[6:]
if data_str == '[DONE]':
break
try:
data = json.loads(data_str)
token = data.get('token', '')
full_response += token
message_placeholder.markdown(f"""
<div class="assistant-message">
<img src="./app/static/ai.jpg" class="assistant-avatar" />
<div class="stMarkdown">{full_response}●</div>
</div>
""", unsafe_allow_html=True)
except json.JSONDecodeError:
pass
elif classify_result == "ABOUT_CHATBOT":
answer = introduce_system(question=prompt)
if isinstance(answer, str):
full_response = answer
message_placeholder.markdown(f"""
<div class="assistant-message">
<img src="./app/static/ai.jpg" class="assistant-avatar" />
<div class="stMarkdown">{full_response}</div>
</div>
""", unsafe_allow_html=True)
else:
full_response = ""
for line in answer.iter_lines():
if line:
line = line.decode('utf-8')
if line.startswith('data: '):
data_str = line[6:]
if data_str == '[DONE]':
break
try:
data = json.loads(data_str)
token = data.get('token', '')
full_response += token
message_placeholder.markdown(f"""
<div class="assistant-message">
<img src="./app/static/ai.jpg" class="assistant-avatar" />
<div class="stMarkdown">{full_response}●</div>
</div>
""", unsafe_allow_html=True)
except json.JSONDecodeError:
pass
else:
answer = response_unrelated_question(question=prompt)
if isinstance(answer, str):
full_response = answer
message_placeholder.markdown(f"""
<div class="assistant-message">
<img src="./app/static/ai.jpg" class="assistant-avatar" />
<div class="stMarkdown">{full_response}</div>
</div>
""", unsafe_allow_html=True)
else:
full_response = ""
for line in answer.iter_lines():
if line:
line = line.decode('utf-8')
if line.startswith('data: '):
data_str = line[6:]
if data_str == '[DONE]':
break
try:
data = json.loads(data_str)
token = data.get('token', '')
full_response += token
message_placeholder.markdown(f"""
<div class="assistant-message">
<img src="./app/static/ai.jpg" class="assistant-avatar" />
<div class="stMarkdown">{full_response}●</div>
</div>
""", unsafe_allow_html=True)
except json.JSONDecodeError:
pass
else:
classify_result = classify_question(question=prompt).json()
print(f"The type of user query: {classify_result}")
if classify_result == "BIDDING_RELATED":
ext_answer = get_extractive_answer(question=prompt)
for word in generate_text_effect(ext_answer):
full_response = word
message_placeholder.markdown(f"""
<div class="assistant-message">
<img src="./app/static/ai.jpg" class="assistant-avatar" />
<div class="stMarkdown">{full_response}●</div>
</div>
""", unsafe_allow_html=True)
elif classify_result == "ABOUT_CHATBOT":
answer = introduce_system(question=prompt)
if isinstance(answer, str):
full_response = answer
message_placeholder.markdown(f"""
<div class="assistant-message">
<img src="./app/static/ai.jpg" class="assistant-avatar" />
<div class="stMarkdown">{full_response}</div>
</div>
""", unsafe_allow_html=True)
else:
full_response = ""
for line in answer.iter_lines():
if line:
line = line.decode('utf-8')
if line.startswith('data: '):
data_str = line[6:]
if data_str == '[DONE]':
break
try:
data = json.loads(data_str)
token = data.get('token', '')
full_response += token
message_placeholder.markdown(f"""
<div class="assistant-message">
<img src="./app/static/ai.jpg" class="assistant-avatar" />
<div class="stMarkdown">{full_response}●</div>
</div>
""", unsafe_allow_html=True)
except json.JSONDecodeError:
pass
else:
answer = response_unrelated_question(question=prompt)
if isinstance(answer, str):
full_response = answer
message_placeholder.markdown(f"""
<div class="assistant-message">
<img src="./app/static/ai.jpg" class="assistant-avatar" />
<div class="stMarkdown">{full_response}</div>
</div>
""", unsafe_allow_html=True)
else:
full_response = ""
for line in answer.iter_lines():
if line:
line = line.decode('utf-8')
if line.startswith('data: '):
data_str = line[6:]
if data_str == '[DONE]':
break
try:
data = json.loads(data_str)
token = data.get('token', '')
full_response += token
message_placeholder.markdown(f"""
<div class="assistant-message">
<img src="./app/static/ai.jpg" class="assistant-avatar" />
<div class="stMarkdown">{full_response}●</div>
</div>
""", unsafe_allow_html=True)
except json.JSONDecodeError:
pass
message_placeholder.markdown(f"""
<div class="assistant-message">
<img src="./app/static/ai.jpg" class="assistant-avatar" />
<div class="stMarkdown">
{full_response}
</div>
</div>
""", unsafe_allow_html=True)
st.session_state.messages.append({'role': 'assistant', 'content': full_response})