File size: 3,128 Bytes
c6dc269
 
 
ff6536f
c6dc269
d0b6b88
 
 
c6dc269
10379c1
 
 
 
c6dc269
 
 
 
7abb0f2
d0b6b88
c6dc269
 
 
 
 
 
d0b6b88
 
 
 
 
c6dc269
 
 
d0b6b88
c6dc269
 
 
 
4ade980
 
86088c2
 
9b8124b
86088c2
4ade980
 
c6dc269
 
 
 
 
 
 
 
 
 
 
31a3643
c6dc269
 
 
 
31a3643
c6dc269
31a3643
c6dc269
 
 
 
 
 
 
 
31a3643
c6dc269
31a3643
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import streamlit as st
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
model = AutoModelForCausalLM.from_pretrained("distilgpt2", torch_dtype=torch.float16)
model = model.to('cuda') if torch.cuda.is_available() else model.to('cpu')

# Set the padding token to the end-of-sequence token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

df = pd.read_csv('anomalies.csv')

# Função para gerar resposta
def response(question):
    prompt = f"Considerando os dados: {df.to_string(index=False)}, onde 'ds' está em formato DateTime, 'real' é o valor da despesa e 'group' é o grupo da despesa. Pergunta: {question}"
    inputs = tokenizer(prompt, return_tensors='pt', padding='max_length', truncation=True, max_length=256)
    attention_mask = inputs['attention_mask']
    input_ids = inputs['input_ids']
    
    generated_ids = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_length=len(input_ids[0]) + 50,  # Reduce max_length to speed up response
        temperature=0.7,
        top_p=0.9,
        no_repeat_ngram_size=2,
        num_beams=3,  # Adding beams for more reliable generation
    )
    
    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    final_response = generated_text.split("Resposta:")[-1].split(".")[0] + "."
    
    return final_response

# Interface Streamlit
st.markdown("""
<div style='display: flex; align-items: center;'>
    <div style='width: 40px; height: 40px; background-color: green; border-radius: 50%; margin-right: 5px;'></div>
    <div style='width: 40px; height: 40px; background-color: red; border-radius: 50%; margin-right: 5px;'></div>
    <div style='width: 40px; height: 40px; background-color: yellow; border-radius: 50%; margin-right: 5px;'></div>
    <span style='font-size: 40px; font-weight: bold;'>Chatbot do Tesouro RS</span>
</div>
""", unsafe_allow_html=True)

# Histórico de conversas
if 'history' not in st.session_state:
    st.session_state['history'] = []

# Caixa de entrada para a pergunta
user_question = st.text_input("Escreva sua questão aqui:", "")

if user_question:
    # Adiciona emoji de pessoa quando a pergunta está sendo digitada
    st.session_state['history'].append(('👤', user_question))
    st.markdown(f"**👤 {user_question}**")
    
    # Gera a resposta
    bot_response = response(user_question)
    
    # Adiciona emoji de robô quando a resposta está sendo gerada e alinha à direita
    st.session_state['history'].append(('🤖', bot_response))
    st.markdown(f"<div style='text-align: right'>**🤖 {bot_response}**</div>", unsafe_allow_html=True)

# Botão para limpar o histórico
if st.button("Limpar"):
    st.session_state['history'] = []

# Exibe o histórico de conversas
for sender, message in st.session_state['history']:
    if sender == '👤':
        st.markdown(f"**👤 {message}**")
    elif sender == '🤖':
        st.markdown(f"<div style='text-align: right'>**🤖 {message}**</div>", unsafe_allow_html=True)