hertogateis commited on
Commit
2e654c9
·
verified ·
1 Parent(s): 8a33362

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -91
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import streamlit as st
3
- import pandas as pd
4
  from st_aggrid import AgGrid
 
5
  from transformers import pipeline, T5ForConditionalGeneration, T5Tokenizer
6
 
7
  # Set the page layout for Streamlit
@@ -14,59 +14,19 @@ style = '''
14
  header {visibility: hidden;}
15
  div.block-container {padding-top:4rem;}
16
  section[data-testid="stSidebar"] div:first-child {
17
- padding-top: 0;
18
- }
19
- .font {
20
- text-align:center;
21
- font-family:sans-serif;font-size: 1.25rem;
22
- }
23
- .button-container {
24
- display: flex;
25
- justify-content: center;
26
- margin-top: 50px;
27
- }
28
- .upload-button {
29
- font-size: 1.5rem;
30
- padding: 15px 30px;
31
- background-color: #4CAF50;
32
- color: white;
33
- border-radius: 8px;
34
- cursor: pointer;
35
- border: none;
36
- }
37
- .upload-button:hover {
38
- background-color: #45a049;
39
- }
40
- .file-upload-container {
41
- margin-top: 20px;
42
- display: none;
43
- }
44
  </style>
45
  '''
46
  st.markdown(style, unsafe_allow_html=True)
47
 
48
- # Title and description
49
- st.markdown('<p style="font-family:sans-serif;font-size: 1.9rem;"> HertogAI Table Q&A using TAPAS+Data Analysis and Model Language</p>', unsafe_allow_html=True)
50
- st.markdown('<p style="font-family:sans-serif;font-size: 1.0rem;"> This code is based on Jordan Skinner. I recoded and enhanced for </p>', unsafe_allow_html=True)
51
- st.markdown('<p style="font-family:sans-serif;font-size: 1.2rem;"> the Data analysis are (SUM, MAX, MIN, AVG, COUNT, MEAN, STDDEV) </p>', unsafe_allow_html=True)
52
- st.markdown("<p style='font-family:sans-serif;font-size: 0.8rem;'>Pre-trained TAPAS model runs on max 64 rows and 32 columns data. Make sure the file data doesn't exceed these dimensions.</p>", unsafe_allow_html=True)
53
-
54
- # Create a large clickable button for the user to interact with
55
- if st.button("Click to upload file", key="upload_button", help="Click here to upload a file", use_container_width=True):
56
- # Show the file uploader widget when the button is clicked
57
- uploaded_file = st.file_uploader("Upload your file", type=["csv", "xlsx"])
58
-
59
- # Check if the user uploaded a file
60
- if uploaded_file is not None:
61
- st.success(f"File '{uploaded_file.name}' uploaded successfully!")
62
-
63
- # Process the file (example: display the first 5 rows)
64
- if uploaded_file.name.endswith('.csv'):
65
- df = pd.read_csv(uploaded_file)
66
- st.write(df.head())
67
- elif uploaded_file.name.endswith('.xlsx'):
68
- df = pd.read_excel(uploaded_file)
69
- st.write(df.head())
70
 
71
  # Initialize TAPAS pipeline
72
  tqa = pipeline(task="table-question-answering",
@@ -77,46 +37,165 @@ tqa = pipeline(task="table-question-answering",
77
  t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")
78
  t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")
79
 
80
- # User input for the question
81
- question = st.text_input('Type your question')
82
 
83
- # Process the answer using TAPAS and T5
84
- with st.spinner():
85
- if st.button('Answer'):
86
- try:
87
- # Get the raw answer from TAPAS
88
- raw_answer = tqa(table=df, query=question, truncation=True)
89
-
90
- st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Raw Result From TAPAS: </p>", unsafe_allow_html=True)
91
- st.success(raw_answer)
92
-
93
- # Extract relevant information from the TAPAS result
94
- answer = raw_answer['answer']
95
- aggregator = raw_answer.get('aggregator', '')
96
- coordinates = raw_answer.get('coordinates', [])
97
- cells = raw_answer.get('cells', [])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
- # Construct a base sentence replacing 'SUM' with the query term
100
- base_sentence = f"The {question.lower()} of the selected data is {answer}."
101
- if coordinates and cells:
102
- rows_info = [f"Row {coordinate[0] + 1}, Column '{df.columns[coordinate[1]]}' with value {cell}"
103
- for coordinate, cell in zip(coordinates, cells)]
104
- rows_description = " and ".join(rows_info)
105
- base_sentence += f" This includes the following data: {rows_description}."
106
-
107
- # Generate a fluent response using the T5 model, rephrasing the base sentence
108
- input_text = f"Given the question: '{question}', generate a more human-readable response: {base_sentence}"
109
-
110
- # Tokenize the input and generate a fluent response using T5
111
- inputs = t5_tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True)
112
- summary_ids = t5_model.generate(inputs, max_length=150, num_beams=4, early_stopping=True)
113
-
114
- # Decode the generated text
115
- generated_text = t5_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
116
-
117
- # Display the final generated response
118
- st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Final Generated Response with LLM: </p>", unsafe_allow_html=True)
119
- st.success(generated_text)
120
-
121
- except Exception as e:
122
- st.warning("Please retype your question and make sure to use the column name and cell value correctly.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import streamlit as st
 
3
  from st_aggrid import AgGrid
4
+ import pandas as pd
5
  from transformers import pipeline, T5ForConditionalGeneration, T5Tokenizer
6
 
7
  # Set the page layout for Streamlit
 
14
  header {visibility: hidden;}
15
  div.block-container {padding-top:4rem;}
16
  section[data-testid="stSidebar"] div:first-child {
17
+ padding-top: 0;
18
+ }
19
+ .font {
20
+ text-align:center;
21
+ font-family:sans-serif;font-size: 1.25rem;}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  </style>
23
  '''
24
  st.markdown(style, unsafe_allow_html=True)
25
 
26
+ st.markdown('<p style="font-family:sans-serif;font-size: 1.9rem;"> HertogAI Table Q&A using TAPAS and Model Language</p>', unsafe_allow_html=True)
27
+ st.markdown('<p style="font-family:sans-serif;font-size: 1.0rem;"> This code is based on Jordan Skinner. I recoded and enhanced it </p>', unsafe_allow_html=True)
28
+ st.markdown("<p style='font-family:sans-serif;font-size: 1.2rem;'>Pre-trained TAPAS model runs on max 64 rows and 32 columns data. Make sure the file data doesn't exceed these dimensions.</p>", unsafe_allow_html=True)
29
+ st.markdown("<p style='font-family:sans-serif;font-size: 1.5rem;'>Click the side bar > to upload your file.</p>", unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  # Initialize TAPAS pipeline
32
  tqa = pipeline(task="table-question-answering",
 
37
  t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")
38
  t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")
39
 
40
+ # File uploader in the sidebar
41
+ file_name = st.sidebar.file_uploader("Upload file:", type=['csv', 'xlsx'])
42
 
43
+ # File processing and question answering
44
+ if file_name is None:
45
+ st.markdown('<p class="font">Please upload an excel or csv file </p>', unsafe_allow_html=True)
46
+ else:
47
+ try:
48
+ # Check file type and handle reading accordingly
49
+ if file_name.name.endswith('.csv'):
50
+ df = pd.read_csv(file_name, sep=';', encoding='ISO-8859-1') # Adjust encoding if needed
51
+ elif file_name.name.endswith('.xlsx'):
52
+ df = pd.read_excel(file_name, engine='openpyxl') # Use openpyxl to read .xlsx files
53
+ else:
54
+ st.error("Unsupported file type")
55
+ df = None
56
+
57
+ # Continue with further processing if df is loaded
58
+ if df is not None:
59
+ numeric_columns = df.select_dtypes(include=['object']).columns
60
+ for col in numeric_columns:
61
+ df[col] = pd.to_numeric(df[col], errors='ignore')
62
+
63
+ st.write("Original Data:")
64
+ st.write(df)
65
+
66
+ # Create a copy for numerical operations
67
+ df_numeric = df.copy()
68
+ df = df.astype(str)
69
+
70
+ # Display the first 5 rows of the dataframe in an editable grid
71
+ grid_response = AgGrid(
72
+ df.head(5),
73
+ columns_auto_size_mode='FIT_CONTENTS',
74
+ editable=True,
75
+ height=300,
76
+ width='100%',
77
+ )
78
 
79
+ except Exception as e:
80
+ st.error(f"Error reading file: {str(e)}")
81
+
82
+ # User input for the question
83
+ question = st.text_input('Type your question')
84
+
85
+ # Process the answer using TAPAS and T5
86
+ with st.spinner():
87
+ if st.button('Answer'):
88
+ try:
89
+ # Get the raw answer from TAPAS
90
+ raw_answer = tqa(table=df, query=question, truncation=True)
91
+
92
+ st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Raw Result From TAPAS: </p>",
93
+ unsafe_allow_html=True)
94
+ st.success(raw_answer)
95
+
96
+ # Extract relevant information from the TAPAS result
97
+ answer = raw_answer['answer']
98
+ aggregator = raw_answer.get('aggregator', '')
99
+ coordinates = raw_answer.get('coordinates', [])
100
+ cells = raw_answer.get('cells', [])
101
+
102
+ # Construct a base sentence replacing 'SUM' with the query term
103
+ base_sentence = f"The {question.lower()} of the selected data is {answer}."
104
+ if coordinates and cells:
105
+ rows_info = [f"Row {coordinate[0] + 1}, Column '{df.columns[coordinate[1]]}' with value {cell}"
106
+ for coordinate, cell in zip(coordinates, cells)]
107
+ rows_description = " and ".join(rows_info)
108
+ base_sentence += f" This includes the following data: {rows_description}."
109
+
110
+ # Generate a fluent response using the T5 model, rephrasing the base sentence
111
+ input_text = f"Given the question: '{question}', generate a more human-readable response: {base_sentence}"
112
+
113
+ # Tokenize the input and generate a fluent response using T5
114
+ inputs = t5_tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True)
115
+ summary_ids = t5_model.generate(inputs, max_length=150, num_beams=4, early_stopping=True)
116
+
117
+ # Decode the generated text
118
+ generated_text = t5_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
119
+
120
+ # Display the final generated response
121
+ st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Final Generated Response with LLM: </p>", unsafe_allow_html=True)
122
+ st.success(generated_text)
123
+
124
+ except Exception as e:
125
+ st.warning("Please retype your question and make sure to use the column name and cell value correctly.")
126
+
127
+ try:
128
+ # Get raw answer again from TAPAS
129
+ raw_answer = tqa(table=df, query=question, truncation=True)
130
+
131
+ # Display raw result for debugging purposes
132
+ st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Raw Result: </p>", unsafe_allow_html=True)
133
+ st.success(raw_answer)
134
+
135
+ # Processing the raw_answer
136
+ processed_answer = raw_answer['answer'].replace(';', ' ') # Clean the answer text
137
+ row_idx = raw_answer['coordinates'][0][0] # Row index from TAPAS
138
+ col_idx = raw_answer['coordinates'][0][1] # Column index from TAPAS
139
+ column_name = df.columns[col_idx] # Column name from the DataFrame
140
+ row_data = df.iloc[row_idx].to_dict() # Row data corresponding to the row index
141
+
142
+ # Handle different types of answers (e.g., 'SUM', 'MAX', 'MIN', 'AVG', etc.)
143
+ if 'SUM' in processed_answer:
144
+ summary_type = 'sum'
145
+ numeric_value = df_numeric[column_name].sum()
146
+ elif 'MAX' in processed_answer:
147
+ summary_type = 'maximum'
148
+ numeric_value = df_numeric[column_name].max()
149
+ elif 'MIN' in processed_answer:
150
+ summary_type = 'minimum'
151
+ numeric_value = df_numeric[column_name].min()
152
+ elif 'AVG' in processed_answer or 'AVERAGE' in processed_answer:
153
+ summary_type = 'average'
154
+ numeric_value = df_numeric[column_name].mean()
155
+ elif 'COUNT' in processed_answer:
156
+ summary_type = 'count'
157
+ numeric_value = df_numeric[column_name].count()
158
+ elif 'MEDIAN' in processed_answer:
159
+ summary_type = 'median'
160
+ numeric_value = df_numeric[column_name].median()
161
+ elif 'STD' in processed_answer or 'STANDARD DEVIATION' in processed_answer:
162
+ summary_type = 'std_dev'
163
+ numeric_value = df_numeric[column_name].std()
164
+ else:
165
+ summary_type = 'value'
166
+ numeric_value = processed_answer # In case of a general answer
167
+
168
+ # Build a natural language response based on the aggregation type
169
+ if summary_type == 'sum':
170
+ natural_language_answer = f"The total {column_name} is {numeric_value}."
171
+ elif summary_type == 'maximum':
172
+ natural_language_answer = f"The highest {column_name} is {numeric_value}, recorded for '{row_data.get('Name', 'Unknown')}'."
173
+ elif summary_type == 'minimum':
174
+ natural_language_answer = f"The lowest {column_name} is {numeric_value}, recorded for '{row_data.get('Name', 'Unknown')}'."
175
+ elif summary_type == 'average':
176
+ natural_language_answer = f"The average {column_name} is {numeric_value}."
177
+ elif summary_type == 'count':
178
+ natural_language_answer = f"The number of entries in {column_name} is {numeric_value}."
179
+ elif summary_type == 'median':
180
+ natural_language_answer = f"The median {column_name} is {numeric_value}."
181
+ elif summary_type == 'std_dev':
182
+ natural_language_answer = f"The standard deviation of {column_name} is {numeric_value}."
183
+ else:
184
+ natural_language_answer = f"The {column_name} value is {numeric_value} for '{row_data.get('Name', 'Unknown')}'."
185
+
186
+ # Display the final natural language answer
187
+ st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Analysis Results: </p>", unsafe_allow_html=True)
188
+ st.success(f"""
189
+ • Answer: {natural_language_answer}
190
+
191
+ Data Location:
192
+ • Row: {row_idx + 1}
193
+ • Column: {column_name}
194
+
195
+ Additional Context:
196
+ • Full Row Data: {row_data}
197
+ • Query Asked: "{question}"
198
+ """)
199
+
200
+ except Exception as e:
201
+ st.warning("Please retype your question and make sure to use the column name and cell value correctly.")