hertogateis commited on
Commit
98bef5d
·
verified ·
1 Parent(s): 27c94ba

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -0
app.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from st_aggrid import AgGrid
4
+ import pandas as pd
5
+ from transformers import pipeline, T5ForConditionalGeneration, T5Tokenizer
6
+ import plotly.express as px
7
+
8
+ # Set the page layout for Streamlit
9
+ st.set_page_config(layout="wide")
10
+
11
+ # Initialize TAPAS pipeline
12
+ tqa = pipeline(task="table-question-answering",
13
+ model="google/tapas-large-finetuned-wtq",
14
+ device="cpu")
15
+
16
+ # Initialize T5 tokenizer and model for text generation
17
+ t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")
18
+ t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")
19
+
20
+ # Title and Introduction
21
+ st.title("Table Question Answering and Data Analysis App")
22
+ st.markdown("""
23
+ This app allows you to upload a table (CSV or Excel) and ask questions about the data.
24
+ Based on your question, it will provide the corresponding answer using the **TAPAS** model and additional data processing.
25
+
26
+ ### Available Features:
27
+ - **mean()**: For "average", it computes the mean of the entire numeric DataFrame.
28
+ - **sum()**: For "sum", it calculates the sum of all numeric values in the DataFrame.
29
+ - **max()**: For "max", it computes the maximum value in the DataFrame.
30
+ - **min()**: For "min", it computes the minimum value in the DataFrame.
31
+ - **count()**: For "count", it counts the non-null values in the entire DataFrame.
32
+ - **Graph Generation**: You can ask questions like "make a graph of column sales?" or "make a graph between sales and expenses?". The app will generate interactive graphs for you.
33
+
34
+ Upload your data and ask questions to get both answers and visualizations.
35
+ """)
36
+
37
+ # File uploader in the sidebar
38
+ file_name = st.sidebar.file_uploader("Upload file:", type=['csv', 'xlsx'])
39
+
40
+ # File processing and question answering
41
+ if file_name is None:
42
+ st.markdown('<p class="font">Please upload an excel or csv file </p>', unsafe_allow_html=True)
43
+ else:
44
+ try:
45
+ # Check file type and handle reading accordingly
46
+ if file_name.name.endswith('.csv'):
47
+ df = pd.read_csv(file_name, sep=';', encoding='ISO-8859-1') # Adjust encoding if needed
48
+ elif file_name.name.endswith('.xlsx'):
49
+ df = pd.read_excel(file_name, engine='openpyxl') # Use openpyxl to read .xlsx files
50
+ else:
51
+ st.error("Unsupported file type")
52
+ df = None
53
+
54
+ if df is not None:
55
+ numeric_columns = df.select_dtypes(include=['object']).columns
56
+ for col in numeric_columns:
57
+ df[col] = pd.to_numeric(df[col], errors='ignore')
58
+
59
+ st.write("Original Data:")
60
+ st.write(df)
61
+
62
+ df_numeric = df.copy()
63
+ df = df.astype(str)
64
+
65
+ # Display the first 5 rows of the dataframe in an editable grid
66
+ grid_response = AgGrid(
67
+ df.head(5),
68
+ columns_auto_size_mode='FIT_CONTENTS',
69
+ editable=True,
70
+ height=300,
71
+ width='100%',
72
+ )
73
+
74
+ except Exception as e:
75
+ st.error(f"Error reading file: {str(e)}")
76
+
77
+ # User input for the question
78
+ question = st.text_input('Type your question')
79
+
80
+ # Process the answer using TAPAS and T5
81
+ with st.spinner():
82
+ if st.button('Answer'):
83
+ try:
84
+ raw_answer = tqa(table=df, query=question, truncation=True)
85
+
86
+ st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Raw Result From TAPAS: </p>",
87
+ unsafe_allow_html=True)
88
+ st.success(raw_answer)
89
+
90
+ answer = raw_answer['answer']
91
+ aggregator = raw_answer.get('aggregator', '')
92
+ coordinates = raw_answer.get('coordinates', [])
93
+ cells = raw_answer.get('cells', [])
94
+
95
+ # Handle aggregation based on user question or TAPAS output
96
+ if 'average' in question.lower() or aggregator == 'AVG':
97
+ avg_value = df.mean().mean() # Calculate overall average
98
+ base_sentence = f"The average for '{question}' is {avg_value:.2f}."
99
+ elif 'sum' in question.lower() or aggregator == 'SUM':
100
+ total_sum = df.sum().sum() # Calculate overall sum
101
+ base_sentence = f"The sum for '{question}' is {total_sum:.2f}."
102
+ elif 'max' in question.lower() or aggregator == 'MAX':
103
+ max_value = df.max().max() # Find overall max value
104
+ base_sentence = f"The maximum value for '{question}' is {max_value:.2f}."
105
+ elif 'min' in question.lower() or aggregator == 'MIN':
106
+ min_value = df.min().min() # Find overall min value
107
+ base_sentence = f"The minimum value for '{question}' is {min_value:.2f}."
108
+ elif 'count' in question.lower() or aggregator == 'COUNT':
109
+ count_value = df.count().sum() # Count all values
110
+ base_sentence = f"The total count of non-null values for '{question}' is {count_value}."
111
+ elif 'graph' in question.lower():
112
+ # Check for graph-related queries
113
+ if 'between' in question.lower() and 'and' in question.lower():
114
+ columns = question.split('between')[-1].split('and')
115
+ columns = [col.strip() for col in columns]
116
+ if len(columns) == 2 and all(col in df.columns for col in columns):
117
+ fig = px.scatter(df, x=columns[0], y=columns[1], title=f"Graph between {columns[0]} and {columns[1]}")
118
+ st.plotly_chart(fig, use_container_width=True)
119
+ base_sentence = f"Here is the graph between '{columns[0]}' and '{columns[1]}'."
120
+ elif 'column' in question.lower():
121
+ column = question.split('of')[-1].strip()
122
+ if column in df.columns:
123
+ fig = px.line(df, x=df.index, y=column, title=f"Graph of column '{column}'")
124
+ st.plotly_chart(fig, use_container_width=True)
125
+ base_sentence = f"Here is the graph of column '{column}'."
126
+ else:
127
+ base_sentence = f"Column '{column}' not found in the data."
128
+ else:
129
+ base_sentence = f"The answer from TAPAS for '{question}' is {answer}."
130
+
131
+ # Display the final response
132
+ st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Final Generated Response: </p>", unsafe_allow_html=True)
133
+ st.success(base_sentence)
134
+
135
+ except Exception as e:
136
+ st.warning(f"Error processing question or generating answer: {str(e)}")
137
+ st.warning("Please retype your question and make sure to use the column name and cell value correctly.")