File size: 3,171 Bytes
c6dc269
 
6c96c7d
34e67a1
c6a4668
 
acf8351
10379c1
14c4d74
 
9bbf8c6
 
 
c6dc269
9bbf8c6
acf8351
6c96c7d
6d02d8f
c6a4668
 
6c96c7d
cc0f753
 
41debeb
b13ea5d
 
 
 
 
 
 
 
 
 
41debeb
 
b13ea5d
 
acf8351
 
b13ea5d
acf8351
 
 
 
 
3b9304d
c6dc269
23171a7
4ade980
 
86088c2
 
9b8124b
86088c2
4ade980
 
c6dc269
23171a7
c6dc269
 
 
23171a7
c6dc269
 
 
23171a7
c6dc269
31a3643
c6dc269
23171a7
0850865
c6dc269
23171a7
c6dc269
31a3643
c6dc269
23171a7
c6dc269
 
 
23171a7
c6dc269
 
31a3643
c6dc269
31a3643
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import streamlit as st
import pandas as pd
import torch
from transformers import pipeline
#from transformers import TapasTokenizer, TapexTokenizer, BartForConditionalGeneration
from transformers import AutoTokenizer, AutoModelForTableQuestionAnswering
import datetime

#df = pd.read_excel('discrepantes.xlsx', index_col='Unnamed: 0')
df = pd.read_excel('discrepantes.xlsx')
df.fillna(0, inplace=True)
table_data = df.astype(str)
print(table_data.head())

def response(user_question, table_data):
    a = datetime.datetime.now()

    model_name = "google/tapas-base-finetuned-wtq"
    model = AutoModelForTableQuestionAnswering.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # The query should be passed as a list
    encoding = tokenizer(table=table_data, queries=[user_question], padding=True, return_tensors="pt", truncation=True)

    # Instead of using generate, we pass the encoding through the model to get the logits
    outputs = model(**encoding)

    # Extract the answer coordinates
    predicted_answer_coordinates = outputs.logits.argmax(-1)

    # Decode the answer from the table using the coordinates
    answer = tokenizer.convert_logits_to_predictions(
        encoding.data,
        predicted_answer_coordinates
    )

    # Process the answer into a readable format
    answer_text = answer[0][0][0] if len(answer[0]) > 0 else "Não foi possível encontrar uma resposta"

    query_result = {
        "Resposta": answer_text
    }

    b = datetime.datetime.now()
    print(b - a)

    return query_result

# Streamlit interface
st.markdown("""
<div style='display: flex; align-items: center;'>
    <div style='width: 40px; height: 40px; background-color: green; border-radius: 50%; margin-right: 5px;'></div>
    <div style='width: 40px; height: 40px; background-color: red; border-radius: 50%; margin-right: 5px;'></div>
    <div style='width: 40px; height: 40px; background-color: yellow; border-radius: 50%; margin-right: 5px;'></div>
    <span style='font-size: 40px; font-weight: bold;'>Chatbot do Tesouro RS</span>
</div>
""", unsafe_allow_html=True)

# Chat history
if 'history' not in st.session_state:
    st.session_state['history'] = []

# Input box for user question
user_question = st.text_input("Escreva sua questão aqui:", "")

if user_question:
    # Add person emoji when typing question
    st.session_state['history'].append(('👤', user_question))
    st.markdown(f"**👤 {user_question}**")
    
    # Generate the response
    bot_response = response(user_question, table_data)
    
    # Add robot emoji when generating response and align to the right
    st.session_state['history'].append(('🤖', bot_response))
    st.markdown(f"<div style='text-align: right'>**🤖 {bot_response}**</div>", unsafe_allow_html=True)

# Clear history button
if st.button("Limpar"):
    st.session_state['history'] = []

# Display chat history
for sender, message in st.session_state['history']:
    if sender == '👤':
        st.markdown(f"**👤 {message}**")
    elif sender == '🤖':
        st.markdown(f"<div style='text-align: right'>**🤖 {message}**</div>", unsafe_allow_html=True)