import time import json import requests import streamlit as st import os from urllib.parse import urlencode, urlparse, parse_qs st.set_page_config(page_title="ViEduChat - Trợ lý AI giáo dục Việt Nam", page_icon="./app/static/ai.jpg", layout="centered", initial_sidebar_state="collapsed") # ==== MODULE URL ==== routing_response_module = st.secrets["ViEduQA_Routing_Module"] retrieval_module = st.secrets["ViEduQA_Retrieval_Module"] reranker_module = st.secrets["ViEduQA_Rerank_Module"] abs_QA_module = st.secrets["ViEduQA_QA_Module"] url_api_question_classify_model = f"{routing_response_module}/query_classify" url_api_unrelated_question_response_model = f"{routing_response_module}/response_unrelated_question" url_api_introduce_system_model = f"{routing_response_module}/about_me" url_api_retrieval_model = f"{retrieval_module}/search" url_api_reranker_model = f"{reranker_module}/rerank" url_api_generation_model = f"{abs_QA_module}/answer" url_api_extract_reference_model = f"{routing_response_module}/extract_references_unstream" # ========== STREAMLIT UI ========== with open("./static/styles.css") as f: st.markdown(f"", unsafe_allow_html=True) if 'messages' not in st.session_state: st.session_state.messages = [{'role': 'assistant', 'content': "Xin chào. Tôi là trợ lý AI giáo dục Việt Nam được phát triển bởi Đào Thị Ngọc Ánh. Rất vui khi được hỗ trợ bạn trong học tập!"}] st.markdown(f"""
""", unsafe_allow_html=True) st.markdown("

ViEduChat

", unsafe_allow_html=True) def classify_question(question): data = { "question": question } response = requests.post(url_api_question_classify_model, json=data) if response.status_code == 200: print(response) return response else: return f"Lỗi: {response.status_code} - {response.text}" def introduce_system(question): data = { "question": question } response = requests.post(url_api_introduce_system_model, json=data, stream=True) if response.status_code == 200: return response else: return f"Lỗi: {response.status_code} - {response.text}" def response_unrelated_question(question): data = { "question": question } response = requests.post(url_api_unrelated_question_response_model, json=data, stream=True) if response.status_code == 200: return response else: return f"Lỗi: {response.status_code} - {response.text}" def retrieve_context(question, top_k=10): data = { "query": question, "top_k": top_k } response = requests.post(url_api_retrieval_model, json=data) if response.status_code == 200: results = response.json()["results"] return results else: return f"Lỗi tại Retrieval Module: {response.status_code} - {response.text}" def rerank_context(url_rerank_module, question, relevant_docs, top_k=5): data = { "question": question, "relevant_docs": relevant_docs, "top_k": top_k } response = requests.post(url_rerank_module, json=data) if response.status_code == 200: results = response.json()["reranked_docs"] return results else: return f"Lỗi tại Rerank module: {response.status_code} - {response.text}" def get_abstractive_answer(context, question): data = { "context": context, "question": question } response = requests.post(url_api_generation_model, json=data, stream=True) if response.status_code == 200: return response else: return f"Lỗi: {response.status_code} - {response.text}" # def get_references(context, question, answer): # data = { # "context": context, # "question": question, # "answer": answer # } # response = requests.post(url_api_extract_reference_model, json=data) # if response.status_code == 200: # return response.json()["refs"] # else: # return f"Lỗi tại module Reference Extractor: {response.status_code} - {response.text}" def generate_text_effect(answer): words = answer.split() for i in range(len(words)): time.sleep(0.03) yield " ".join(words[:i+1]) for message in st.session_state.messages: if message['role'] == 'assistant': avatar_class = "assistant-avatar" message_class = "assistant-message" avatar = './app/static/ai.jpg' else: avatar_class = "" message_class = "user-message" avatar = '' st.markdown(f"""
{message['content']}
""", unsafe_allow_html=True) if prompt := st.chat_input(placeholder='Tôi có thể giúp được gì cho bạn?'): st.markdown(f"""
{prompt}
""", unsafe_allow_html=True) st.session_state.messages.append({'role': 'user', 'content': prompt}) message_placeholder = st.empty() # full_response = "" # classify_result = classify_question(question=prompt).json() # print(f"The type of user query: {classify_result}") # if classify_result == "EDUCATION_RELATED": retrieved_context = retrieve_context(question=prompt, top_k=10) retrieved_context = [item['text'] for item in retrieved_context] reranked_context = rerank_context(url_rerank_module=url_api_reranker_model, question=prompt, relevant_docs=retrieved_context, top_k=5)[0] abs_answer = get_abstractive_answer(context=reranked_context, question=prompt) if isinstance(abs_answer, str): full_response = abs_answer message_placeholder.markdown(f"""
{full_response}
""", unsafe_allow_html=True) else: full_response = "" for line in abs_answer.iter_lines(): if line: line = line.decode('utf-8') if line.startswith('data: '): data_str = line[6:] if data_str == '[DONE]': break try: data = json.loads(data_str) token = data.get('token', '') full_response += token message_placeholder.markdown(f"""
{full_response}●
""", unsafe_allow_html=True) except json.JSONDecodeError: pass # refs = st.expander("Tài liệu tham khảo", expanded=False) # refs_list = get_references(context=reranked_context, question=prompt, answer=full_response) # print(refs_list) # refs.write(f"{refs_list}") # elif classify_result == "ABOUT_CHATBOT": # answer = introduce_system(question=prompt) # if isinstance(answer, str): # full_response = answer # message_placeholder.markdown(f""" #
# #
{full_response}
#
# """, unsafe_allow_html=True) # else: # full_response = "" # for line in answer.iter_lines(): # if line: # line = line.decode('utf-8') # if line.startswith('data: '): # data_str = line[6:] # if data_str == '[DONE]': # break # try: # data = json.loads(data_str) # token = data.get('token', '') # full_response += token # message_placeholder.markdown(f""" #
# #
{full_response}●
#
# """, unsafe_allow_html=True) # except json.JSONDecodeError: # pass # else: # answer = response_unrelated_question(question=prompt) # if isinstance(answer, str): # full_response = answer # message_placeholder.markdown(f""" #
# #
{full_response}
#
# """, unsafe_allow_html=True) # else: # full_response = "" # for line in answer.iter_lines(): # if line: # line = line.decode('utf-8') # if line.startswith('data: '): # data_str = line[6:] # if data_str == '[DONE]': # break # try: # data = json.loads(data_str) # token = data.get('token', '') # full_response += token # message_placeholder.markdown(f""" #
# #
{full_response}●
#
# """, unsafe_allow_html=True) # except json.JSONDecodeError: # pass message_placeholder.markdown(f"""
{full_response}
""", unsafe_allow_html=True) st.session_state.messages.append({'role': 'assistant', 'content': full_response})