File size: 5,774 Bytes
2e654c9
7797cc9
4665c84
 
 
 
 
 
 
 
 
 
7797cc9
 
 
 
 
 
 
 
 
 
 
 
4665c84
 
 
 
7797cc9
 
 
 
 
4665c84
 
 
 
7797cc9
 
 
 
4665c84
7797cc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fcf406d
 
 
 
 
 
 
 
 
 
 
 
 
7797cc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fcf406d
7797cc9
fcf406d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import pandas as pd
import streamlit as st
from transformers import TapasForQuestionAnswering, TapasTokenizer, T5ForConditionalGeneration, T5Tokenizer
import torch

# Load TAPAS model and tokenizer
tqa_model = TapasForQuestionAnswering.from_pretrained("google/tapas-large-finetuned-wtq")
tqa_tokenizer = TapasTokenizer.from_pretrained("google/tapas-large-finetuned-wtq")

# Load T5 model and tokenizer for rephrasing
t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")
t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")

# Assuming 'df' is the DataFrame you are using and has numeric columns
df_numeric = df.select_dtypes(include='number')

# User input for the question
question = st.text_input('Type your question')

# Process the answer using TAPAS and T5
with st.spinner():
    if st.button('Answer'):
        try:
            # Get the raw answer from TAPAS
            inputs = tqa_tokenizer(table=df, query=question, return_tensors="pt")
            with torch.no_grad():
                outputs = tqa_model(**inputs)
                raw_answer = tqa_tokenizer.decode(outputs.logits.argmax(dim=-1), skip_special_tokens=True)

            st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Raw Result From TAPAS: </p>", unsafe_allow_html=True)
            st.success(raw_answer)

            # Extract relevant information from the TAPAS result
            answer = raw_answer
            aggregator = "average"  # Example aggregator, adjust based on raw_answer if needed
            coordinates = []  # Example, adjust based on raw_answer
            cells = []  # Example, adjust based on raw_answer

            # Construct a base sentence replacing 'SUM' with the query term
            base_sentence = f"The {question.lower()} of the selected data is {answer}."
            if coordinates and cells:
                rows_info = [f"Row {coordinate[0] + 1}, Column '{df.columns[coordinate[1]]}' with value {cell}" 
                             for coordinate, cell in zip(coordinates, cells)]
                rows_description = " and ".join(rows_info)
                base_sentence += f" This includes the following data: {rows_description}."

            # Generate a fluent response using the T5 model, rephrasing the base sentence
            input_text = f"Given the question: '{question}', generate a more human-readable response: {base_sentence}"

            inputs = t5_tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True)
            summary_ids = t5_model.generate(inputs, max_length=150, num_beams=4, early_stopping=True)

            generated_text = t5_tokenizer.decode(summary_ids[0], skip_special_tokens=True)

            # Display the final generated response
            st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Final Generated Response with LLM: </p>", unsafe_allow_html=True)
            st.success(generated_text)

        except Exception as e:
            st.warning("Please retype your question and make sure to use the column name and cell value correctly.")


# Manually fix the aggregator if it returns an incorrect one
if 'MEDIAN' in question.upper() and 'AVERAGE' in aggregator.upper():
    aggregator = 'MEDIAN'
elif 'MIN' in question.upper() and 'AVERAGE' in aggregator.upper():
    aggregator = 'MIN'
elif 'MAX' in question.upper() and 'AVERAGE' in aggregator.upper():
    aggregator = 'MAX'
elif 'TOTAL' in question.upper() and 'SUM' in aggregator.upper():
    aggregator = 'SUM'

# Use the corrected aggregator for further processing
summary_type = aggregator.lower()

# Check if `column_name` is valid before proceeding
if column_name and column_name in df_numeric.columns:
    # Now, calculate the correct value using pandas based on the corrected aggregator
    if summary_type == 'sum':
        numeric_value = df_numeric[column_name].sum()
    elif summary_type == 'max':
        numeric_value = df_numeric[column_name].max()
    elif summary_type == 'min':
        numeric_value = df_numeric[column_name].min()
    elif summary_type == 'average':
        numeric_value = df_numeric[column_name].mean()
    elif summary_type == 'count':
        numeric_value = df_numeric[column_name].count()
    elif summary_type == 'median':
        numeric_value = df_numeric[column_name].median()
    elif summary_type == 'std_dev':
        numeric_value = df_numeric[column_name].std()
    else:
        numeric_value = answer  # Fallback if something went wrong
else:
    numeric_value = "Invalid column"

# Construct a natural language response
if summary_type == 'sum':
    natural_language_answer = f"The total {column_name} is {numeric_value}."
elif summary_type == 'maximum':
    natural_language_answer = f"The highest {column_name} is {numeric_value}."
elif summary_type == 'minimum':
    natural_language_answer = f"The lowest {column_name} is {numeric_value}."
elif summary_type == 'average':
    natural_language_answer = f"The average {column_name} is {numeric_value}."
elif summary_type == 'count':
    natural_language_answer = f"The number of entries in {column_name} is {numeric_value}."
elif summary_type == 'median':
    natural_language_answer = f"The median {column_name} is {numeric_value}."
elif summary_type == 'std_dev':
    natural_language_answer = f"The standard deviation of {column_name} is {numeric_value}."
else:
    natural_language_answer = f"The value for {column_name} is {numeric_value}."

# Display the result to the user
st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Analysis Results: </p>", unsafe_allow_html=True)
st.success(f"""
    • Answer: {natural_language_answer}

    Data Location:
    • Column: {column_name}

    Additional Context:
    • Query Asked: "{question}"
""")