import time import json import requests import streamlit as st st.set_page_config(page_title="ViBidLQA - Trợ lý AI văn bản pháp luật Việt Nam", page_icon="./app/static/ai.jpg", layout="centered", initial_sidebar_state="expanded") routing_response_module = st.secrets["ViBidLQA_Routing_Module"] retrieval_module = st.secrets["ViBidLQA_Retrieval_Module"] ext_QA_module = st.secrets["ViBidLQA_EQA_Module"] abs_QA_module = st.secrets["ViBidLQA_AQA_Module"] url_api_question_classify_model = f"{routing_response_module}/query_classify" url_api_unrelated_question_response_model = f"{routing_response_module}/response_unrelated_question" url_api_introduce_system_model = f"{routing_response_module}/about_me" url_api_retrieval_model = f"{retrieval_module}/search" url_api_extraction_model = f"{ext_QA_module}/answer" url_api_generation_model = f"{abs_QA_module}/answer" with open("./static/styles.css") as f: st.markdown(f"", unsafe_allow_html=True) if 'messages' not in st.session_state: st.session_state.messages = [{'role': 'assistant', 'content': "Xin chào. Tôi là trợ lý AI văn bản luật Đấu thầu Việt Nam được phát triển bởi Nguyễn Trường Phúc và các cộng sự. Rất vui khi được hỗ trợ bạn trong các vấn đề pháp lý tại Việt Nam!"}] st.markdown(f"""
""", unsafe_allow_html=True) st.markdown("

ViBidLQA

", unsafe_allow_html=True) answering_method = st.sidebar.selectbox(options=['Extraction', 'Generation'], label='Chọn mô hình trả lời câu hỏi:', index=0) if answering_method == 'Generation': print('Switched to generative model...') if answering_method == 'Extraction': print('Switched to extraction model...') def classify_question(question): data = { "question": question } response = requests.post(url_api_question_classify_model, json=data) if response.status_code == 200: print(response) return response else: return f"Lỗi: {response.status_code} - {response.text}" def introduce_system(question): data = { "question": question } response = requests.post(url_api_introduce_system_model, json=data, stream=True) if response.status_code == 200: return response else: return f"Lỗi: {response.status_code} - {response.text}" def response_unrelated_question(question): data = { "question": question } response = requests.post(url_api_unrelated_question_response_model, json=data, stream=True) if response.status_code == 200: return response else: return f"Lỗi: {response.status_code} - {response.text}" def retrieve_context(question, top_k=10): data = { "query": question, "top_k": top_k } response = requests.post(url_api_retrieval_model, json=data) if response.status_code == 200: results = response.json()["results"] print(f"Retrieved bidding legal context: {results[0]['text']}") return results[0]["text"] else: return f"Lỗi: {response.status_code} - {response.text}" def get_abstractive_answer(question): context = retrieve_context(question=question) data = { "context": context, "question": question } response = requests.post(url_api_generation_model, json=data, stream=True) if response.status_code == 200: return response else: return f"Lỗi: {response.status_code} - {response.text}" def generate_text_effect(answer): words = answer.split() for i in range(len(words)): time.sleep(0.03) yield " ".join(words[:i+1]) def get_extractive_answer(question, stride=20, max_length=256, n_best=50, max_answer_length=512): context = retrieve_context(question=question) data = { "context": context, "question": question, "stride": stride, "max_length": max_length, "n_best": n_best, "max_answer_length": max_answer_length } response = requests.post(url_api_extraction_model, json=data) if response.status_code == 200: result = response.json() return result["best_answer"] else: return f"Lỗi: {response.status_code} - {response.text}" for message in st.session_state.messages: if message['role'] == 'assistant': avatar_class = "assistant-avatar" message_class = "assistant-message" avatar = './app/static/ai.jpg' else: avatar_class = "user-avatar" message_class = "user-message" avatar = './app/static/human.png' st.markdown(f"""
{message['content']}
""", unsafe_allow_html=True) if prompt := st.chat_input(placeholder='Tôi có thể giúp được gì cho bạn?'): st.markdown(f"""
{prompt}
""", unsafe_allow_html=True) st.session_state.messages.append({'role': 'user', 'content': prompt}) message_placeholder = st.empty() full_response = "" if answering_method == 'Generation': classify_result = classify_question(question=prompt).json() print(f"The type of user query: {classify_result}") if classify_result == "BIDDING_RELATED": abs_answer = get_abstractive_answer(question=prompt) if isinstance(abs_answer, str): full_response = abs_answer message_placeholder.markdown(f"""
{full_response}
""", unsafe_allow_html=True) else: full_response = "" for line in abs_answer.iter_lines(): if line: line = line.decode('utf-8') if line.startswith('data: '): data_str = line[6:] if data_str == '[DONE]': break try: data = json.loads(data_str) token = data.get('token', '') full_response += token message_placeholder.markdown(f"""
{full_response}●
""", unsafe_allow_html=True) except json.JSONDecodeError: pass elif classify_result == "ABOUT_CHATBOT": answer = introduce_system(question=prompt) if isinstance(answer, str): full_response = answer message_placeholder.markdown(f"""
{full_response}
""", unsafe_allow_html=True) else: full_response = "" for line in answer.iter_lines(): if line: line = line.decode('utf-8') if line.startswith('data: '): data_str = line[6:] if data_str == '[DONE]': break try: data = json.loads(data_str) token = data.get('token', '') full_response += token message_placeholder.markdown(f"""
{full_response}●
""", unsafe_allow_html=True) except json.JSONDecodeError: pass else: answer = response_unrelated_question(question=prompt) if isinstance(answer, str): full_response = answer message_placeholder.markdown(f"""
{full_response}
""", unsafe_allow_html=True) else: full_response = "" for line in answer.iter_lines(): if line: line = line.decode('utf-8') if line.startswith('data: '): data_str = line[6:] if data_str == '[DONE]': break try: data = json.loads(data_str) token = data.get('token', '') full_response += token message_placeholder.markdown(f"""
{full_response}●
""", unsafe_allow_html=True) except json.JSONDecodeError: pass else: classify_result = classify_question(question=prompt).json() print(f"The type of user query: {classify_result}") if classify_result == "BIDDING_RELATED": ext_answer = get_extractive_answer(question=prompt) for word in generate_text_effect(ext_answer): full_response = word message_placeholder.markdown(f"""
{full_response}●
""", unsafe_allow_html=True) elif classify_result == "ABOUT_CHATBOT": answer = introduce_system(question=prompt) if isinstance(answer, str): full_response = answer message_placeholder.markdown(f"""
{full_response}
""", unsafe_allow_html=True) else: full_response = "" for line in answer.iter_lines(): if line: line = line.decode('utf-8') if line.startswith('data: '): data_str = line[6:] if data_str == '[DONE]': break try: data = json.loads(data_str) token = data.get('token', '') full_response += token message_placeholder.markdown(f"""
{full_response}●
""", unsafe_allow_html=True) except json.JSONDecodeError: pass else: answer = response_unrelated_question(question=prompt) if isinstance(answer, str): full_response = answer message_placeholder.markdown(f"""
{full_response}
""", unsafe_allow_html=True) else: full_response = "" for line in answer.iter_lines(): if line: line = line.decode('utf-8') if line.startswith('data: '): data_str = line[6:] if data_str == '[DONE]': break try: data = json.loads(data_str) token = data.get('token', '') full_response += token message_placeholder.markdown(f"""
{full_response}●
""", unsafe_allow_html=True) except json.JSONDecodeError: pass message_placeholder.markdown(f"""
{full_response}
""", unsafe_allow_html=True) st.session_state.messages.append({'role': 'assistant', 'content': full_response})