hertogateis commited on
Commit
8a33362
·
verified ·
1 Parent(s): 74b544f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -189
app.py CHANGED
@@ -1,50 +1,72 @@
1
  import os
2
  import streamlit as st
3
- from st_aggrid import AgGrid
4
  import pandas as pd
 
5
  from transformers import pipeline, T5ForConditionalGeneration, T5Tokenizer
6
 
7
- import streamlit as st
 
8
 
9
- # CSS styling for the toggle button and instructions
10
  style = '''
11
  <style>
12
- .toggle-wrapper {
 
 
 
 
 
 
 
 
 
 
13
  display: flex;
14
- align-items: center;
15
  justify-content: center;
16
- padding: 10px;
17
- background-color: #f0f0f0;
18
- border-radius: 10px;
19
- box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
20
- margin-top: 20px;
 
 
 
 
 
21
  }
22
- .toggle-text {
23
- font-family: sans-serif;
24
- font-size: 1.2rem;
25
- margin-right: 10px;
26
  }
27
- .slide-to-enter {
28
- font-size: 1rem;
29
- font-family: sans-serif;
30
- color: #333;
31
  }
32
  </style>
33
  '''
34
-
35
- # Include the CSS to the page
36
  st.markdown(style, unsafe_allow_html=True)
37
 
38
- # Render the toggle button with "Slide to enter data" text
39
- st.markdown('<div class="toggle-wrapper"><span class="toggle-text">> </span><span class="slide-to-enter">Slide to enter data</span></div>', unsafe_allow_html=True)
40
-
41
- # Your main app content goes here
42
-
43
-
44
  st.markdown('<p style="font-family:sans-serif;font-size: 1.9rem;"> HertogAI Table Q&A using TAPAS+Data Analysis and Model Language</p>', unsafe_allow_html=True)
45
- st.markdown('<p style="font-family:sans-serif;font-size: 0.9rem;"> This code is based on Jordan Skinner. I recoded and enhanced for </p>', unsafe_allow_html=True)
46
- st.markdown('<p style="font-family:sans-serif;font-size: 0.9rem;"> the Data analisys are (SUM, MAX, MIN, AVG, COUNT, MEAN, STDDEV) </p>', unsafe_allow_html=True)
47
- st.markdown("<p style='font-family:sans-serif;font-size: 0.5rem;'>Pre-trained TAPAS model runs on max 64 rows and 32 columns data. Make sure the file data doesn't exceed these dimensions.</p>", unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  # Initialize TAPAS pipeline
50
  tqa = pipeline(task="table-question-answering",
@@ -55,165 +77,46 @@ tqa = pipeline(task="table-question-answering",
55
  t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")
56
  t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")
57
 
58
- # File uploader in the sidebar
59
- file_name = st.sidebar.file_uploader("Upload file:", type=['csv', 'xlsx'])
60
-
61
- # File processing and question answering
62
- if file_name is None:
63
- st.markdown('<p class="font">Please upload an excel or csv file </p>', unsafe_allow_html=True)
64
- else:
65
- try:
66
- # Check file type and handle reading accordingly
67
- if file_name.name.endswith('.csv'):
68
- df = pd.read_csv(file_name, sep=';', encoding='ISO-8859-1') # Adjust encoding if needed
69
- elif file_name.name.endswith('.xlsx'):
70
- df = pd.read_excel(file_name, engine='openpyxl') # Use openpyxl to read .xlsx files
71
- else:
72
- st.error("Unsupported file type")
73
- df = None
74
 
75
- # Continue with further processing if df is loaded
76
- if df is not None:
77
- numeric_columns = df.select_dtypes(include=['object']).columns
78
- for col in numeric_columns:
79
- df[col] = pd.to_numeric(df[col], errors='ignore')
80
-
81
- st.write("Original Data:")
82
- st.write(df)
83
-
84
- # Create a copy for numerical operations
85
- df_numeric = df.copy()
86
- df = df.astype(str)
87
-
88
- # Display the first 5 rows of the dataframe in an editable grid
89
- grid_response = AgGrid(
90
- df.head(5),
91
- columns_auto_size_mode='FIT_CONTENTS',
92
- editable=True,
93
- height=300,
94
- width='100%',
95
- )
96
 
97
- except Exception as e:
98
- st.error(f"Error reading file: {str(e)}")
99
-
100
- # User input for the question
101
- question = st.text_input('Type your question')
102
-
103
- # Process the answer using TAPAS and T5
104
- with st.spinner():
105
- if st.button('Answer'):
106
- try:
107
- # Get the raw answer from TAPAS
108
- raw_answer = tqa(table=df, query=question, truncation=True)
109
-
110
- st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Raw Result From TAPAS: </p>",
111
- unsafe_allow_html=True)
112
- st.success(raw_answer)
113
-
114
- # Extract relevant information from the TAPAS result
115
- answer = raw_answer['answer']
116
- aggregator = raw_answer.get('aggregator', '')
117
- coordinates = raw_answer.get('coordinates', [])
118
- cells = raw_answer.get('cells', [])
119
-
120
- # Construct a base sentence replacing 'SUM' with the query term
121
- base_sentence = f"The {question.lower()} of the selected data is {answer}."
122
- if coordinates and cells:
123
- rows_info = [f"Row {coordinate[0] + 1}, Column '{df.columns[coordinate[1]]}' with value {cell}"
124
- for coordinate, cell in zip(coordinates, cells)]
125
- rows_description = " and ".join(rows_info)
126
- base_sentence += f" This includes the following data: {rows_description}."
127
-
128
- # Generate a fluent response using the T5 model, rephrasing the base sentence
129
- input_text = f"Given the question: '{question}', generate a more human-readable response: {base_sentence}"
130
-
131
- # Tokenize the input and generate a fluent response using T5
132
- inputs = t5_tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True)
133
- summary_ids = t5_model.generate(inputs, max_length=150, num_beams=4, early_stopping=True)
134
-
135
- # Decode the generated text
136
- generated_text = t5_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
137
-
138
- # Display the final generated response
139
- st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Final Generated Response with LLM: </p>", unsafe_allow_html=True)
140
- st.success(generated_text)
141
-
142
- except Exception as e:
143
- st.warning("Please retype your question and make sure to use the column name and cell value correctly.")
144
-
145
- try:
146
- # Get raw answer again from TAPAS
147
- raw_answer = tqa(table=df, query=question, truncation=True)
148
-
149
- # Display raw result for debugging purposes
150
- st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Raw Result: </p>", unsafe_allow_html=True)
151
- st.success(raw_answer)
152
-
153
- # Processing the raw_answer
154
- processed_answer = raw_answer['answer'].replace(';', ' ') # Clean the answer text
155
- row_idx = raw_answer['coordinates'][0][0] # Row index from TAPAS
156
- col_idx = raw_answer['coordinates'][0][1] # Column index from TAPAS
157
- column_name = df.columns[col_idx] # Column name from the DataFrame
158
- row_data = df.iloc[row_idx].to_dict() # Row data corresponding to the row index
159
-
160
- # Handle different types of answers (e.g., 'SUM', 'MAX', 'MIN', 'AVG', etc.)
161
- if 'SUM' in processed_answer:
162
- summary_type = 'sum'
163
- numeric_value = df_numeric[column_name].sum()
164
- elif 'MAX' in processed_answer:
165
- summary_type = 'maximum'
166
- numeric_value = df_numeric[column_name].max()
167
- elif 'MIN' in processed_answer:
168
- summary_type = 'minimum'
169
- numeric_value = df_numeric[column_name].min()
170
- elif 'AVG' in processed_answer or 'AVERAGE' in processed_answer:
171
- summary_type = 'average'
172
- numeric_value = df_numeric[column_name].mean()
173
- elif 'COUNT' in processed_answer:
174
- summary_type = 'count'
175
- numeric_value = df_numeric[column_name].count()
176
- elif 'MEDIAN' in processed_answer:
177
- summary_type = 'median'
178
- numeric_value = df_numeric[column_name].median()
179
- elif 'STD' in processed_answer or 'STANDARD DEVIATION' in processed_answer:
180
- summary_type = 'std_dev'
181
- numeric_value = df_numeric[column_name].std()
182
- else:
183
- summary_type = 'value'
184
- numeric_value = processed_answer # In case of a general answer
185
-
186
- # Build a natural language response based on the aggregation type
187
- if summary_type == 'sum':
188
- natural_language_answer = f"The total {column_name} is {numeric_value}."
189
- elif summary_type == 'maximum':
190
- natural_language_answer = f"The highest {column_name} is {numeric_value}, recorded for '{row_data.get('Name', 'Unknown')}'."
191
- elif summary_type == 'minimum':
192
- natural_language_answer = f"The lowest {column_name} is {numeric_value}, recorded for '{row_data.get('Name', 'Unknown')}'."
193
- elif summary_type == 'average':
194
- natural_language_answer = f"The average {column_name} is {numeric_value}."
195
- elif summary_type == 'count':
196
- natural_language_answer = f"The number of entries in {column_name} is {numeric_value}."
197
- elif summary_type == 'median':
198
- natural_language_answer = f"The median {column_name} is {numeric_value}."
199
- elif summary_type == 'std_dev':
200
- natural_language_answer = f"The standard deviation of {column_name} is {numeric_value}."
201
- else:
202
- natural_language_answer = f"The {column_name} value is {numeric_value} for '{row_data.get('Name', 'Unknown')}'."
203
-
204
- # Display the final natural language answer
205
- st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Analysis Results: </p>", unsafe_allow_html=True)
206
- st.success(f"""
207
- • Answer: {natural_language_answer}
208
-
209
- Data Location:
210
- • Row: {row_idx + 1}
211
- • Column: {column_name}
212
-
213
- Additional Context:
214
- • Full Row Data: {row_data}
215
- • Query Asked: "{question}"
216
- """)
217
-
218
- except Exception as e:
219
- st.warning("Please retype your question and make sure to use the column name and cell value correctly.")
 
1
  import os
2
  import streamlit as st
 
3
  import pandas as pd
4
+ from st_aggrid import AgGrid
5
  from transformers import pipeline, T5ForConditionalGeneration, T5Tokenizer
6
 
7
+ # Set the page layout for Streamlit
8
+ st.set_page_config(layout="wide")
9
 
10
+ # CSS styling
11
  style = '''
12
  <style>
13
+ body {background-color: #F5F5F5; color: #000000;}
14
+ header {visibility: hidden;}
15
+ div.block-container {padding-top:4rem;}
16
+ section[data-testid="stSidebar"] div:first-child {
17
+ padding-top: 0;
18
+ }
19
+ .font {
20
+ text-align:center;
21
+ font-family:sans-serif;font-size: 1.25rem;
22
+ }
23
+ .button-container {
24
  display: flex;
 
25
  justify-content: center;
26
+ margin-top: 50px;
27
+ }
28
+ .upload-button {
29
+ font-size: 1.5rem;
30
+ padding: 15px 30px;
31
+ background-color: #4CAF50;
32
+ color: white;
33
+ border-radius: 8px;
34
+ cursor: pointer;
35
+ border: none;
36
  }
37
+ .upload-button:hover {
38
+ background-color: #45a049;
 
 
39
  }
40
+ .file-upload-container {
41
+ margin-top: 20px;
42
+ display: none;
 
43
  }
44
  </style>
45
  '''
 
 
46
  st.markdown(style, unsafe_allow_html=True)
47
 
48
+ # Title and description
 
 
 
 
 
49
  st.markdown('<p style="font-family:sans-serif;font-size: 1.9rem;"> HertogAI Table Q&A using TAPAS+Data Analysis and Model Language</p>', unsafe_allow_html=True)
50
+ st.markdown('<p style="font-family:sans-serif;font-size: 1.0rem;"> This code is based on Jordan Skinner. I recoded and enhanced for </p>', unsafe_allow_html=True)
51
+ st.markdown('<p style="font-family:sans-serif;font-size: 1.2rem;"> the Data analysis are (SUM, MAX, MIN, AVG, COUNT, MEAN, STDDEV) </p>', unsafe_allow_html=True)
52
+ st.markdown("<p style='font-family:sans-serif;font-size: 0.8rem;'>Pre-trained TAPAS model runs on max 64 rows and 32 columns data. Make sure the file data doesn't exceed these dimensions.</p>", unsafe_allow_html=True)
53
+
54
+ # Create a large clickable button for the user to interact with
55
+ if st.button("Click to upload file", key="upload_button", help="Click here to upload a file", use_container_width=True):
56
+ # Show the file uploader widget when the button is clicked
57
+ uploaded_file = st.file_uploader("Upload your file", type=["csv", "xlsx"])
58
+
59
+ # Check if the user uploaded a file
60
+ if uploaded_file is not None:
61
+ st.success(f"File '{uploaded_file.name}' uploaded successfully!")
62
+
63
+ # Process the file (example: display the first 5 rows)
64
+ if uploaded_file.name.endswith('.csv'):
65
+ df = pd.read_csv(uploaded_file)
66
+ st.write(df.head())
67
+ elif uploaded_file.name.endswith('.xlsx'):
68
+ df = pd.read_excel(uploaded_file)
69
+ st.write(df.head())
70
 
71
  # Initialize TAPAS pipeline
72
  tqa = pipeline(task="table-question-answering",
 
77
  t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")
78
  t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")
79
 
80
+ # User input for the question
81
+ question = st.text_input('Type your question')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ # Process the answer using TAPAS and T5
84
+ with st.spinner():
85
+ if st.button('Answer'):
86
+ try:
87
+ # Get the raw answer from TAPAS
88
+ raw_answer = tqa(table=df, query=question, truncation=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Raw Result From TAPAS: </p>", unsafe_allow_html=True)
91
+ st.success(raw_answer)
92
+
93
+ # Extract relevant information from the TAPAS result
94
+ answer = raw_answer['answer']
95
+ aggregator = raw_answer.get('aggregator', '')
96
+ coordinates = raw_answer.get('coordinates', [])
97
+ cells = raw_answer.get('cells', [])
98
+
99
+ # Construct a base sentence replacing 'SUM' with the query term
100
+ base_sentence = f"The {question.lower()} of the selected data is {answer}."
101
+ if coordinates and cells:
102
+ rows_info = [f"Row {coordinate[0] + 1}, Column '{df.columns[coordinate[1]]}' with value {cell}"
103
+ for coordinate, cell in zip(coordinates, cells)]
104
+ rows_description = " and ".join(rows_info)
105
+ base_sentence += f" This includes the following data: {rows_description}."
106
+
107
+ # Generate a fluent response using the T5 model, rephrasing the base sentence
108
+ input_text = f"Given the question: '{question}', generate a more human-readable response: {base_sentence}"
109
+
110
+ # Tokenize the input and generate a fluent response using T5
111
+ inputs = t5_tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True)
112
+ summary_ids = t5_model.generate(inputs, max_length=150, num_beams=4, early_stopping=True)
113
+
114
+ # Decode the generated text
115
+ generated_text = t5_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
116
+
117
+ # Display the final generated response
118
+ st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Final Generated Response with LLM: </p>", unsafe_allow_html=True)
119
+ st.success(generated_text)
120
+
121
+ except Exception as e:
122
+ st.warning("Please retype your question and make sure to use the column name and cell value correctly.")