ViEduChat / app.py
ntphuc149's picture
Update app.py
dc9c974 verified
import time
import json
import requests
import streamlit as st
import os
from urllib.parse import urlencode, urlparse, parse_qs
st.set_page_config(page_title="ViEduChat - Trợ lý AI giáo dục Việt Nam", page_icon="./app/static/ai.jpg", layout="centered", initial_sidebar_state="collapsed")
# ==== MODULE URL ====
routing_response_module = st.secrets["ViEduQA_Routing_Module"]
retrieval_module = st.secrets["ViEduQA_Retrieval_Module"]
reranker_module = st.secrets["ViEduQA_Rerank_Module"]
abs_QA_module = st.secrets["ViEduQA_QA_Module"]
url_api_question_classify_model = f"{routing_response_module}/query_classify"
url_api_unrelated_question_response_model = f"{routing_response_module}/response_unrelated_question"
url_api_introduce_system_model = f"{routing_response_module}/about_me"
url_api_retrieval_model = f"{retrieval_module}/search"
url_api_reranker_model = f"{reranker_module}/rerank"
url_api_generation_model = f"{abs_QA_module}/answer"
url_api_extract_reference_model = f"{routing_response_module}/extract_references_unstream"
# ========== STREAMLIT UI ==========
with open("./static/styles.css") as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
if 'messages' not in st.session_state:
st.session_state.messages = [{'role': 'assistant', 'content': "Xin chào. Tôi là trợ lý AI giáo dục Việt Nam được phát triển bởi Đào Thị Ngọc Ánh. Rất vui khi được hỗ trợ bạn trong học tập!"}]
st.markdown(f"""
<div class=logo_area>
<img src="./app/static/ai.jpg"/>
</div>
""", unsafe_allow_html=True)
st.markdown("<h2 style='text-align: center;'>ViEduChat</h2>", unsafe_allow_html=True)
def classify_question(question):
data = {
"question": question
}
response = requests.post(url_api_question_classify_model, json=data)
if response.status_code == 200:
print(response)
return response
else:
return f"Lỗi: {response.status_code} - {response.text}"
def introduce_system(question):
data = {
"question": question
}
response = requests.post(url_api_introduce_system_model, json=data, stream=True)
if response.status_code == 200:
return response
else:
return f"Lỗi: {response.status_code} - {response.text}"
def response_unrelated_question(question):
data = {
"question": question
}
response = requests.post(url_api_unrelated_question_response_model, json=data, stream=True)
if response.status_code == 200:
return response
else:
return f"Lỗi: {response.status_code} - {response.text}"
def retrieve_context(question, top_k=10):
data = {
"query": question,
"top_k": top_k
}
response = requests.post(url_api_retrieval_model, json=data)
if response.status_code == 200:
results = response.json()["results"]
return results
else:
return f"Lỗi tại Retrieval Module: {response.status_code} - {response.text}"
def rerank_context(url_rerank_module, question, relevant_docs, top_k=5):
data = {
"question": question,
"relevant_docs": relevant_docs,
"top_k": top_k
}
response = requests.post(url_rerank_module, json=data)
if response.status_code == 200:
results = response.json()["reranked_docs"]
return results
else:
return f"Lỗi tại Rerank module: {response.status_code} - {response.text}"
def get_abstractive_answer(context, question):
data = {
"context": context,
"question": question
}
response = requests.post(url_api_generation_model, json=data, stream=True)
if response.status_code == 200:
return response
else:
return f"Lỗi: {response.status_code} - {response.text}"
# def get_references(context, question, answer):
# data = {
# "context": context,
# "question": question,
# "answer": answer
# }
# response = requests.post(url_api_extract_reference_model, json=data)
# if response.status_code == 200:
# return response.json()["refs"]
# else:
# return f"Lỗi tại module Reference Extractor: {response.status_code} - {response.text}"
def generate_text_effect(answer):
words = answer.split()
for i in range(len(words)):
time.sleep(0.03)
yield " ".join(words[:i+1])
for message in st.session_state.messages:
if message['role'] == 'assistant':
avatar_class = "assistant-avatar"
message_class = "assistant-message"
avatar = './app/static/ai.jpg'
else:
avatar_class = ""
message_class = "user-message"
avatar = ''
st.markdown(f"""
<div class="{message_class}">
<img src="{avatar}" class="{avatar_class}" />
<div class="stMarkdown">{message['content']}</div>
</div>
""", unsafe_allow_html=True)
if prompt := st.chat_input(placeholder='Tôi có thể giúp được gì cho bạn?'):
st.markdown(f"""
<div class="user-message">
<div class="stMarkdown">{prompt}</div>
</div>
""", unsafe_allow_html=True)
st.session_state.messages.append({'role': 'user', 'content': prompt})
message_placeholder = st.empty()
# full_response = ""
# classify_result = classify_question(question=prompt).json()
# print(f"The type of user query: {classify_result}")
# if classify_result == "EDUCATION_RELATED":
retrieved_context = retrieve_context(question=prompt, top_k=10)
retrieved_context = [item['text'] for item in retrieved_context]
reranked_context = rerank_context(url_rerank_module=url_api_reranker_model,
question=prompt,
relevant_docs=retrieved_context,
top_k=5)[0]
abs_answer = get_abstractive_answer(context=reranked_context, question=prompt)
if isinstance(abs_answer, str):
full_response = abs_answer
message_placeholder.markdown(f"""
<div class="assistant-message">
<img src="./app/static/ai.jpg" class="assistant-avatar" />
<div class="stMarkdown">{full_response}</div>
</div>
""", unsafe_allow_html=True)
else:
full_response = ""
for line in abs_answer.iter_lines():
if line:
line = line.decode('utf-8')
if line.startswith('data: '):
data_str = line[6:]
if data_str == '[DONE]':
break
try:
data = json.loads(data_str)
token = data.get('token', '')
full_response += token
message_placeholder.markdown(f"""
<div class="assistant-message">
<img src="./app/static/ai.jpg" class="assistant-avatar" />
<div class="stMarkdown">{full_response}●</div>
</div>
""", unsafe_allow_html=True)
except json.JSONDecodeError:
pass
# refs = st.expander("Tài liệu tham khảo", expanded=False)
# refs_list = get_references(context=reranked_context, question=prompt, answer=full_response)
# print(refs_list)
# refs.write(f"{refs_list}")
# elif classify_result == "ABOUT_CHATBOT":
# answer = introduce_system(question=prompt)
# if isinstance(answer, str):
# full_response = answer
# message_placeholder.markdown(f"""
# <div class="assistant-message">
# <img src="./app/static/ai.jpg" class="assistant-avatar" />
# <div class="stMarkdown">{full_response}</div>
# </div>
# """, unsafe_allow_html=True)
# else:
# full_response = ""
# for line in answer.iter_lines():
# if line:
# line = line.decode('utf-8')
# if line.startswith('data: '):
# data_str = line[6:]
# if data_str == '[DONE]':
# break
# try:
# data = json.loads(data_str)
# token = data.get('token', '')
# full_response += token
# message_placeholder.markdown(f"""
# <div class="assistant-message">
# <img src="./app/static/ai.jpg" class="assistant-avatar" />
# <div class="stMarkdown">{full_response}●</div>
# </div>
# """, unsafe_allow_html=True)
# except json.JSONDecodeError:
# pass
# else:
# answer = response_unrelated_question(question=prompt)
# if isinstance(answer, str):
# full_response = answer
# message_placeholder.markdown(f"""
# <div class="assistant-message">
# <img src="./app/static/ai.jpg" class="assistant-avatar" />
# <div class="stMarkdown">{full_response}</div>
# </div>
# """, unsafe_allow_html=True)
# else:
# full_response = ""
# for line in answer.iter_lines():
# if line:
# line = line.decode('utf-8')
# if line.startswith('data: '):
# data_str = line[6:]
# if data_str == '[DONE]':
# break
# try:
# data = json.loads(data_str)
# token = data.get('token', '')
# full_response += token
# message_placeholder.markdown(f"""
# <div class="assistant-message">
# <img src="./app/static/ai.jpg" class="assistant-avatar" />
# <div class="stMarkdown">{full_response}●</div>
# </div>
# """, unsafe_allow_html=True)
# except json.JSONDecodeError:
# pass
message_placeholder.markdown(f"""
<div class="assistant-message">
<img src="./app/static/ai.jpg" class="assistant-avatar" />
<div class="stMarkdown">
{full_response}
</div>
</div>
""", unsafe_allow_html=True)
st.session_state.messages.append({'role': 'assistant', 'content': full_response})