Spaces:
Sleeping
Sleeping
File size: 6,391 Bytes
2e654c9 7797cc9 4665c84 879d20e 4665c84 7797cc9 4665c84 7797cc9 4665c84 7797cc9 4665c84 7797cc9 879d20e fcf406d 7797cc9 fcf406d 7797cc9 fcf406d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import pandas as pd
import streamlit as st
from transformers import TapasForQuestionAnswering, TapasTokenizer, T5ForConditionalGeneration, T5Tokenizer
import torch
# Assuming df is uploaded or pre-defined (you can replace with actual data loading logic)
# Example DataFrame (replace with your actual file or data)
data = {
'Column1': [1, 2, 3, 4],
'Column2': [5.5, 6.5, 7.5, 8.5],
'Column3': ['a', 'b', 'c', 'd']
}
df = pd.DataFrame(data)
# Check if DataFrame is valid
if df is not None and not df.empty:
# Select numeric columns
df_numeric = df.select_dtypes(include='number')
else:
df_numeric = pd.DataFrame() # Empty DataFrame if input is invalid
# Load TAPAS model and tokenizer
tqa_model = TapasForQuestionAnswering.from_pretrained("google/tapas-large-finetuned-wtq")
tqa_tokenizer = TapasTokenizer.from_pretrained("google/tapas-large-finetuned-wtq")
# Load T5 model and tokenizer for rephrasing
t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")
t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")
# User input for the question
question = st.text_input('Type your question')
# Process the answer using TAPAS and T5
with st.spinner():
if st.button('Answer'):
try:
# Get the raw answer from TAPAS
inputs = tqa_tokenizer(table=df, query=question, return_tensors="pt")
with torch.no_grad():
outputs = tqa_model(**inputs)
raw_answer = tqa_tokenizer.decode(outputs.logits.argmax(dim=-1), skip_special_tokens=True)
st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Raw Result From TAPAS: </p>", unsafe_allow_html=True)
st.success(raw_answer)
# Extract relevant information from the TAPAS result
answer = raw_answer
aggregator = "average" # Example aggregator, adjust based on raw_answer if needed
coordinates = [] # Example, adjust based on raw_answer
cells = [] # Example, adjust based on raw_answer
# Construct a base sentence replacing 'SUM' with the query term
base_sentence = f"The {question.lower()} of the selected data is {answer}."
if coordinates and cells:
rows_info = [f"Row {coordinate[0] + 1}, Column '{df.columns[coordinate[1]]}' with value {cell}"
for coordinate, cell in zip(coordinates, cells)]
rows_description = " and ".join(rows_info)
base_sentence += f" This includes the following data: {rows_description}."
# Generate a fluent response using the T5 model, rephrasing the base sentence
input_text = f"Given the question: '{question}', generate a more human-readable response: {base_sentence}"
inputs = t5_tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True)
summary_ids = t5_model.generate(inputs, max_length=150, num_beams=4, early_stopping=True)
generated_text = t5_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
# Display the final generated response
st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Final Generated Response with LLM: </p>", unsafe_allow_html=True)
st.success(generated_text)
except Exception as e:
st.warning("Please retype your question and make sure to use the column name and cell value correctly.")
# Assuming 'column_name' exists and is selected or provided by the user
# Example of getting 'column_name' from user input (adjust this part according to your app):
column_name = st.selectbox("Select a column", df.columns)
# Manually fix the aggregator if it returns an incorrect one
if 'MEDIAN' in question.upper() and 'AVERAGE' in aggregator.upper():
aggregator = 'MEDIAN'
elif 'MIN' in question.upper() and 'AVERAGE' in aggregator.upper():
aggregator = 'MIN'
elif 'MAX' in question.upper() and 'AVERAGE' in aggregator.upper():
aggregator = 'MAX'
elif 'TOTAL' in question.upper() and 'SUM' in aggregator.upper():
aggregator = 'SUM'
# Use the corrected aggregator for further processing
summary_type = aggregator.lower()
# Check if `column_name` is valid before proceeding
if column_name and column_name in df_numeric.columns:
# Now, calculate the correct value using pandas based on the corrected aggregator
if summary_type == 'sum':
numeric_value = df_numeric[column_name].sum()
elif summary_type == 'max':
numeric_value = df_numeric[column_name].max()
elif summary_type == 'min':
numeric_value = df_numeric[column_name].min()
elif summary_type == 'average':
numeric_value = df_numeric[column_name].mean()
elif summary_type == 'count':
numeric_value = df_numeric[column_name].count()
elif summary_type == 'median':
numeric_value = df_numeric[column_name].median()
elif summary_type == 'std_dev':
numeric_value = df_numeric[column_name].std()
else:
numeric_value = answer # Fallback if something went wrong
else:
numeric_value = "Invalid column"
# Construct a natural language response
if summary_type == 'sum':
natural_language_answer = f"The total {column_name} is {numeric_value}."
elif summary_type == 'maximum':
natural_language_answer = f"The highest {column_name} is {numeric_value}."
elif summary_type == 'minimum':
natural_language_answer = f"The lowest {column_name} is {numeric_value}."
elif summary_type == 'average':
natural_language_answer = f"The average {column_name} is {numeric_value}."
elif summary_type == 'count':
natural_language_answer = f"The number of entries in {column_name} is {numeric_value}."
elif summary_type == 'median':
natural_language_answer = f"The median {column_name} is {numeric_value}."
elif summary_type == 'std_dev':
natural_language_answer = f"The standard deviation of {column_name} is {numeric_value}."
else:
natural_language_answer = f"The value for {column_name} is {numeric_value}."
# Display the result to the user
st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Analysis Results: </p>", unsafe_allow_html=True)
st.success(f"""
• Answer: {natural_language_answer}
Data Location:
• Column: {column_name}
Additional Context:
• Query Asked: "{question}"
""")
|