Spaces:
Running
Running
import time | |
import requests | |
import streamlit as st | |
st.set_page_config(page_title="ViBidLawQA - Trợ lý AI hỗ trợ hỏi đáp luật Việt Nam", page_icon="./app/static/ai.jpg", layout="centered", initial_sidebar_state="expanded") | |
with open("./static/styles.css") as f: | |
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) | |
if 'messages' not in st.session_state: | |
st.session_state.messages = [] | |
st.markdown(f""" | |
<div class=logo_area> | |
<img src="./app/static/ai.jpg"/> | |
</div> | |
""", unsafe_allow_html=True) | |
st.markdown("<h2 style='text-align: center;'>ViBidLQA Bot</h2>", unsafe_allow_html=True) | |
url_api_extraction_model = st.sidebar.text_input(label="URL API Extraction model:") | |
url_api_generation_model = st.sidebar.text_input(label="URL API Generation model:") | |
answering_method = st.sidebar.selectbox(options=['Extraction', 'Generation'], label='Chọn mô hình trả lời câu hỏi:', index=0) | |
context = st.sidebar.text_area(label='Nội dung văn bản pháp luật Việt Nam:', placeholder='Vui lòng nhập nội dung văn bản pháp luật Việt Nam tại đây...', height=500) | |
if answering_method == 'Generation': | |
print('Switching to generative model...') | |
print('Loading generative model...') | |
if answering_method == 'Extraction': | |
print('Switching to extraction model...') | |
print('Loading extraction model...') | |
def get_abstractive_answer(context, question): | |
data = { | |
"context": context, | |
"question": question | |
} | |
response = requests.post(url_api_generation_model, json=data) | |
if response.status_code == 200: | |
result = response.json() | |
return result["answer"] | |
else: | |
return f"Lỗi: {response.status_code} - {response.text}" | |
def generate_text_effect(answer): | |
words = answer.split() | |
for i in range(len(words)): | |
time.sleep(0.03) | |
yield " ".join(words[:i+1]) | |
def get_extractive_answer(context, question, stride=20, max_length=256, n_best=50, max_answer_length=512): | |
data = { | |
"context": context, | |
"question": question, | |
"stride": stride, | |
"max_length": max_length, | |
"n_best": n_best, | |
"max_answer_length": max_answer_length | |
} | |
response = requests.post(url_api_extraction_model, json=data) | |
if response.status_code == 200: | |
result = response.json() | |
return result["best_answer"] | |
else: | |
return f"Lỗi: {response.status_code} - {response.text}" | |
for message in st.session_state.messages: | |
if message['role'] == 'assistant': | |
avatar_class = "assistant-avatar" | |
message_class = "assistant-message" | |
avatar = './app/static/ai.jpg' | |
else: | |
avatar_class = "user-avatar" | |
message_class = "user-message" | |
avatar = './app/static/human.png' | |
st.markdown(f""" | |
<div class="{message_class}"> | |
<img src="{avatar}" class="{avatar_class}" /> | |
<div class="stMarkdown">{message['content']}</div> | |
</div> | |
""", unsafe_allow_html=True) | |
if prompt := st.chat_input(placeholder='Tôi có thể giúp được gì cho bạn?'): | |
st.markdown(f""" | |
<div class="user-message"> | |
<img src="./app/static/human.png" class="user-avatar" /> | |
<div class="stMarkdown">{prompt}</div> | |
</div> | |
""", unsafe_allow_html=True) | |
st.session_state.messages.append({'role': 'user', 'content': prompt}) | |
message_placeholder = st.empty() | |
for _ in range(2): | |
for dots in ["●", "●●", "●●●"]: | |
time.sleep(0.2) | |
message_placeholder.markdown(f""" | |
<div class="assistant-message"> | |
<img src="./app/static/ai.jpg" class="assistant-avatar" /> | |
<div class="stMarkdown">{dots}</div> | |
</div> | |
""", unsafe_allow_html=True) | |
full_response = "" | |
if answering_method == 'Generation': | |
abs_answer = get_abstractive_answer(context=context, question=prompt) | |
for word in generate_text_effect(abs_answer): | |
full_response = word | |
message_placeholder.markdown(f""" | |
<div class="assistant-message"> | |
<img src="./app/static/ai.jpg" class="assistant-avatar" /> | |
<div class="stMarkdown">{full_response}●</div> | |
</div> | |
""", unsafe_allow_html=True) | |
else: | |
ext_answer = get_extractive_answer(context=context, question=prompt) | |
for word in generate_text_effect(ext_answer): | |
full_response = word | |
message_placeholder.markdown(f""" | |
<div class="assistant-message"> | |
<img src="./app/static/ai.jpg" class="assistant-avatar" /> | |
<div class="stMarkdown">{full_response}●</div> | |
</div> | |
""", unsafe_allow_html=True) | |
message_placeholder.markdown(f""" | |
<div class="assistant-message"> | |
<img src="./app/static/ai.jpg" class="assistant-avatar" /> | |
<div class="stMarkdown"> | |
{full_response} | |
</div> | |
</div> | |
""", unsafe_allow_html=True) | |
st.session_state.messages.append({'role': 'assistant', 'content': full_response}) |