hertogateis commited on
Commit
de716c3
·
verified ·
1 Parent(s): d973d75

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +201 -0
app.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from st_aggrid import AgGrid
4
+ import pandas as pd
5
+ from transformers import pipeline, T5ForConditionalGeneration, T5Tokenizer
6
+
7
+ # Set the page layout for Streamlit
8
+ st.set_page_config(layout="wide")
9
+
10
+ # CSS styling
11
+ style = '''
12
+ <style>
13
+ body {background-color: #F5F5F5; color: #000000;}
14
+ header {visibility: hidden;}
15
+ div.block-container {padding-top:4rem;}
16
+ section[data-testid="stSidebar"] div:first-child {
17
+ padding-top: 0;
18
+ }
19
+ .font {
20
+ text-align:center;
21
+ font-family:sans-serif;font-size: 1.25rem;}
22
+ </style>
23
+ '''
24
+ st.markdown(style, unsafe_allow_html=True)
25
+
26
+ st.markdown('<p style="font-family:sans-serif;font-size: 1.9rem;"> HertogAI Table Q&A using TAPAS+Data Analysis and Model Language</p>', unsafe_allow_html=True)
27
+ st.markdown('<p style="font-family:sans-serif;font-size: 1.9rem;"> This code is based on Jordan Skinner. I recoded and enhanced for </p>', unsafe_allow_html=True)
28
+ st.markdown('<p style="font-family:sans-serif;font-size: 1.9rem;"> the Data analisys are (SUM, MAX, MIN, AVG, COUNT, MEAN, STDDEV) </p>', unsafe_allow_html=True)
29
+ st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'>Pre-trained TAPAS model runs on max 64 rows and 32 columns data. Make sure the file data doesn't exceed these dimensions.</p>", unsafe_allow_html=True)
30
+
31
+ # Initialize TAPAS pipeline
32
+ tqa = pipeline(task="table-question-answering",
33
+ model="google/tapas-large-finetuned-wtq",
34
+ device="cpu")
35
+
36
+ # Initialize T5 tokenizer and model for text generation
37
+ t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")
38
+ t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")
39
+
40
+ # File uploader in the sidebar
41
+ file_name = st.sidebar.file_uploader("Upload file:", type=['csv', 'xlsx'])
42
+
43
+ # File processing and question answering
44
+ if file_name is None:
45
+ st.markdown('<p class="font">Please upload an excel or csv file </p>', unsafe_allow_html=True)
46
+ else:
47
+ try:
48
+ # Check file type and handle reading accordingly
49
+ if file_name.name.endswith('.csv'):
50
+ df = pd.read_csv(file_name, sep=';', encoding='ISO-8859-1') # Adjust encoding if needed
51
+ elif file_name.name.endswith('.xlsx'):
52
+ df = pd.read_excel(file_name, engine='openpyxl') # Use openpyxl to read .xlsx files
53
+ else:
54
+ st.error("Unsupported file type")
55
+ df = None
56
+
57
+ # Continue with further processing if df is loaded
58
+ if df is not None:
59
+ numeric_columns = df.select_dtypes(include=['object']).columns
60
+ for col in numeric_columns:
61
+ df[col] = pd.to_numeric(df[col], errors='ignore')
62
+
63
+ st.write("Original Data:")
64
+ st.write(df)
65
+
66
+ # Create a copy for numerical operations
67
+ df_numeric = df.copy()
68
+ df = df.astype(str)
69
+
70
+ # Display the first 5 rows of the dataframe in an editable grid
71
+ grid_response = AgGrid(
72
+ df.head(5),
73
+ columns_auto_size_mode='FIT_CONTENTS',
74
+ editable=True,
75
+ height=300,
76
+ width='100%',
77
+ )
78
+
79
+ except Exception as e:
80
+ st.error(f"Error reading file: {str(e)}")
81
+
82
+ # User input for the question
83
+ question = st.text_input('Type your question')
84
+
85
+ # Process the answer using TAPAS and T5
86
+ with st.spinner():
87
+ if st.button('Answer'):
88
+ try:
89
+ # Get the raw answer from TAPAS
90
+ raw_answer = tqa(table=df, query=question, truncation=True)
91
+
92
+ st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Raw Result From TAPAS: </p>",
93
+ unsafe_allow_html=True)
94
+ st.success(raw_answer)
95
+
96
+ # Extract relevant information from the TAPAS result
97
+ answer = raw_answer['answer']
98
+ aggregator = raw_answer.get('aggregator', '')
99
+ coordinates = raw_answer.get('coordinates', [])
100
+ cells = raw_answer.get('cells', [])
101
+
102
+ # Construct a base sentence replacing 'SUM' with the query term
103
+ base_sentence = f"The {question.lower()} of the selected data is {answer}."
104
+ if coordinates and cells:
105
+ rows_info = [f"Row {coordinate[0] + 1}, Column '{df.columns[coordinate[1]]}' with value {cell}"
106
+ for coordinate, cell in zip(coordinates, cells)]
107
+ rows_description = " and ".join(rows_info)
108
+ base_sentence += f" This includes the following data: {rows_description}."
109
+
110
+ # Generate a fluent response using the T5 model, rephrasing the base sentence
111
+ input_text = f"Given the question: '{question}', generate a more human-readable response: {base_sentence}"
112
+
113
+ # Tokenize the input and generate a fluent response using T5
114
+ inputs = t5_tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True)
115
+ summary_ids = t5_model.generate(inputs, max_length=150, num_beams=4, early_stopping=True)
116
+
117
+ # Decode the generated text
118
+ generated_text = t5_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
119
+
120
+ # Display the final generated response
121
+ st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Final Generated Response with LLM: </p>", unsafe_allow_html=True)
122
+ st.success(generated_text)
123
+
124
+ except Exception as e:
125
+ st.warning("Please retype your question and make sure to use the column name and cell value correctly.")
126
+
127
+ try:
128
+ # Get raw answer again from TAPAS
129
+ raw_answer = tqa(table=df, query=question, truncation=True)
130
+
131
+ # Display raw result for debugging purposes
132
+ st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Raw Result: </p>", unsafe_allow_html=True)
133
+ st.success(raw_answer)
134
+
135
+ # Processing the raw_answer
136
+ processed_answer = raw_answer['answer'].replace(';', ' ') # Clean the answer text
137
+ row_idx = raw_answer['coordinates'][0][0] # Row index from TAPAS
138
+ col_idx = raw_answer['coordinates'][0][1] # Column index from TAPAS
139
+ column_name = df.columns[col_idx] # Column name from the DataFrame
140
+ row_data = df.iloc[row_idx].to_dict() # Row data corresponding to the row index
141
+
142
+ # Handle different types of answers (e.g., 'SUM', 'MAX', 'MIN', 'AVG', etc.)
143
+ if 'SUM' in processed_answer:
144
+ summary_type = 'sum'
145
+ numeric_value = df_numeric[column_name].sum()
146
+ elif 'MAX' in processed_answer:
147
+ summary_type = 'maximum'
148
+ numeric_value = df_numeric[column_name].max()
149
+ elif 'MIN' in processed_answer:
150
+ summary_type = 'minimum'
151
+ numeric_value = df_numeric[column_name].min()
152
+ elif 'AVG' in processed_answer or 'AVERAGE' in processed_answer:
153
+ summary_type = 'average'
154
+ numeric_value = df_numeric[column_name].mean()
155
+ elif 'COUNT' in processed_answer:
156
+ summary_type = 'count'
157
+ numeric_value = df_numeric[column_name].count()
158
+ elif 'MEDIAN' in processed_answer:
159
+ summary_type = 'median'
160
+ numeric_value = df_numeric[column_name].median()
161
+ elif 'STD' in processed_answer or 'STANDARD DEVIATION' in processed_answer:
162
+ summary_type = 'std_dev'
163
+ numeric_value = df_numeric[column_name].std()
164
+ else:
165
+ summary_type = 'value'
166
+ numeric_value = processed_answer # In case of a general answer
167
+
168
+ # Build a natural language response based on the aggregation type
169
+ if summary_type == 'sum':
170
+ natural_language_answer = f"The total {column_name} is {numeric_value}."
171
+ elif summary_type == 'maximum':
172
+ natural_language_answer = f"The highest {column_name} is {numeric_value}, recorded for '{row_data.get('Name', 'Unknown')}'."
173
+ elif summary_type == 'minimum':
174
+ natural_language_answer = f"The lowest {column_name} is {numeric_value}, recorded for '{row_data.get('Name', 'Unknown')}'."
175
+ elif summary_type == 'average':
176
+ natural_language_answer = f"The average {column_name} is {numeric_value}."
177
+ elif summary_type == 'count':
178
+ natural_language_answer = f"The number of entries in {column_name} is {numeric_value}."
179
+ elif summary_type == 'median':
180
+ natural_language_answer = f"The median {column_name} is {numeric_value}."
181
+ elif summary_type == 'std_dev':
182
+ natural_language_answer = f"The standard deviation of {column_name} is {numeric_value}."
183
+ else:
184
+ natural_language_answer = f"The {column_name} value is {numeric_value} for '{row_data.get('Name', 'Unknown')}'."
185
+
186
+ # Display the final natural language answer
187
+ st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Analysis Results: </p>", unsafe_allow_html=True)
188
+ st.success(f"""
189
+ • Answer: {natural_language_answer}
190
+
191
+ Data Location:
192
+ • Row: {row_idx + 1}
193
+ • Column: {column_name}
194
+
195
+ Additional Context:
196
+ • Full Row Data: {row_data}
197
+ • Query Asked: "{question}"
198
+ """)
199
+
200
+ except Exception as e:
201
+ st.warning("Please retype your question and make sure to use the column name and cell value correctly.")