hertogateis commited on
Commit
fcf406d
·
verified ·
1 Parent(s): ad7ff43

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -0
app.py CHANGED
@@ -129,6 +129,81 @@ else:
129
  # Get raw answer again from TAPAS
130
  raw_answer = tqa(table=df, query=question, truncation=True)
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  # Display raw result for debugging purposes
133
  st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Raw Result: </p>", unsafe_allow_html=True)
134
  st.success(raw_answer)
 
129
  # Get raw answer again from TAPAS
130
  raw_answer = tqa(table=df, query=question, truncation=True)
131
 
132
+ # Extract the raw answer from TAPAS and the aggregator it returned
133
+ raw_answer = tqa(table=df, query=question, truncation=True)
134
+
135
+ # Get the answer, coordinates, cells, and aggregator from the raw TAPAS output
136
+ answer = raw_answer['answer']
137
+ aggregator = raw_answer.get('aggregator', '')
138
+ coordinates = raw_answer.get('coordinates', [])
139
+ cells = raw_answer.get('cells', [])
140
+
141
+ # Manually fix the aggregator if it returns an incorrect one
142
+ if 'MEDIAN' in question.upper() and 'AVERAGE' in aggregator.upper():
143
+ aggregator = 'MEDIAN'
144
+ elif 'MIN' in question.upper() and 'AVERAGE' in aggregator.upper():
145
+ aggregator = 'MIN'
146
+ elif 'MAX' in question.upper() and 'AVERAGE' in aggregator.upper():
147
+ aggregator = 'MAX'
148
+ elif 'TOTAL' in question.upper() and 'SUM' in aggregator.upper():
149
+ aggregator = 'SUM'
150
+
151
+ # Use the corrected aggregator for further processing
152
+ summary_type = aggregator.lower()
153
+
154
+ # Now, calculate the correct value using pandas based on the corrected aggregator
155
+ if summary_type == 'sum':
156
+ numeric_value = df_numeric[column_name].sum()
157
+ elif summary_type == 'max':
158
+ numeric_value = df_numeric[column_name].max()
159
+ elif summary_type == 'min':
160
+ numeric_value = df_numeric[column_name].min()
161
+ elif summary_type == 'average':
162
+ numeric_value = df_numeric[column_name].mean()
163
+ elif summary_type == 'count':
164
+ numeric_value = df_numeric[column_name].count()
165
+ elif summary_type == 'median':
166
+ numeric_value = df_numeric[column_name].median()
167
+ elif summary_type == 'std_dev':
168
+ numeric_value = df_numeric[column_name].std()
169
+ else:
170
+ numeric_value = processed_answer # Fallback if something went wrong
171
+
172
+ # Construct a natural language response
173
+ if summary_type == 'sum':
174
+ natural_language_answer = f"The total {column_name} is {numeric_value}."
175
+ elif summary_type == 'maximum':
176
+ natural_language_answer = f"The highest {column_name} is {numeric_value}."
177
+ elif summary_type == 'minimum':
178
+ natural_language_answer = f"The lowest {column_name} is {numeric_value}."
179
+ elif summary_type == 'average':
180
+ natural_language_answer = f"The average {column_name} is {numeric_value}."
181
+ elif summary_type == 'count':
182
+ natural_language_answer = f"The number of entries in {column_name} is {numeric_value}."
183
+ elif summary_type == 'median':
184
+ natural_language_answer = f"The median {column_name} is {numeric_value}."
185
+ elif summary_type == 'std_dev':
186
+ natural_language_answer = f"The standard deviation of {column_name} is {numeric_value}."
187
+ else:
188
+ natural_language_answer = f"The value for {column_name} is {numeric_value}."
189
+
190
+ # Display the result to the user
191
+ st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Analysis Results: </p>", unsafe_allow_html=True)
192
+ st.success(f"""
193
+ • Answer: {natural_language_answer}
194
+
195
+ Data Location:
196
+ • Column: {column_name}
197
+
198
+ Additional Context:
199
+ • Query Asked: "{question}"
200
+ """)
201
+
202
+
203
+
204
+
205
+
206
+
207
  # Display raw result for debugging purposes
208
  st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Raw Result: </p>", unsafe_allow_html=True)
209
  st.success(raw_answer)