File size: 6,391 Bytes
2e654c9
7797cc9
4665c84
 
 
879d20e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4665c84
 
 
 
 
 
 
7797cc9
 
 
 
 
 
 
 
 
4665c84
 
 
 
7797cc9
 
 
 
 
4665c84
 
 
 
7797cc9
 
 
 
4665c84
7797cc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
879d20e
 
 
 
fcf406d
 
 
 
 
 
 
 
 
 
 
 
 
7797cc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fcf406d
7797cc9
fcf406d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import pandas as pd
import streamlit as st
from transformers import TapasForQuestionAnswering, TapasTokenizer, T5ForConditionalGeneration, T5Tokenizer
import torch

# Assuming df is uploaded or pre-defined (you can replace with actual data loading logic)
# Example DataFrame (replace with your actual file or data)
data = {
    'Column1': [1, 2, 3, 4],
    'Column2': [5.5, 6.5, 7.5, 8.5],
    'Column3': ['a', 'b', 'c', 'd']
}
df = pd.DataFrame(data)

# Check if DataFrame is valid
if df is not None and not df.empty:
    # Select numeric columns
    df_numeric = df.select_dtypes(include='number')
else:
    df_numeric = pd.DataFrame()  # Empty DataFrame if input is invalid

# Load TAPAS model and tokenizer
tqa_model = TapasForQuestionAnswering.from_pretrained("google/tapas-large-finetuned-wtq")
tqa_tokenizer = TapasTokenizer.from_pretrained("google/tapas-large-finetuned-wtq")

# Load T5 model and tokenizer for rephrasing
t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")
t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")

# User input for the question
question = st.text_input('Type your question')

# Process the answer using TAPAS and T5
with st.spinner():
    if st.button('Answer'):
        try:
            # Get the raw answer from TAPAS
            inputs = tqa_tokenizer(table=df, query=question, return_tensors="pt")
            with torch.no_grad():
                outputs = tqa_model(**inputs)
                raw_answer = tqa_tokenizer.decode(outputs.logits.argmax(dim=-1), skip_special_tokens=True)

            st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Raw Result From TAPAS: </p>", unsafe_allow_html=True)
            st.success(raw_answer)

            # Extract relevant information from the TAPAS result
            answer = raw_answer
            aggregator = "average"  # Example aggregator, adjust based on raw_answer if needed
            coordinates = []  # Example, adjust based on raw_answer
            cells = []  # Example, adjust based on raw_answer

            # Construct a base sentence replacing 'SUM' with the query term
            base_sentence = f"The {question.lower()} of the selected data is {answer}."
            if coordinates and cells:
                rows_info = [f"Row {coordinate[0] + 1}, Column '{df.columns[coordinate[1]]}' with value {cell}" 
                             for coordinate, cell in zip(coordinates, cells)]
                rows_description = " and ".join(rows_info)
                base_sentence += f" This includes the following data: {rows_description}."

            # Generate a fluent response using the T5 model, rephrasing the base sentence
            input_text = f"Given the question: '{question}', generate a more human-readable response: {base_sentence}"

            inputs = t5_tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True)
            summary_ids = t5_model.generate(inputs, max_length=150, num_beams=4, early_stopping=True)

            generated_text = t5_tokenizer.decode(summary_ids[0], skip_special_tokens=True)

            # Display the final generated response
            st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Final Generated Response with LLM: </p>", unsafe_allow_html=True)
            st.success(generated_text)

        except Exception as e:
            st.warning("Please retype your question and make sure to use the column name and cell value correctly.")


# Assuming 'column_name' exists and is selected or provided by the user
# Example of getting 'column_name' from user input (adjust this part according to your app):
column_name = st.selectbox("Select a column", df.columns)

# Manually fix the aggregator if it returns an incorrect one
if 'MEDIAN' in question.upper() and 'AVERAGE' in aggregator.upper():
    aggregator = 'MEDIAN'
elif 'MIN' in question.upper() and 'AVERAGE' in aggregator.upper():
    aggregator = 'MIN'
elif 'MAX' in question.upper() and 'AVERAGE' in aggregator.upper():
    aggregator = 'MAX'
elif 'TOTAL' in question.upper() and 'SUM' in aggregator.upper():
    aggregator = 'SUM'

# Use the corrected aggregator for further processing
summary_type = aggregator.lower()

# Check if `column_name` is valid before proceeding
if column_name and column_name in df_numeric.columns:
    # Now, calculate the correct value using pandas based on the corrected aggregator
    if summary_type == 'sum':
        numeric_value = df_numeric[column_name].sum()
    elif summary_type == 'max':
        numeric_value = df_numeric[column_name].max()
    elif summary_type == 'min':
        numeric_value = df_numeric[column_name].min()
    elif summary_type == 'average':
        numeric_value = df_numeric[column_name].mean()
    elif summary_type == 'count':
        numeric_value = df_numeric[column_name].count()
    elif summary_type == 'median':
        numeric_value = df_numeric[column_name].median()
    elif summary_type == 'std_dev':
        numeric_value = df_numeric[column_name].std()
    else:
        numeric_value = answer  # Fallback if something went wrong
else:
    numeric_value = "Invalid column"

# Construct a natural language response
if summary_type == 'sum':
    natural_language_answer = f"The total {column_name} is {numeric_value}."
elif summary_type == 'maximum':
    natural_language_answer = f"The highest {column_name} is {numeric_value}."
elif summary_type == 'minimum':
    natural_language_answer = f"The lowest {column_name} is {numeric_value}."
elif summary_type == 'average':
    natural_language_answer = f"The average {column_name} is {numeric_value}."
elif summary_type == 'count':
    natural_language_answer = f"The number of entries in {column_name} is {numeric_value}."
elif summary_type == 'median':
    natural_language_answer = f"The median {column_name} is {numeric_value}."
elif summary_type == 'std_dev':
    natural_language_answer = f"The standard deviation of {column_name} is {numeric_value}."
else:
    natural_language_answer = f"The value for {column_name} is {numeric_value}."

# Display the result to the user
st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Analysis Results: </p>", unsafe_allow_html=True)
st.success(f"""
    • Answer: {natural_language_answer}

    Data Location:
    • Column: {column_name}

    Additional Context:
    • Query Asked: "{question}"
""")