adnanaman commited on
Commit
2d7e5a3
·
verified ·
1 Parent(s): ed068a1

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +209 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # streamlit_tapex_app.py
2
+
3
+ import streamlit as st
4
+ import pandas as pd
5
+ import torch
6
+ from transformers import TapexTokenizer, BartForConditionalGeneration
7
+ import xml.etree.ElementTree as ET
8
+ from io import StringIO
9
+ import logging
10
+ from datetime import datetime
11
+ import time
12
+
13
+ # Configure logging
14
+ logging.basicConfig(
15
+ level=logging.INFO,
16
+ format='%(asctime)s - %(levelname)s - %(message)s'
17
+ )
18
+ logger = logging.getLogger(__name__)
19
+
20
+ @st.cache_resource
21
+ def load_model():
22
+ """
23
+ Load and cache the TAPEX model and tokenizer using Streamlit's caching
24
+ """
25
+ try:
26
+ tokenizer = TapexTokenizer.from_pretrained(
27
+ "microsoft/tapex-large-finetuned-wtq",
28
+ model_max_length=1024
29
+ )
30
+ model = BartForConditionalGeneration.from_pretrained(
31
+ "microsoft/tapex-large-finetuned-wtq"
32
+ )
33
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
34
+ model = model.to(device)
35
+ model.eval()
36
+ return tokenizer, model
37
+ except Exception as e:
38
+ st.error(f"Error loading model: {str(e)}")
39
+ return None, None
40
+
41
+ def parse_xml_to_dataframe(xml_string: str):
42
+ """
43
+ Parse XML string to DataFrame with error handling
44
+ """
45
+ try:
46
+ tree = ET.parse(StringIO(xml_string))
47
+ root = tree.getroot()
48
+
49
+ data = []
50
+ columns = set()
51
+
52
+ # First pass: collect all possible columns
53
+ for record in root.findall('.//record'):
54
+ columns.update(elem.tag for elem in record)
55
+
56
+ # Second pass: create data rows
57
+ for record in root.findall('.//record'):
58
+ row_data = {col: None for col in columns}
59
+ for elem in record:
60
+ row_data[elem.tag] = elem.text
61
+ data.append(row_data)
62
+
63
+ df = pd.DataFrame(data)
64
+
65
+ # Convert numeric columns (automatically detect)
66
+ for col in df.columns:
67
+ try:
68
+ df[col] = pd.to_numeric(df[col])
69
+ except:
70
+ continue
71
+
72
+ return df, None
73
+ except Exception as e:
74
+ return None, f"Error parsing XML: {str(e)}"
75
+
76
+ def process_query(tokenizer, model, df, query: str):
77
+ """
78
+ Process a single query using the TAPEX model
79
+ """
80
+ try:
81
+ start_time = time.time()
82
+
83
+ # Handle direct DataFrame operations for common queries
84
+ query_lower = query.lower()
85
+ if "highest" in query_lower or "maximum" in query_lower:
86
+ for col in df.select_dtypes(include=['number']).columns:
87
+ if col.lower() in query_lower:
88
+ return df.loc[df[col].idxmax()].to_dict()
89
+ elif "average" in query_lower or "mean" in query_lower:
90
+ for col in df.select_dtypes(include=['number']).columns:
91
+ if col.lower() in query_lower:
92
+ return f"Average {col}: {df[col].mean():.2f}"
93
+ elif "total" in query_lower or "sum" in query_lower:
94
+ for col in df.select_dtypes(include=['number']).columns:
95
+ if col.lower() in query_lower:
96
+ return f"Total {col}: {df[col].sum():.2f}"
97
+
98
+ # Use TAPEX for more complex queries
99
+ with torch.no_grad():
100
+ encoding = tokenizer(
101
+ table=df.astype(str),
102
+ query=query,
103
+ return_tensors="pt",
104
+ padding=True,
105
+ truncation=True
106
+ )
107
+ outputs = model.generate(**encoding)
108
+ answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
109
+
110
+ processing_time = time.time() - start_time
111
+ return f"Answer: {answer} (Processing time: {processing_time:.2f}s)"
112
+
113
+ except Exception as e:
114
+ return f"Error processing query: {str(e)}"
115
+
116
+ def main():
117
+ st.title("XML Data Query System")
118
+ st.write("Upload your XML data and ask questions about it!")
119
+
120
+ # Initialize session state for XML input if not exists
121
+ if 'xml_input' not in st.session_state:
122
+ st.session_state.xml_input = ""
123
+
124
+ # Load model
125
+ with st.spinner("Loading TAPEX model... (this may take a few moments)"):
126
+ tokenizer, model = load_model()
127
+ if tokenizer is None or model is None:
128
+ st.error("Failed to load the model. Please refresh the page.")
129
+ return
130
+
131
+ # XML Input
132
+ xml_input = st.text_area(
133
+ "Enter your XML data here:",
134
+ value=st.session_state.xml_input,
135
+ height=200,
136
+ help="Paste your XML data here. Make sure it's properly formatted."
137
+ )
138
+
139
+ # Sample XML button
140
+ if st.button("Load Sample XML"):
141
+ st.session_state.xml_input = """<?xml version="1.0" encoding="UTF-8"?>
142
+ <data>
143
+ <records>
144
+ <record>
145
+ <company>Apple</company>
146
+ <revenue>365.7</revenue>
147
+ <employees>147000</employees>
148
+ <year>2021</year>
149
+ </record>
150
+ <record>
151
+ <company>Microsoft</company>
152
+ <revenue>168.1</revenue>
153
+ <employees>181000</employees>
154
+ <year>2021</year>
155
+ </record>
156
+ <record>
157
+ <company>Amazon</company>
158
+ <revenue>386.1</revenue>
159
+ <employees>1608000</employees>
160
+ <year>2021</year>
161
+ </record>
162
+ </records>
163
+ </data>"""
164
+ st.rerun()
165
+
166
+ if xml_input:
167
+ df, error = parse_xml_to_dataframe(xml_input)
168
+ if error:
169
+ st.error(error)
170
+ else:
171
+ st.success("XML parsed successfully!")
172
+
173
+ # Display DataFrame
174
+ st.subheader("Parsed Data:")
175
+ st.dataframe(df)
176
+
177
+ # Query input
178
+ query = st.text_input(
179
+ "Enter your question about the data:",
180
+ help="Example: 'Which company has the highest revenue?'"
181
+ )
182
+
183
+ # Process query
184
+ if query:
185
+ with st.spinner("Processing query..."):
186
+ result = process_query(tokenizer, model, df, query)
187
+ st.write(result)
188
+
189
+ # Sample queries
190
+ st.subheader("Sample Questions:")
191
+ sample_queries = [
192
+ "Which company has the highest revenue?",
193
+ "What is the average revenue of all companies?",
194
+ "How many employees does Microsoft have?",
195
+ "Which company has the most employees?",
196
+ "What is the total revenue of all companies?"
197
+ ]
198
+
199
+ # Create columns for sample query buttons
200
+ cols = st.columns(len(sample_queries))
201
+ for idx, (col, sample_query) in enumerate(zip(cols, sample_queries)):
202
+ with col:
203
+ if st.button(f"Query {idx + 1}", help=sample_query):
204
+ with st.spinner("Processing query..."):
205
+ result = process_query(tokenizer, model, df, sample_query)
206
+ st.write(result)
207
+
208
+ if __name__ == "__main__":
209
+ main()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ txt
2
+ streamlit>=1.28.0
3
+ pandas>=1.5.0
4
+ torch>=2.0.0
5
+ transformers>=4.34.0