|
import time |
|
import json |
|
import requests |
|
import streamlit as st |
|
import os |
|
from urllib.parse import urlencode, urlparse, parse_qs |
|
|
|
st.set_page_config(page_title="ViEduChat - Trợ lý AI giáo dục Việt Nam", page_icon="./app/static/ai.jpg", layout="centered", initial_sidebar_state="collapsed") |
|
|
|
|
|
routing_response_module = st.secrets["ViEduQA_Routing_Module"] |
|
retrieval_module = st.secrets["ViEduQA_Retrieval_Module"] |
|
reranker_module = st.secrets["ViEduQA_Rerank_Module"] |
|
abs_QA_module = st.secrets["ViEduQA_QA_Module"] |
|
|
|
url_api_question_classify_model = f"{routing_response_module}/query_classify" |
|
url_api_unrelated_question_response_model = f"{routing_response_module}/response_unrelated_question" |
|
url_api_introduce_system_model = f"{routing_response_module}/about_me" |
|
url_api_retrieval_model = f"{retrieval_module}/search" |
|
url_api_reranker_model = f"{reranker_module}/rerank" |
|
url_api_generation_model = f"{abs_QA_module}/answer" |
|
url_api_extract_reference_model = f"{routing_response_module}/extract_references_unstream" |
|
|
|
|
|
|
|
with open("./static/styles.css") as f: |
|
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) |
|
|
|
if 'messages' not in st.session_state: |
|
st.session_state.messages = [{'role': 'assistant', 'content': "Xin chào. Tôi là trợ lý AI giáo dục Việt Nam được phát triển bởi Đào Thị Ngọc Ánh. Rất vui khi được hỗ trợ bạn trong học tập!"}] |
|
|
|
st.markdown(f""" |
|
<div class=logo_area> |
|
<img src="./app/static/ai.jpg"/> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
st.markdown("<h2 style='text-align: center;'>ViEduChat</h2>", unsafe_allow_html=True) |
|
|
|
def classify_question(question): |
|
data = { |
|
"question": question |
|
} |
|
|
|
response = requests.post(url_api_question_classify_model, json=data) |
|
|
|
if response.status_code == 200: |
|
print(response) |
|
return response |
|
else: |
|
return f"Lỗi: {response.status_code} - {response.text}" |
|
|
|
def introduce_system(question): |
|
data = { |
|
"question": question |
|
} |
|
|
|
response = requests.post(url_api_introduce_system_model, json=data, stream=True) |
|
|
|
if response.status_code == 200: |
|
return response |
|
else: |
|
return f"Lỗi: {response.status_code} - {response.text}" |
|
|
|
def response_unrelated_question(question): |
|
data = { |
|
"question": question |
|
} |
|
|
|
response = requests.post(url_api_unrelated_question_response_model, json=data, stream=True) |
|
|
|
if response.status_code == 200: |
|
return response |
|
else: |
|
return f"Lỗi: {response.status_code} - {response.text}" |
|
|
|
def retrieve_context(question, top_k=10): |
|
data = { |
|
"query": question, |
|
"top_k": top_k |
|
} |
|
|
|
response = requests.post(url_api_retrieval_model, json=data) |
|
|
|
if response.status_code == 200: |
|
results = response.json()["results"] |
|
return results |
|
else: |
|
return f"Lỗi tại Retrieval Module: {response.status_code} - {response.text}" |
|
|
|
def rerank_context(url_rerank_module, question, relevant_docs, top_k=5): |
|
data = { |
|
"question": question, |
|
"relevant_docs": relevant_docs, |
|
"top_k": top_k |
|
} |
|
|
|
response = requests.post(url_rerank_module, json=data) |
|
|
|
if response.status_code == 200: |
|
results = response.json()["reranked_docs"] |
|
return results |
|
else: |
|
return f"Lỗi tại Rerank module: {response.status_code} - {response.text}" |
|
|
|
def get_abstractive_answer(context, question): |
|
data = { |
|
"context": context, |
|
"question": question |
|
} |
|
|
|
response = requests.post(url_api_generation_model, json=data, stream=True) |
|
|
|
if response.status_code == 200: |
|
return response |
|
else: |
|
return f"Lỗi: {response.status_code} - {response.text}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_text_effect(answer): |
|
words = answer.split() |
|
for i in range(len(words)): |
|
time.sleep(0.03) |
|
yield " ".join(words[:i+1]) |
|
|
|
for message in st.session_state.messages: |
|
if message['role'] == 'assistant': |
|
avatar_class = "assistant-avatar" |
|
message_class = "assistant-message" |
|
avatar = './app/static/ai.jpg' |
|
else: |
|
avatar_class = "" |
|
message_class = "user-message" |
|
avatar = '' |
|
st.markdown(f""" |
|
<div class="{message_class}"> |
|
<img src="{avatar}" class="{avatar_class}" /> |
|
<div class="stMarkdown">{message['content']}</div> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
if prompt := st.chat_input(placeholder='Tôi có thể giúp được gì cho bạn?'): |
|
st.markdown(f""" |
|
<div class="user-message"> |
|
<div class="stMarkdown">{prompt}</div> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
st.session_state.messages.append({'role': 'user', 'content': prompt}) |
|
|
|
message_placeholder = st.empty() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
retrieved_context = retrieve_context(question=prompt, top_k=10) |
|
retrieved_context = [item['text'] for item in retrieved_context] |
|
reranked_context = rerank_context(url_rerank_module=url_api_reranker_model, |
|
question=prompt, |
|
relevant_docs=retrieved_context, |
|
top_k=5)[0] |
|
|
|
abs_answer = get_abstractive_answer(context=reranked_context, question=prompt) |
|
|
|
if isinstance(abs_answer, str): |
|
full_response = abs_answer |
|
message_placeholder.markdown(f""" |
|
<div class="assistant-message"> |
|
<img src="./app/static/ai.jpg" class="assistant-avatar" /> |
|
<div class="stMarkdown">{full_response}</div> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
else: |
|
full_response = "" |
|
for line in abs_answer.iter_lines(): |
|
if line: |
|
line = line.decode('utf-8') |
|
if line.startswith('data: '): |
|
data_str = line[6:] |
|
if data_str == '[DONE]': |
|
break |
|
|
|
try: |
|
data = json.loads(data_str) |
|
token = data.get('token', '') |
|
full_response += token |
|
|
|
message_placeholder.markdown(f""" |
|
<div class="assistant-message"> |
|
<img src="./app/static/ai.jpg" class="assistant-avatar" /> |
|
<div class="stMarkdown">{full_response}●</div> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
except json.JSONDecodeError: |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
message_placeholder.markdown(f""" |
|
<div class="assistant-message"> |
|
<img src="./app/static/ai.jpg" class="assistant-avatar" /> |
|
<div class="stMarkdown"> |
|
{full_response} |
|
</div> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
st.session_state.messages.append({'role': 'assistant', 'content': full_response}) |