hertogateis commited on
Commit
97b07c2
·
verified ·
1 Parent(s): 95aba37

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -20
app.py CHANGED
@@ -1,21 +1,149 @@
1
- if is_graph_query:
2
- # Handle graph-related questions here
3
- if 'between' in question.lower() and 'and' in question.lower():
4
- columns = question.split('between')[-1].split('and')
5
- columns = [col.strip() for col in columns]
6
- if len(columns) == 2 and all(col in df.columns for col in columns):
7
- fig = px.scatter(df, x=columns[0], y=columns[1], title=f"Graph between {columns[0]} and {columns[1]}")
8
- st.plotly_chart(fig, use_container_width=True)
9
- st.success(f"Here is the graph between '{columns[0]}' and '{columns[1]}'.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  else:
11
- st.warning("Columns not found in the dataset.")
12
- elif 'column' in question.lower():
13
- column = question.split('of')[-1].strip()
14
- if column in df.columns:
15
- fig = px.line(df, x=df.index, y=column, title=f"Graph of column '{column}'")
16
- st.plotly_chart(fig, use_container_width=True)
17
- st.success(f"Here is the graph of column '{column}'.")
18
- else:
19
- st.warning(f"Column '{column}' not found in the data.")
20
- # **Do not proceed with TAPAS processing for graph queries**
21
- st.stop() # This will stop the code from running further
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from st_aggrid import AgGrid
4
+ import pandas as pd
5
+ from transformers import pipeline, T5ForConditionalGeneration, T5Tokenizer
6
+ import plotly.express as px
7
+
8
+ # Set the page layout for Streamlit
9
+ st.set_page_config(layout="wide")
10
+
11
+ # Initialize TAPAS pipeline
12
+ tqa = pipeline(task="table-question-answering",
13
+ model="google/tapas-large-finetuned-wtq",
14
+ device="cpu")
15
+
16
+ # Initialize T5 tokenizer and model for text generation
17
+ t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")
18
+ t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")
19
+
20
+ # Title and Introduction
21
+ st.title("Table Question Answering and Data Analysis App")
22
+ st.markdown("""
23
+ This app allows you to upload a table (CSV or Excel) and ask questions about the data.
24
+ Based on your question, it will provide the corresponding answer using the **TAPAS** model and additional data processing.
25
+
26
+ ### Available Features:
27
+ - **mean()**: For "average", it computes the mean of the entire numeric DataFrame.
28
+ - **sum()**: For "sum", it calculates the sum of all numeric values in the DataFrame.
29
+ - **max()**: For "max", it computes the maximum value in the DataFrame.
30
+ - **min()**: For "min", it computes the minimum value in the DataFrame.
31
+ - **count()**: For "count", it counts the non-null values in the entire DataFrame.
32
+ - **Graph Generation**: You can ask questions like "make a graph of column sales?" or "make a graph between sales and expenses?". The app will generate interactive graphs for you.
33
+
34
+ Upload your data and ask questions to get both answers and visualizations.
35
+ """)
36
+
37
+ # File uploader in the sidebar
38
+ file_name = st.sidebar.file_uploader("Upload file:", type=['csv', 'xlsx'])
39
+
40
+ # File processing and question answering
41
+ if file_name is None:
42
+ st.markdown('<p class="font">Please upload an excel or csv file </p>', unsafe_allow_html=True)
43
+ else:
44
+ try:
45
+ # Check file type and handle reading accordingly
46
+ if file_name.name.endswith('.csv'):
47
+ df = pd.read_csv(file_name, sep=';', encoding='ISO-8859-1') # Adjust encoding if needed
48
+ elif file_name.name.endswith('.xlsx'):
49
+ df = pd.read_excel(file_name, engine='openpyxl') # Use openpyxl to read .xlsx files
50
  else:
51
+ st.error("Unsupported file type")
52
+ df = None
53
+
54
+ if df is not None:
55
+ numeric_columns = df.select_dtypes(include=['object']).columns
56
+ for col in numeric_columns:
57
+ df[col] = pd.to_numeric(df[col], errors='ignore')
58
+
59
+ st.write("Original Data:")
60
+ st.write(df)
61
+
62
+ df_numeric = df.copy()
63
+ df = df.astype(str)
64
+
65
+ # Display the first 5 rows of the dataframe in an editable grid
66
+ grid_response = AgGrid(
67
+ df.head(5),
68
+ columns_auto_size_mode='FIT_CONTENTS',
69
+ editable=True,
70
+ height=300,
71
+ width='100%',
72
+ )
73
+
74
+ except Exception as e:
75
+ st.error(f"Error reading file: {str(e)}")
76
+
77
+ # User input for the question
78
+ question = st.text_input('Type your question')
79
+
80
+ # Check if the question is about generating a graph
81
+ is_graph_query = False
82
+
83
+ if 'graph' in question.lower():
84
+ is_graph_query = True
85
+
86
+ # Process the answer using TAPAS and T5
87
+ with st.spinner():
88
+ if st.button('Answer'):
89
+ try:
90
+ if not is_graph_query:
91
+ # Process TAPAS-related questions if it's not a graph query
92
+ raw_answer = tqa(table=df, query=question, truncation=True)
93
+
94
+ # Display raw answer from TAPAS
95
+ st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Raw Result From TAPAS: </p>", unsafe_allow_html=True)
96
+ st.write(raw_answer) # Display the raw result
97
+
98
+ answer = raw_answer['answer']
99
+ aggregator = raw_answer.get('aggregator', '')
100
+ coordinates = raw_answer.get('coordinates', [])
101
+ cells = raw_answer.get('cells', [])
102
+
103
+ # Handle different aggregators
104
+ if 'average' in question.lower() or aggregator == 'AVG':
105
+ avg_value = df.mean().mean() # Calculate overall average
106
+ base_sentence = f"The average for '{question}' is {avg_value:.2f}."
107
+ elif 'sum' in question.lower() or aggregator == 'SUM':
108
+ total_sum = df.sum().sum() # Calculate overall sum
109
+ base_sentence = f"The sum for '{question}' is {total_sum:.2f}."
110
+ elif 'max' in question.lower() or aggregator == 'MAX':
111
+ max_value = df.max().max() # Find overall max value
112
+ base_sentence = f"The maximum value for '{question}' is {max_value:.2f}."
113
+ elif 'min' in question.lower() or aggregator == 'MIN':
114
+ min_value = df.min().min() # Find overall min value
115
+ base_sentence = f"The minimum value for '{question}' is {min_value:.2f}."
116
+ elif 'count' in question.lower() or aggregator == 'COUNT':
117
+ count_value = df.count().sum() # Count all values
118
+ base_sentence = f"The total count of non-null values for '{question}' is {count_value}."
119
+ else:
120
+ base_sentence = f"The answer from TAPAS for '{question}' is {answer}."
121
+
122
+ # Display the final response
123
+ st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Final Generated Response: </p>", unsafe_allow_html=True)
124
+ st.success(base_sentence)
125
+
126
+ else:
127
+ # Handle graph-related questions
128
+ if 'between' in question.lower() and 'and' in question.lower():
129
+ columns = question.split('between')[-1].split('and')
130
+ columns = [col.strip() for col in columns]
131
+ if len(columns) == 2 and all(col in df.columns for col in columns):
132
+ fig = px.scatter(df, x=columns[0], y=columns[1], title=f"Graph between {columns[0]} and {columns[1]}")
133
+ st.plotly_chart(fig, use_container_width=True)
134
+ st.success(f"Here is the graph between '{columns[0]}' and '{columns[1]}'.")
135
+ else:
136
+ st.warning("Columns not found in the dataset.")
137
+ elif 'column' in question.lower():
138
+ column = question.split('of')[-1].strip()
139
+ if column in df.columns:
140
+ fig = px.line(df, x=df.index, y=column, title=f"Graph of column '{column}'")
141
+ st.plotly_chart(fig, use_container_width=True)
142
+ st.success(f"Here is the graph of column '{column}'.")
143
+ else:
144
+ st.warning(f"Column '{column}' not found in the data.")
145
+ return # Skip TAPAS processing for graph-related queries
146
+
147
+ except Exception as e:
148
+ st.warning(f"Error processing question or generating answer: {str(e)}")
149
+ st.warning("Please retype your question and make sure to use the column name and cell value correctly.")