File size: 6,843 Bytes
97b07c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98bef5d
97b07c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e69cc97
97b07c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e3aa94
 
 
97b07c2
be69684
1e3aa94
97b07c2
 
 
1e3aa94
 
 
 
 
be69684
 
 
 
 
 
 
97b07c2
be69684
 
 
 
 
 
97b07c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eaa1fa5
 
97b07c2
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
import streamlit as st
from st_aggrid import AgGrid
import pandas as pd
from transformers import pipeline, T5ForConditionalGeneration, T5Tokenizer
import plotly.express as px

# Set the page layout for Streamlit
st.set_page_config(layout="wide")

# Initialize TAPAS pipeline
tqa = pipeline(task="table-question-answering", 
              model="google/tapas-large-finetuned-wtq",
              device="cpu")

# Initialize T5 tokenizer and model for text generation
t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")
t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")

# Title and Introduction
st.title("Table Question Answering and Data Analysis App")
st.markdown(""" 
    This app allows you to upload a table (CSV or Excel) and ask questions about the data.
    Based on your question, it will provide the corresponding answer using the **TAPAS** model and additional data processing.

    ### Available Features:
    - **mean()**: For "average", it computes the mean of the entire numeric DataFrame.
    - **sum()**: For "sum", it calculates the sum of all numeric values in the DataFrame.
    - **max()**: For "max", it computes the maximum value in the DataFrame.
    - **min()**: For "min", it computes the minimum value in the DataFrame.
    - **count()**: For "count", it counts the non-null values in the entire DataFrame.
    - **Graph Generation**: You can ask questions like "make a graph of column sales?" or "make a graph between sales and expenses?". The app will generate interactive graphs for you.
    
    Upload your data and ask questions to get both answers and visualizations.
""")

# File uploader in the sidebar
file_name = st.sidebar.file_uploader("Upload file:", type=['csv', 'xlsx'])

# File processing and question answering
if file_name is None:
    st.markdown('<p class="font">Please upload an excel or csv file </p>', unsafe_allow_html=True)
else:
    try:
        # Check file type and handle reading accordingly
        if file_name.name.endswith('.csv'):
            df = pd.read_csv(file_name, sep=';', encoding='ISO-8859-1')  # Adjust encoding if needed
        elif file_name.name.endswith('.xlsx'):
            df = pd.read_excel(file_name, engine='openpyxl')  # Use openpyxl to read .xlsx files
        else:
            st.error("Unsupported file type")
            df = None

        if df is not None:
            numeric_columns = df.select_dtypes(include=['object']).columns
            for col in numeric_columns:
                df[col] = pd.to_numeric(df[col], errors='ignore')

            st.write("Original Data:")
            st.write(df)

            df_numeric = df.copy()
            df = df.astype(str)

            # Display the first 5 rows of the dataframe in an editable grid
            grid_response = AgGrid(
                df.head(5),
                fit_columns_on_grid_load=True,  # Correct parameter to fit columns on grid load
                editable=True, 
                height=300, 
                width='100%',
            )
            
    except Exception as e:
        st.error(f"Error reading file: {str(e)}")

    # User input for the question
    question = st.text_input('Type your question')

    # Check if the question is about generating a graph
    is_graph_query = False

    if 'graph' in question.lower():
        is_graph_query = True

    # Process the answer using TAPAS and T5
    with st.spinner():
        if st.button('Answer'):
            try:
                if not is_graph_query:
                    # Process TAPAS-related questions if it's not a graph query
                    raw_answer = tqa(table=df, query=question, truncation=True)

                    # Display raw answer from TAPAS on the screen
                    st.markdown("<p style='font-family:sans-serif;font-size: 1rem;'>Raw TAPAS Answer: </p>", unsafe_allow_html=True)
                    st.write(raw_answer)  # Display the raw TAPAS output

                    # Extract relevant values for Plotly
                    answer = raw_answer.get('answer', '')
                    coordinates = raw_answer.get('coordinates', [])
                    cells = raw_answer.get('cells', [])

                    st.markdown("<p style='font-family:sans-serif;font-size: 1rem;'>Relevant Data for Plotly: </p>", unsafe_allow_html=True)
                    st.write(f"Answer: {answer}")
                    st.write(f"Coordinates: {coordinates}")
                    st.write(f"Cells: {cells}")

                    # If TAPAS is returning a list of numbers for "average" like you mentioned
                    if "average" in question.lower() and cells:
                        # Assuming cells are numeric values that can be plotted in a graph
                        plot_data = [float(cell) for cell in cells]  # Convert cells to numeric data

                        # Create a DataFrame for Plotly
                        plot_df = pd.DataFrame({ 'Index': list(range(1, len(plot_data) + 1)), 'Value': plot_data })

                        # Generate a graph using Plotly
                        fig = px.line(plot_df, x='Index', y='Value', title=f"Graph for '{question}'")
                        st.plotly_chart(fig, use_container_width=True)

                    else:
                        st.write(f"No data to plot for the question: '{question}'")

                else:
                    # Handle graph-related questions
                    if 'between' in question.lower() and 'and' in question.lower():
                        columns = question.split('between')[-1].split('and')
                        columns = [col.strip() for col in columns]
                        if len(columns) == 2 and all(col in df.columns for col in columns):
                            fig = px.scatter(df, x=columns[0], y=columns[1], title=f"Graph between {columns[0]} and {columns[1]}")
                            st.plotly_chart(fig, use_container_width=True)
                            st.success(f"Here is the graph between '{columns[0]}' and '{columns[1]}'.")
                        else:
                            st.warning("Columns not found in the dataset.")
                    elif 'column' in question.lower():
                        column = question.split('of')[-1].strip()
                        if column in df.columns:
                            fig = px.line(df, x=df.index, y=column, title=f"Graph of column '{column}'")
                            st.plotly_chart(fig, use_container_width=True)
                    
                    st.stop()  # This halts further execution

            except Exception as e:
                st.warning(f"Error processing question or generating answer: {str(e)}")
                st.warning("Please retype your question and make sure to use the column name and cell value correctly.")