File size: 5,304 Bytes
f912d04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import time
import requests
import streamlit as st

st.set_page_config(page_title="ViBidLawQA - Trợ lý AI hỗ trợ hỏi đáp luật Việt Nam", page_icon="./app/static/ai.jpg", layout="centered", initial_sidebar_state="expanded")

with open("./static/styles.css") as f:
    st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)

if 'messages' not in st.session_state:
    st.session_state.messages = []

st.markdown(f"""

<div class=logo_area>

    <img src="./app/static/ai.jpg"/>

</div>

""", unsafe_allow_html=True)
st.markdown("<h2 style='text-align: center;'>ViBidLQA Bot</h2>", unsafe_allow_html=True)

url_api_extraction_model = st.sidebar.text_input(label="URL API Extraction model:")
url_api_generation_model = st.sidebar.text_input(label="URL API Generation model:")

answering_method = st.sidebar.selectbox(options=['Extraction', 'Generation'], label='Chọn mô hình trả lời câu hỏi:', index=0)
context = st.sidebar.text_area(label='Nội dung văn bản pháp luật Việt Nam:', placeholder='Vui lòng nhập nội dung văn bản pháp luật Việt Nam tại đây...', height=500)

if answering_method == 'Generation':
    print('Switching to generative model...')
    print('Loading generative model...')

if answering_method == 'Extraction':
    print('Switching to extraction model...')
    print('Loading extraction model...')

def get_abstractive_answer(context, question):
    data = {
        "context": context,
        "question": question
    }

    response = requests.post(url_api_generation_model, json=data)
    if response.status_code == 200:
        result = response.json()
        return result["answer"]
    else:
        return f"Lỗi: {response.status_code} - {response.text}"

def generate_text_effect(answer):
    words = answer.split()
    for i in range(len(words)):
        time.sleep(0.03)
        yield " ".join(words[:i+1])

def get_extractive_answer(context, question, stride=20, max_length=256, n_best=50, max_answer_length=512):
    data = {
    "context": context,
    "question": question,
    "stride": stride,
    "max_length": max_length,
    "n_best": n_best,
    "max_answer_length": max_answer_length
    }

    response = requests.post(url_api_extraction_model, json=data)

    if response.status_code == 200:
        result = response.json()
        return result["best_answer"]
    else:
        return f"Lỗi: {response.status_code} - {response.text}"

for message in st.session_state.messages:
    if message['role'] == 'assistant':
        avatar_class = "assistant-avatar"
        message_class = "assistant-message"
        avatar = './app/static/ai.jpg'
    else:
        avatar_class = "user-avatar"
        message_class = "user-message"
        avatar = './app/static/human.png'
    st.markdown(f"""

    <div class="{message_class}">

        <img src="{avatar}" class="{avatar_class}" />

        <div class="stMarkdown">{message['content']}</div>

    </div>

    """, unsafe_allow_html=True)

if prompt := st.chat_input(placeholder='Tôi có thể giúp được gì cho bạn?'):
    st.markdown(f"""

    <div class="user-message">

        <img src="./app/static/human.png" class="user-avatar" />

        <div class="stMarkdown">{prompt}</div>

    </div>

    """, unsafe_allow_html=True)
    st.session_state.messages.append({'role': 'user', 'content': prompt})
    
    message_placeholder = st.empty()
    
    for _ in range(2):
        for dots in ["●", "●●", "●●●"]:
            time.sleep(0.2)
            message_placeholder.markdown(f"""

            <div class="assistant-message">

                <img src="./app/static/ai.jpg" class="assistant-avatar" />

                <div class="stMarkdown">{dots}</div>

            </div>

            """, unsafe_allow_html=True)
    
    full_response = ""
    if answering_method == 'Generation':
        abs_answer = get_abstractive_answer(context=context, question=prompt) 
        for word in generate_text_effect(abs_answer):
            full_response = word

            message_placeholder.markdown(f"""

            <div class="assistant-message">

                <img src="./app/static/ai.jpg" class="assistant-avatar" />

                <div class="stMarkdown">{full_response}●</div>

            </div>

            """, unsafe_allow_html=True)
        
    else:
        ext_answer = get_extractive_answer(context=context, question=prompt)
        for word in generate_text_effect(ext_answer):
            full_response = word
            
            message_placeholder.markdown(f"""

            <div class="assistant-message">

                <img src="./app/static/ai.jpg" class="assistant-avatar" />

                <div class="stMarkdown">{full_response}●</div>

            </div>

            """, unsafe_allow_html=True)

    message_placeholder.markdown(f"""

    <div class="assistant-message">

        <img src="./app/static/ai.jpg" class="assistant-avatar" />

            <div class="stMarkdown">

                {full_response}

            </div>

    </div>

    """, unsafe_allow_html=True)
    
    st.session_state.messages.append({'role': 'assistant', 'content': full_response})