Spaces:
Running
Running
File size: 3,128 Bytes
c6dc269 ff6536f c6dc269 d0b6b88 c6dc269 10379c1 c6dc269 7abb0f2 d0b6b88 c6dc269 d0b6b88 c6dc269 d0b6b88 c6dc269 4ade980 86088c2 9b8124b 86088c2 4ade980 c6dc269 31a3643 c6dc269 31a3643 c6dc269 31a3643 c6dc269 31a3643 c6dc269 31a3643 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import streamlit as st
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
model = AutoModelForCausalLM.from_pretrained("distilgpt2", torch_dtype=torch.float16)
model = model.to('cuda') if torch.cuda.is_available() else model.to('cpu')
# Set the padding token to the end-of-sequence token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
df = pd.read_csv('anomalies.csv')
# Função para gerar resposta
def response(question):
prompt = f"Considerando os dados: {df.to_string(index=False)}, onde 'ds' está em formato DateTime, 'real' é o valor da despesa e 'group' é o grupo da despesa. Pergunta: {question}"
inputs = tokenizer(prompt, return_tensors='pt', padding='max_length', truncation=True, max_length=256)
attention_mask = inputs['attention_mask']
input_ids = inputs['input_ids']
generated_ids = model.generate(
input_ids,
attention_mask=attention_mask,
max_length=len(input_ids[0]) + 50, # Reduce max_length to speed up response
temperature=0.7,
top_p=0.9,
no_repeat_ngram_size=2,
num_beams=3, # Adding beams for more reliable generation
)
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
final_response = generated_text.split("Resposta:")[-1].split(".")[0] + "."
return final_response
# Interface Streamlit
st.markdown("""
<div style='display: flex; align-items: center;'>
<div style='width: 40px; height: 40px; background-color: green; border-radius: 50%; margin-right: 5px;'></div>
<div style='width: 40px; height: 40px; background-color: red; border-radius: 50%; margin-right: 5px;'></div>
<div style='width: 40px; height: 40px; background-color: yellow; border-radius: 50%; margin-right: 5px;'></div>
<span style='font-size: 40px; font-weight: bold;'>Chatbot do Tesouro RS</span>
</div>
""", unsafe_allow_html=True)
# Histórico de conversas
if 'history' not in st.session_state:
st.session_state['history'] = []
# Caixa de entrada para a pergunta
user_question = st.text_input("Escreva sua questão aqui:", "")
if user_question:
# Adiciona emoji de pessoa quando a pergunta está sendo digitada
st.session_state['history'].append(('👤', user_question))
st.markdown(f"**👤 {user_question}**")
# Gera a resposta
bot_response = response(user_question)
# Adiciona emoji de robô quando a resposta está sendo gerada e alinha à direita
st.session_state['history'].append(('🤖', bot_response))
st.markdown(f"<div style='text-align: right'>**🤖 {bot_response}**</div>", unsafe_allow_html=True)
# Botão para limpar o histórico
if st.button("Limpar"):
st.session_state['history'] = []
# Exibe o histórico de conversas
for sender, message in st.session_state['history']:
if sender == '👤':
st.markdown(f"**👤 {message}**")
elif sender == '🤖':
st.markdown(f"<div style='text-align: right'>**🤖 {message}**</div>", unsafe_allow_html=True)
|