fschwartzer commited on
Commit
7212b4f
·
verified ·
1 Parent(s): c898242

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -22
app.py CHANGED
@@ -26,17 +26,23 @@ html_content = f"""
26
  # Aplicar o markdown combinado no Streamlit
27
  st.markdown(html_content, unsafe_allow_html=True)
28
 
29
- # Inicialização de variáveis de estado
30
- if 'all_anomalies' not in st.session_state:
31
- st.session_state['all_anomalies'] = pd.DataFrame()
32
- if 'history' not in st.session_state:
33
- st.session_state['history'] = []
34
-
35
- # Carregar os modelos de tradução e TAPEX
36
- pt_en_translator = T5ForConditionalGeneration.from_pretrained("unicamp-dl/translation-pt-en-t5")
37
- en_pt_translator = T5ForConditionalGeneration.from_pretrained("unicamp-dl/translation-en-pt-t5")
38
- tapex_model = BartForConditionalGeneration.from_pretrained("microsoft/tapex-large-finetuned-wtq")
39
- tapex_tokenizer = TapexTokenizer.from_pretrained("microsoft/tapex-large-finetuned-wtq")
 
 
 
 
 
 
40
  tokenizer = T5Tokenizer.from_pretrained("unicamp-dl/translation-pt-en-t5")
41
 
42
  def translate(text, model, tokenizer, source_lang="pt", target_lang="en"):
@@ -45,6 +51,7 @@ def translate(text, model, tokenizer, source_lang="pt", target_lang="en"):
45
  translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
46
  return translated_text
47
 
 
48
  def response(user_question, table_data):
49
  question_en = translate(user_question, pt_en_translator, tokenizer, source_lang="pt", target_lang="en")
50
  encoding = tapex_tokenizer(table=table_data, query=[question_en], padding=True, return_tensors="pt", truncation=True)
@@ -53,6 +60,7 @@ def response(user_question, table_data):
53
  response_pt = translate(response_en, en_pt_translator, tokenizer, source_lang="en", target_lang="pt")
54
  return response_pt
55
 
 
56
  def load_data(uploaded_file):
57
  if uploaded_file.name.endswith('.csv'):
58
  df = pd.read_csv(uploaded_file, quotechar='"', encoding='utf-8')
@@ -95,6 +103,8 @@ def preprocess_data(df):
95
  df_clean = new_df.copy()
96
  return df_clean
97
 
 
 
98
  def apply_prophet(df_clean):
99
  if df_clean.empty:
100
  st.error("DataFrame está vazio após o pré-processamento.")
@@ -156,6 +166,13 @@ def apply_prophet(df_clean):
156
  # Return the dataframe of all anomalies
157
  return all_anomalies
158
 
 
 
 
 
 
 
 
159
  tab1, tab2 = st.tabs(["Meta Prophet", "Microsoft TAPEX"])
160
 
161
  # Interface para carregar arquivo
@@ -169,31 +186,24 @@ with tab1:
169
  if df_clean.empty:
170
  st.warning("Não há dados válidos para processar.")
171
  else:
172
- # Check if 'all_anomalies' is already in session state to avoid re-running Prophet
173
- if 'all_anomalies' not in st.session_state:
174
  with st.spinner('Aplicando modelo de série temporal...'):
175
  all_anomalies = apply_prophet(df_clean)
176
  st.session_state['all_anomalies'] = all_anomalies
177
 
178
  with tab2:
179
- # Ensure 'all_anomalies' exists in session state before allowing user interaction
180
  if 'all_anomalies' in st.session_state and not st.session_state['all_anomalies'].empty:
181
- # Interface para perguntas do usuário
182
  user_question = st.text_input("Escreva sua questão aqui:", "")
183
  if user_question:
184
  bot_response = response(user_question, st.session_state['all_anomalies'])
185
  st.session_state['history'].append(('👤', user_question))
186
  st.session_state['history'].append(('🤖', bot_response))
187
 
188
- # Mostrar histórico de conversa
189
  for sender, message in st.session_state['history']:
190
- if sender == '👤':
191
- st.markdown(f"**👤 {message}**")
192
- elif sender == '🤖':
193
- st.markdown(f"**🤖 {message}**", unsafe_allow_html=True)
194
 
195
- # Botão para limpar histórico
196
  if st.button("Limpar histórico"):
197
  st.session_state['history'] = []
198
  else:
199
- st.warning("Por favor, processe os dados no Meta Prophet primeiro.")
 
26
  # Aplicar o markdown combinado no Streamlit
27
  st.markdown(html_content, unsafe_allow_html=True)
28
 
29
+ # Cache models to prevent re-loading on every run
30
+ @st.cache_resource
31
+ def load_translation_model(model_name):
32
+ return T5ForConditionalGeneration.from_pretrained(model_name)
33
+
34
+ @st.cache_resource
35
+ def load_tapex_model():
36
+ return BartForConditionalGeneration.from_pretrained("microsoft/tapex-large-finetuned-wtq")
37
+
38
+ @st.cache_resource
39
+ def load_tapex_tokenizer():
40
+ return TapexTokenizer.from_pretrained("microsoft/tapex-large-finetuned-wtq")
41
+
42
+ pt_en_translator = load_translation_model("unicamp-dl/translation-pt-en-t5")
43
+ en_pt_translator = load_translation_model("unicamp-dl/translation-en-pt-t5")
44
+ tapex_model = load_tapex_model()
45
+ tapex_tokenizer = load_tapex_tokenizer()
46
  tokenizer = T5Tokenizer.from_pretrained("unicamp-dl/translation-pt-en-t5")
47
 
48
  def translate(text, model, tokenizer, source_lang="pt", target_lang="en"):
 
51
  translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
52
  return translated_text
53
 
54
+ # Function to translate and interact with TAPEX model
55
  def response(user_question, table_data):
56
  question_en = translate(user_question, pt_en_translator, tokenizer, source_lang="pt", target_lang="en")
57
  encoding = tapex_tokenizer(table=table_data, query=[question_en], padding=True, return_tensors="pt", truncation=True)
 
60
  response_pt = translate(response_en, en_pt_translator, tokenizer, source_lang="en", target_lang="pt")
61
  return response_pt
62
 
63
+ # Load and preprocess the data
64
  def load_data(uploaded_file):
65
  if uploaded_file.name.endswith('.csv'):
66
  df = pd.read_csv(uploaded_file, quotechar='"', encoding='utf-8')
 
103
  df_clean = new_df.copy()
104
  return df_clean
105
 
106
+ # Cache the Prophet computation to avoid recomputing
107
+ @st.cache_data
108
  def apply_prophet(df_clean):
109
  if df_clean.empty:
110
  st.error("DataFrame está vazio após o pré-processamento.")
 
166
  # Return the dataframe of all anomalies
167
  return all_anomalies
168
 
169
+ # Initialize session states
170
+ if 'all_anomalies' not in st.session_state:
171
+ st.session_state['all_anomalies'] = pd.DataFrame()
172
+
173
+ if 'history' not in st.session_state:
174
+ st.session_state['history'] = []
175
+
176
  tab1, tab2 = st.tabs(["Meta Prophet", "Microsoft TAPEX"])
177
 
178
  # Interface para carregar arquivo
 
186
  if df_clean.empty:
187
  st.warning("Não há dados válidos para processar.")
188
  else:
189
+ # Cache the Prophet results
190
+ if st.session_state['all_anomalies'].empty:
191
  with st.spinner('Aplicando modelo de série temporal...'):
192
  all_anomalies = apply_prophet(df_clean)
193
  st.session_state['all_anomalies'] = all_anomalies
194
 
195
  with tab2:
 
196
  if 'all_anomalies' in st.session_state and not st.session_state['all_anomalies'].empty:
 
197
  user_question = st.text_input("Escreva sua questão aqui:", "")
198
  if user_question:
199
  bot_response = response(user_question, st.session_state['all_anomalies'])
200
  st.session_state['history'].append(('👤', user_question))
201
  st.session_state['history'].append(('🤖', bot_response))
202
 
 
203
  for sender, message in st.session_state['history']:
204
+ st.markdown(f"**{sender} {message}**")
 
 
 
205
 
 
206
  if st.button("Limpar histórico"):
207
  st.session_state['history'] = []
208
  else:
209
+ st.warning("Por favor, processe os dados no Meta Prophet primeiro.")