|
import gc
|
|
import time
|
|
import torch
|
|
import numpy as np
|
|
import streamlit as st
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering
|
|
|
|
st.set_page_config(page_title="ViBidLawQA - Hệ thống hỏi đáp trực tuyến luật Việt Nam", page_icon="./app/static/ai.png", layout="centered", initial_sidebar_state="expanded")
|
|
|
|
with open("./static/styles.css") as f:
|
|
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
|
|
|
|
if 'messages' not in st.session_state:
|
|
st.session_state.messages = []
|
|
|
|
st.markdown(f"""
|
|
<div class=logo_area>
|
|
<img src="./app/static/ai.png"/>
|
|
</div>
|
|
""", unsafe_allow_html=True)
|
|
st.markdown("<h2 style='text-align: center;'>ViBidLawQA_v2</h2>", unsafe_allow_html=True)
|
|
|
|
answering_method = st.sidebar.selectbox(options=['Extraction', 'Generation'], label='Chọn mô hình trả lời câu hỏi:', index=0)
|
|
context = st.sidebar.text_area(label='Nội dung văn bản pháp luật Việt Nam:', placeholder='Vui lòng nhập nội dung văn bản pháp luật Việt Nam tại đây...', height=500)
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
if answering_method == 'Generation' and 'aqa_model' not in st.session_state:
|
|
if 'eqa_model' and 'eqa_tokenizer' in st.session_state:
|
|
del st.session_state.eqa_model
|
|
del st.session_state.eqa_tokenizer
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
print('Switching to generative model...')
|
|
print('Loading generative model...')
|
|
st.session_state.aqa_model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path='./models/AQA_model').to(device)
|
|
st.session_state.aqa_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path='./models/AQA_model')
|
|
|
|
if answering_method == 'Extraction' and 'eqa_model' not in st.session_state:
|
|
if 'aqa_model' and 'aqa_tokenizer' in st.session_state:
|
|
del st.session_state.aqa_model
|
|
del st.session_state.aqa_tokenizer
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
print('Switching to extraction model...')
|
|
print('Loading extraction model...')
|
|
st.session_state.eqa_model = AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name_or_path='./models/EQA_model').to(device)
|
|
st.session_state.eqa_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path='./models/EQA_model')
|
|
|
|
def get_abstractive_answer(context, question, max_length=1024, max_target_length=512):
|
|
inputs = st.session_state.aqa_tokenizer(question,
|
|
context,
|
|
max_length=max_length,
|
|
truncation='only_second',
|
|
padding='max_length',
|
|
return_tensors='pt')
|
|
outputs = st.session_state.aqa_model.generate(inputs=inputs['input_ids'].to(device),
|
|
attention_mask=inputs['attention_mask'].to(device),
|
|
max_length=max_target_length)
|
|
answer = st.session_state.aqa_tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_space=True)
|
|
|
|
if not answer.endswith('.'):
|
|
answer += '.'
|
|
|
|
return answer
|
|
|
|
def generate_text_effect(answer):
|
|
words = answer.split()
|
|
for i in range(len(words)):
|
|
time.sleep(0.05)
|
|
yield " ".join(words[:i+1])
|
|
|
|
def get_extractive_answer(context, question, stride=20, max_length=256, n_best=50, max_answer_length=512):
|
|
inputs = st.session_state.eqa_tokenizer(question,
|
|
context,
|
|
max_length=max_length,
|
|
truncation='only_second',
|
|
stride=stride,
|
|
return_overflowing_tokens=True,
|
|
return_offsets_mapping=True,
|
|
padding='max_length')
|
|
for i in range(len(inputs['input_ids'])):
|
|
sequence_ids = inputs.sequence_ids(i)
|
|
offset = inputs['offset_mapping'][i]
|
|
inputs['offset_mapping'][i] = [
|
|
o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
|
|
]
|
|
|
|
input_ids = torch.tensor(inputs["input_ids"]).to(device)
|
|
attention_mask = torch.tensor(inputs["attention_mask"]).to(device)
|
|
|
|
with torch.no_grad():
|
|
outputs = st.session_state.eqa_model(input_ids=input_ids, attention_mask=attention_mask)
|
|
|
|
start_logits = outputs.start_logits.cpu().numpy()
|
|
end_logits = outputs.end_logits.cpu().numpy()
|
|
|
|
answers = []
|
|
for i in range(len(inputs["input_ids"])):
|
|
start_logit = start_logits[i]
|
|
end_logit = end_logits[i]
|
|
offsets = inputs["offset_mapping"][i]
|
|
|
|
start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
|
|
end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
|
|
for start_index in start_indexes:
|
|
for end_index in end_indexes:
|
|
if offsets[start_index] is None or offsets[end_index] is None:
|
|
continue
|
|
if end_index < start_index or end_index - start_index + 1 > max_answer_length:
|
|
continue
|
|
|
|
answer = {
|
|
"text": context[offsets[start_index][0] : offsets[end_index][1]],
|
|
"logit_score": start_logit[start_index] + end_logit[end_index],
|
|
}
|
|
answers.append(answer)
|
|
|
|
if len(answers) > 0:
|
|
best_answer = max(answers, key=lambda x: x["logit_score"])
|
|
return best_answer["text"]
|
|
else:
|
|
return ""
|
|
|
|
for message in st.session_state.messages:
|
|
if message['role'] == 'assistant':
|
|
avatar_class = "assistant-avatar"
|
|
message_class = "assistant-message"
|
|
avatar = './app/static/ai.png'
|
|
else:
|
|
avatar_class = "user-avatar"
|
|
message_class = "user-message"
|
|
avatar = './app/static/human.png'
|
|
st.markdown(f"""
|
|
<div class="{message_class}">
|
|
<img src="{avatar}" class="{avatar_class}" />
|
|
<div class="stMarkdown">{message['content']}</div>
|
|
</div>
|
|
""", unsafe_allow_html=True)
|
|
|
|
if prompt := st.chat_input(placeholder='Tôi có thể giúp được gì cho bạn?'):
|
|
st.markdown(f"""
|
|
<div class="user-message">
|
|
<img src="./app/static/human.png" class="user-avatar" />
|
|
<div class="stMarkdown">{prompt}</div>
|
|
</div>
|
|
""", unsafe_allow_html=True)
|
|
st.session_state.messages.append({'role': 'user', 'content': prompt})
|
|
|
|
message_placeholder = st.empty()
|
|
|
|
for _ in range(2):
|
|
for dots in ["●", "●●", "●●●"]:
|
|
time.sleep(0.2)
|
|
message_placeholder.markdown(f"""
|
|
<div class="assistant-message">
|
|
<img src="./app/static/ai.png" class="assistant-avatar" />
|
|
<div class="stMarkdown">{dots}</div>
|
|
</div>
|
|
""", unsafe_allow_html=True)
|
|
|
|
full_response = ""
|
|
if answering_method == 'Generation':
|
|
abs_answer = get_abstractive_answer(context=context, question=prompt)
|
|
for word in generate_text_effect(abs_answer):
|
|
full_response = word
|
|
|
|
message_placeholder.markdown(f"""
|
|
<div class="assistant-message">
|
|
<img src="./app/static/ai.png" class="assistant-avatar" />
|
|
<div class="stMarkdown">{full_response}●</div>
|
|
</div>
|
|
""", unsafe_allow_html=True)
|
|
|
|
else:
|
|
ext_answer = get_extractive_answer(context=context, question=prompt)
|
|
for word in generate_text_effect(ext_answer):
|
|
full_response = word
|
|
|
|
message_placeholder.markdown(f"""
|
|
<div class="assistant-message">
|
|
<img src="./app/static/ai.png" class="assistant-avatar" />
|
|
<div class="stMarkdown">{full_response}●</div>
|
|
</div>
|
|
""", unsafe_allow_html=True)
|
|
|
|
message_placeholder.markdown(f"""
|
|
<div class="assistant-message">
|
|
<img src="./app/static/ai.png" class="assistant-avatar" />
|
|
<div class="stMarkdown">{full_response}</div>
|
|
</div>
|
|
""", unsafe_allow_html=True)
|
|
|
|
st.session_state.messages.append({'role': 'assistant', 'content': full_response}) |