File size: 5,555 Bytes
de716c3
2e654c9
7797cc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fcf406d
 
 
 
 
 
 
 
 
 
 
 
 
7797cc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fcf406d
7797cc9
fcf406d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import pandas as pd
import streamlit as st
from tapas import tqa, t5_tokenizer, t5_model

# Assuming 'df' is the DataFrame you are using and has numeric columns
df_numeric = df.select_dtypes(include='number')

# Ensure that `column_name` is defined and valid
column_name = None  # Make sure this is defined later from TAPAS response

# User input for the question
question = st.text_input('Type your question')

# Process the answer using TAPAS and T5
with st.spinner():
    if st.button('Answer'):
        try:
            # Get the raw answer from TAPAS
            raw_answer = tqa(table=df, query=question, truncation=True)

            st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Raw Result From TAPAS: </p>", unsafe_allow_html=True)
            st.success(raw_answer)

            # Extract relevant information from the TAPAS result
            answer = raw_answer['answer']
            aggregator = raw_answer.get('aggregator', '')
            coordinates = raw_answer.get('coordinates', [])
            cells = raw_answer.get('cells', [])

            # Extract the column name based on coordinates
            if coordinates:
                row, col = coordinates[0]  # assuming single cell result
                column_name = df.columns[col]  # Get the column name

            # Construct a base sentence replacing 'SUM' with the query term
            base_sentence = f"The {question.lower()} of the selected data is {answer}."
            if coordinates and cells:
                rows_info = [f"Row {coordinate[0] + 1}, Column '{df.columns[coordinate[1]]}' with value {cell}"
                             for coordinate, cell in zip(coordinates, cells)]
                rows_description = " and ".join(rows_info)
                base_sentence += f" This includes the following data: {rows_description}."

            # Generate a fluent response using the T5 model, rephrasing the base sentence
            input_text = f"Given the question: '{question}', generate a more human-readable response: {base_sentence}"

            # Tokenize the input and generate a fluent response using T5
            inputs = t5_tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True)
            summary_ids = t5_model.generate(inputs, max_length=150, num_beams=4, early_stopping=True)

            # Decode the generated text
            generated_text = t5_tokenizer.decode(summary_ids[0], skip_special_tokens=True)

            # Display the final generated response
            st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Final Generated Response with LLM: </p>", unsafe_allow_html=True)
            st.success(generated_text)

        except Exception as e:
            st.warning("Please retype your question and make sure to use the column name and cell value correctly.")


# Manually fix the aggregator if it returns an incorrect one
if 'MEDIAN' in question.upper() and 'AVERAGE' in aggregator.upper():
    aggregator = 'MEDIAN'
elif 'MIN' in question.upper() and 'AVERAGE' in aggregator.upper():
    aggregator = 'MIN'
elif 'MAX' in question.upper() and 'AVERAGE' in aggregator.upper():
    aggregator = 'MAX'
elif 'TOTAL' in question.upper() and 'SUM' in aggregator.upper():
    aggregator = 'SUM'

# Use the corrected aggregator for further processing
summary_type = aggregator.lower()

# Check if `column_name` is valid before proceeding
if column_name and column_name in df_numeric.columns:
    # Now, calculate the correct value using pandas based on the corrected aggregator
    if summary_type == 'sum':
        numeric_value = df_numeric[column_name].sum()
    elif summary_type == 'max':
        numeric_value = df_numeric[column_name].max()
    elif summary_type == 'min':
        numeric_value = df_numeric[column_name].min()
    elif summary_type == 'average':
        numeric_value = df_numeric[column_name].mean()
    elif summary_type == 'count':
        numeric_value = df_numeric[column_name].count()
    elif summary_type == 'median':
        numeric_value = df_numeric[column_name].median()
    elif summary_type == 'std_dev':
        numeric_value = df_numeric[column_name].std()
    else:
        numeric_value = answer  # Fallback if something went wrong
else:
    numeric_value = "Invalid column"

# Construct a natural language response
if summary_type == 'sum':
    natural_language_answer = f"The total {column_name} is {numeric_value}."
elif summary_type == 'maximum':
    natural_language_answer = f"The highest {column_name} is {numeric_value}."
elif summary_type == 'minimum':
    natural_language_answer = f"The lowest {column_name} is {numeric_value}."
elif summary_type == 'average':
    natural_language_answer = f"The average {column_name} is {numeric_value}."
elif summary_type == 'count':
    natural_language_answer = f"The number of entries in {column_name} is {numeric_value}."
elif summary_type == 'median':
    natural_language_answer = f"The median {column_name} is {numeric_value}."
elif summary_type == 'std_dev':
    natural_language_answer = f"The standard deviation of {column_name} is {numeric_value}."
else:
    natural_language_answer = f"The value for {column_name} is {numeric_value}."

# Display the result to the user
st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Analysis Results: </p>", unsafe_allow_html=True)
st.success(f"""
    • Answer: {natural_language_answer}

    Data Location:
    • Column: {column_name}

    Additional Context:
    • Query Asked: "{question}"
""")