Spaces:
Sleeping
Sleeping
File size: 7,005 Bytes
98bef5d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import os
import streamlit as st
from st_aggrid import AgGrid
import pandas as pd
from transformers import pipeline, T5ForConditionalGeneration, T5Tokenizer
import plotly.express as px
# Set the page layout for Streamlit
st.set_page_config(layout="wide")
# Initialize TAPAS pipeline
tqa = pipeline(task="table-question-answering",
model="google/tapas-large-finetuned-wtq",
device="cpu")
# Initialize T5 tokenizer and model for text generation
t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")
t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")
# Title and Introduction
st.title("Table Question Answering and Data Analysis App")
st.markdown("""
This app allows you to upload a table (CSV or Excel) and ask questions about the data.
Based on your question, it will provide the corresponding answer using the **TAPAS** model and additional data processing.
### Available Features:
- **mean()**: For "average", it computes the mean of the entire numeric DataFrame.
- **sum()**: For "sum", it calculates the sum of all numeric values in the DataFrame.
- **max()**: For "max", it computes the maximum value in the DataFrame.
- **min()**: For "min", it computes the minimum value in the DataFrame.
- **count()**: For "count", it counts the non-null values in the entire DataFrame.
- **Graph Generation**: You can ask questions like "make a graph of column sales?" or "make a graph between sales and expenses?". The app will generate interactive graphs for you.
Upload your data and ask questions to get both answers and visualizations.
""")
# File uploader in the sidebar
file_name = st.sidebar.file_uploader("Upload file:", type=['csv', 'xlsx'])
# File processing and question answering
if file_name is None:
st.markdown('<p class="font">Please upload an excel or csv file </p>', unsafe_allow_html=True)
else:
try:
# Check file type and handle reading accordingly
if file_name.name.endswith('.csv'):
df = pd.read_csv(file_name, sep=';', encoding='ISO-8859-1') # Adjust encoding if needed
elif file_name.name.endswith('.xlsx'):
df = pd.read_excel(file_name, engine='openpyxl') # Use openpyxl to read .xlsx files
else:
st.error("Unsupported file type")
df = None
if df is not None:
numeric_columns = df.select_dtypes(include=['object']).columns
for col in numeric_columns:
df[col] = pd.to_numeric(df[col], errors='ignore')
st.write("Original Data:")
st.write(df)
df_numeric = df.copy()
df = df.astype(str)
# Display the first 5 rows of the dataframe in an editable grid
grid_response = AgGrid(
df.head(5),
columns_auto_size_mode='FIT_CONTENTS',
editable=True,
height=300,
width='100%',
)
except Exception as e:
st.error(f"Error reading file: {str(e)}")
# User input for the question
question = st.text_input('Type your question')
# Process the answer using TAPAS and T5
with st.spinner():
if st.button('Answer'):
try:
raw_answer = tqa(table=df, query=question, truncation=True)
st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Raw Result From TAPAS: </p>",
unsafe_allow_html=True)
st.success(raw_answer)
answer = raw_answer['answer']
aggregator = raw_answer.get('aggregator', '')
coordinates = raw_answer.get('coordinates', [])
cells = raw_answer.get('cells', [])
# Handle aggregation based on user question or TAPAS output
if 'average' in question.lower() or aggregator == 'AVG':
avg_value = df.mean().mean() # Calculate overall average
base_sentence = f"The average for '{question}' is {avg_value:.2f}."
elif 'sum' in question.lower() or aggregator == 'SUM':
total_sum = df.sum().sum() # Calculate overall sum
base_sentence = f"The sum for '{question}' is {total_sum:.2f}."
elif 'max' in question.lower() or aggregator == 'MAX':
max_value = df.max().max() # Find overall max value
base_sentence = f"The maximum value for '{question}' is {max_value:.2f}."
elif 'min' in question.lower() or aggregator == 'MIN':
min_value = df.min().min() # Find overall min value
base_sentence = f"The minimum value for '{question}' is {min_value:.2f}."
elif 'count' in question.lower() or aggregator == 'COUNT':
count_value = df.count().sum() # Count all values
base_sentence = f"The total count of non-null values for '{question}' is {count_value}."
elif 'graph' in question.lower():
# Check for graph-related queries
if 'between' in question.lower() and 'and' in question.lower():
columns = question.split('between')[-1].split('and')
columns = [col.strip() for col in columns]
if len(columns) == 2 and all(col in df.columns for col in columns):
fig = px.scatter(df, x=columns[0], y=columns[1], title=f"Graph between {columns[0]} and {columns[1]}")
st.plotly_chart(fig, use_container_width=True)
base_sentence = f"Here is the graph between '{columns[0]}' and '{columns[1]}'."
elif 'column' in question.lower():
column = question.split('of')[-1].strip()
if column in df.columns:
fig = px.line(df, x=df.index, y=column, title=f"Graph of column '{column}'")
st.plotly_chart(fig, use_container_width=True)
base_sentence = f"Here is the graph of column '{column}'."
else:
base_sentence = f"Column '{column}' not found in the data."
else:
base_sentence = f"The answer from TAPAS for '{question}' is {answer}."
# Display the final response
st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Final Generated Response: </p>", unsafe_allow_html=True)
st.success(base_sentence)
except Exception as e:
st.warning(f"Error processing question or generating answer: {str(e)}")
st.warning("Please retype your question and make sure to use the column name and cell value correctly.")
|