hertogateis commited on
Commit
95aba37
·
verified ·
1 Parent(s): 04ac291

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -148
app.py CHANGED
@@ -1,149 +1,21 @@
1
- import os
2
- import streamlit as st
3
- from st_aggrid import AgGrid
4
- import pandas as pd
5
- from transformers import pipeline, T5ForConditionalGeneration, T5Tokenizer
6
- import plotly.express as px
7
-
8
- # Set the page layout for Streamlit
9
- st.set_page_config(layout="wide")
10
-
11
- # Initialize TAPAS pipeline
12
- tqa = pipeline(task="table-question-answering",
13
- model="google/tapas-large-finetuned-wtq",
14
- device="cpu")
15
-
16
- # Initialize T5 tokenizer and model for text generation
17
- t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")
18
- t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")
19
-
20
- # Title and Introduction
21
- st.title("Table Question Answering and Data Analysis App")
22
- st.markdown("""
23
- This app allows you to upload a table (CSV or Excel) and ask questions about the data.
24
- Based on your question, it will provide the corresponding answer using the **TAPAS** model and additional data processing.
25
-
26
- ### Available Features:
27
- - **mean()**: For "average", it computes the mean of the entire numeric DataFrame.
28
- - **sum()**: For "sum", it calculates the sum of all numeric values in the DataFrame.
29
- - **max()**: For "max", it computes the maximum value in the DataFrame.
30
- - **min()**: For "min", it computes the minimum value in the DataFrame.
31
- - **count()**: For "count", it counts the non-null values in the entire DataFrame.
32
- - **Graph Generation**: You can ask questions like "make a graph of column sales?" or "make a graph between sales and expenses?". The app will generate interactive graphs for you.
33
-
34
- Upload your data and ask questions to get both answers and visualizations.
35
- """)
36
-
37
- # File uploader in the sidebar
38
- file_name = st.sidebar.file_uploader("Upload file:", type=['csv', 'xlsx'])
39
-
40
- # File processing and question answering
41
- if file_name is None:
42
- st.markdown('<p class="font">Please upload an excel or csv file </p>', unsafe_allow_html=True)
43
- else:
44
- try:
45
- # Check file type and handle reading accordingly
46
- if file_name.name.endswith('.csv'):
47
- df = pd.read_csv(file_name, sep=';', encoding='ISO-8859-1') # Adjust encoding if needed
48
- elif file_name.name.endswith('.xlsx'):
49
- df = pd.read_excel(file_name, engine='openpyxl') # Use openpyxl to read .xlsx files
50
  else:
51
- st.error("Unsupported file type")
52
- df = None
53
-
54
- if df is not None:
55
- numeric_columns = df.select_dtypes(include=['object']).columns
56
- for col in numeric_columns:
57
- df[col] = pd.to_numeric(df[col], errors='ignore')
58
-
59
- st.write("Original Data:")
60
- st.write(df)
61
-
62
- df_numeric = df.copy()
63
- df = df.astype(str)
64
-
65
- # Display the first 5 rows of the dataframe in an editable grid
66
- grid_response = AgGrid(
67
- df.head(5),
68
- columns_auto_size_mode='FIT_CONTENTS',
69
- editable=True,
70
- height=300,
71
- width='100%',
72
- )
73
-
74
- except Exception as e:
75
- st.error(f"Error reading file: {str(e)}")
76
-
77
- # User input for the question
78
- question = st.text_input('Type your question')
79
-
80
- # Check if the question is about generating a graph
81
- is_graph_query = False
82
-
83
- if 'graph' in question.lower():
84
- is_graph_query = True
85
-
86
- # Process the answer using TAPAS and T5
87
- with st.spinner():
88
- if st.button('Answer'):
89
- try:
90
- if not is_graph_query:
91
- # Process TAPAS-related questions if it's not a graph query
92
- raw_answer = tqa(table=df, query=question, truncation=True)
93
-
94
- # Display raw answer from TAPAS
95
- st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Raw Result From TAPAS: </p>", unsafe_allow_html=True)
96
- st.write(raw_answer) # Display the raw result
97
-
98
- answer = raw_answer['answer']
99
- aggregator = raw_answer.get('aggregator', '')
100
- coordinates = raw_answer.get('coordinates', [])
101
- cells = raw_answer.get('cells', [])
102
-
103
- # Handle different aggregators
104
- if 'average' in question.lower() or aggregator == 'AVG':
105
- avg_value = df.mean().mean() # Calculate overall average
106
- base_sentence = f"The average for '{question}' is {avg_value:.2f}."
107
- elif 'sum' in question.lower() or aggregator == 'SUM':
108
- total_sum = df.sum().sum() # Calculate overall sum
109
- base_sentence = f"The sum for '{question}' is {total_sum:.2f}."
110
- elif 'max' in question.lower() or aggregator == 'MAX':
111
- max_value = df.max().max() # Find overall max value
112
- base_sentence = f"The maximum value for '{question}' is {max_value:.2f}."
113
- elif 'min' in question.lower() or aggregator == 'MIN':
114
- min_value = df.min().min() # Find overall min value
115
- base_sentence = f"The minimum value for '{question}' is {min_value:.2f}."
116
- elif 'count' in question.lower() or aggregator == 'COUNT':
117
- count_value = df.count().sum() # Count all values
118
- base_sentence = f"The total count of non-null values for '{question}' is {count_value}."
119
- else:
120
- base_sentence = f"The answer from TAPAS for '{question}' is {answer}."
121
-
122
- # Display the final response
123
- st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Final Generated Response: </p>", unsafe_allow_html=True)
124
- st.success(base_sentence)
125
-
126
- else:
127
- # Handle graph-related questions
128
- if 'between' in question.lower() and 'and' in question.lower():
129
- columns = question.split('between')[-1].split('and')
130
- columns = [col.strip() for col in columns]
131
- if len(columns) == 2 and all(col in df.columns for col in columns):
132
- fig = px.scatter(df, x=columns[0], y=columns[1], title=f"Graph between {columns[0]} and {columns[1]}")
133
- st.plotly_chart(fig, use_container_width=True)
134
- st.success(f"Here is the graph between '{columns[0]}' and '{columns[1]}'.")
135
- else:
136
- st.warning("Columns not found in the dataset.")
137
- elif 'column' in question.lower():
138
- column = question.split('of')[-1].strip()
139
- if column in df.columns:
140
- fig = px.line(df, x=df.index, y=column, title=f"Graph of column '{column}'")
141
- st.plotly_chart(fig, use_container_width=True)
142
- st.success(f"Here is the graph of column '{column}'.")
143
- else:
144
- st.warning(f"Column '{column}' not found in the data.")
145
- return # Skip TAPAS processing for graph-related queries
146
-
147
- except Exception as e:
148
- st.warning(f"Error processing question or generating answer: {str(e)}")
149
- st.warning("Please retype your question and make sure to use the column name and cell value correctly.")
 
1
+ if is_graph_query:
2
+ # Handle graph-related questions here
3
+ if 'between' in question.lower() and 'and' in question.lower():
4
+ columns = question.split('between')[-1].split('and')
5
+ columns = [col.strip() for col in columns]
6
+ if len(columns) == 2 and all(col in df.columns for col in columns):
7
+ fig = px.scatter(df, x=columns[0], y=columns[1], title=f"Graph between {columns[0]} and {columns[1]}")
8
+ st.plotly_chart(fig, use_container_width=True)
9
+ st.success(f"Here is the graph between '{columns[0]}' and '{columns[1]}'.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  else:
11
+ st.warning("Columns not found in the dataset.")
12
+ elif 'column' in question.lower():
13
+ column = question.split('of')[-1].strip()
14
+ if column in df.columns:
15
+ fig = px.line(df, x=df.index, y=column, title=f"Graph of column '{column}'")
16
+ st.plotly_chart(fig, use_container_width=True)
17
+ st.success(f"Here is the graph of column '{column}'.")
18
+ else:
19
+ st.warning(f"Column '{column}' not found in the data.")
20
+ # **Do not proceed with TAPAS processing for graph queries**
21
+ st.stop() # This will stop the code from running further