luulinh90s commited on
Commit
9aa727e
·
1 Parent(s): d4f39f7
Files changed (1) hide show
  1. app.py +393 -0
app.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, render_template, request, redirect, url_for, send_from_directory, session
2
+ import json
3
+ import random
4
+ import os
5
+ import string
6
+ import logging
7
+ from datetime import datetime
8
+
9
+
10
+ import os
11
+ from huggingface_hub import login
12
+
13
+ # Use the Hugging Face token from environment variables
14
+ hf_token = os.environ.get("HF_TOKEN")
15
+ if hf_token:
16
+ login(token=hf_token)
17
+ else:
18
+ logger.error("HF_TOKEN not found in environment variables")
19
+
20
+ # Set up logging
21
+ logging.basicConfig(level=logging.INFO,
22
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
23
+ handlers=[
24
+ logging.FileHandler("app.log"),
25
+ logging.StreamHandler()
26
+ ])
27
+ logger = logging.getLogger(__name__)
28
+
29
+ app = Flask(__name__)
30
+ app.config['SECRET_KEY'] = 'supersecretkey' # Change this to a random secret key
31
+
32
+ # Directories for visualizations
33
+ VISUALIZATION_DIRS_PLAN_OF_SQLS = {
34
+ "TP": "htmls_POS/TP",
35
+ "TN": "htmls_POS/TN",
36
+ "FP": "htmls_POS/FP",
37
+ "FN": "htmls_POS/FN"
38
+ }
39
+
40
+ VISUALIZATION_DIRS_CHAIN_OF_TABLE = {
41
+ "TP": "htmls_COT/TP",
42
+ "TN": "htmls_COT/TN",
43
+ "FP": "htmls_COT/FP",
44
+ "FN": "htmls_COT/FN"
45
+ }
46
+
47
+ VISUALIZATION_DIRS_NO_XAI = {
48
+ "TP": "htmls_NO_XAI/TP",
49
+ "TN": "htmls_NO_XAI/TN",
50
+ "FP": "htmls_NO_XAI/FP",
51
+ "FN": "htmls_NO_XAI/FN"
52
+ }
53
+
54
+ VISUALIZATION_DIRS_DATER = {
55
+ "TP": "htmls_DATER/TP",
56
+ "TN": "htmls_DATER/TN",
57
+ "FP": "htmls_DATER/FP",
58
+ "FN": "htmls_DATER/FN"
59
+ }
60
+
61
+ import json
62
+ import os
63
+ from datetime import datetime
64
+ from huggingface_hub import HfApi
65
+
66
+
67
+ def save_session_data(username, data):
68
+ try:
69
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
70
+ file_name = f'{username}_{timestamp}_session.json'
71
+
72
+ # Convert data to JSON string
73
+ json_data = json.dumps(data, indent=4)
74
+
75
+ # Create a temporary file
76
+ temp_file_path = f"/tmp/{file_name}"
77
+ with open(temp_file_path, 'w') as f:
78
+ f.write(json_data)
79
+
80
+ # Upload the file to Hugging Face
81
+ api = HfApi()
82
+ api.upload_file(
83
+ path_or_fileobj=temp_file_path,
84
+ path_in_repo=f"session_data/{file_name}",
85
+ repo_id="luulinh90s/Tabular-LLM-Study-Preference", # Replace with your actual Space name
86
+ repo_type="space",
87
+ )
88
+
89
+ # Remove the temporary file
90
+ os.remove(temp_file_path)
91
+
92
+ logger.info(f"Session data saved for user {username} in Hugging Face Space")
93
+ except Exception as e:
94
+ logger.exception(f"Error saving session data for user {username}: {e}")
95
+
96
+
97
+ from huggingface_hub import hf_hub_download
98
+
99
+
100
+ def load_session_data(username):
101
+ try:
102
+ # List files in the session_data directory
103
+ api = HfApi()
104
+ files = api.list_repo_files(repo_id="luulinh90s/Tabular-LLM-Study-Preference", repo_type="space", path="session_data")
105
+
106
+ # Filter and sort files for the user
107
+ user_files = [f for f in files if f.startswith(f'session_data/{username}_') and f.endswith('_session.json')]
108
+
109
+ if not user_files:
110
+ logger.warning(f"No session data found for user {username}")
111
+ return None
112
+
113
+ # Get the most recent file
114
+ latest_file = sorted(user_files, reverse=True)[0]
115
+
116
+ # Download the file
117
+ file_path = hf_hub_download(repo_id="luulinh90s/Tabular-LLM-Study-Preference", repo_type="space", filename=latest_file)
118
+
119
+ with open(file_path, 'r') as f:
120
+ data = json.load(f)
121
+
122
+ logger.info(f"Session data loaded for user {username} from Hugging Face Space")
123
+ return data
124
+ except Exception as e:
125
+ logger.exception(f"Error loading session data for user {username}: {e}")
126
+ return None
127
+
128
+
129
+ def load_samples(methods):
130
+ logger.info(f"Loading samples for methods: {methods}")
131
+ samples = []
132
+ categories = ["TP", "TN", "FP", "FN"]
133
+
134
+ method_dirs = []
135
+ for method in methods:
136
+ if method == 'No-XAI':
137
+ method_dirs.append('NO_XAI')
138
+ elif method == 'Dater':
139
+ method_dirs.append('DATER')
140
+ elif method == 'Chain-of-Table':
141
+ method_dirs.append('COT')
142
+ elif method == 'Plan-of-SQLs':
143
+ method_dirs.append('POS')
144
+
145
+ for category in categories:
146
+ dir_a = f'htmls_{method_dirs[0].upper()}/{category}'
147
+ dir_b = f'htmls_{method_dirs[1].upper()}/{category}'
148
+
149
+ files_a = set(os.listdir(dir_a))
150
+ files_b = set(os.listdir(dir_b))
151
+
152
+ matching_files = files_a & files_b
153
+
154
+ for file in matching_files:
155
+ samples.append({
156
+ 'category': category,
157
+ 'file': file
158
+ })
159
+
160
+ return samples
161
+
162
+
163
+ def select_balanced_samples(samples):
164
+ try:
165
+ selected_samples = random.sample(samples, min(10, len(samples)))
166
+ logger.info(f"Selected balanced samples: {len(selected_samples)}")
167
+ return selected_samples
168
+ except Exception as e:
169
+ logger.exception("Error selecting balanced samples")
170
+ return []
171
+
172
+ def generate_random_string(length=8):
173
+ return ''.join(random.choices(string.ascii_letters + string.digits, k=length))
174
+
175
+
176
+ @app.route('/', methods=['GET', 'POST'])
177
+ def index():
178
+ logger.info("Rendering index page.")
179
+ if request.method == 'POST':
180
+ username = request.form.get('username')
181
+ seed = request.form.get('seed')
182
+ methods = request.form.get('method').split(',')
183
+
184
+ if not username or not seed or len(methods) != 2:
185
+ logger.error("Missing username, seed, or incorrect number of methods.")
186
+ return "Please fill in all fields and select exactly two methods.", 400
187
+
188
+ try:
189
+ seed = int(seed)
190
+ random.seed(seed)
191
+ all_samples = load_samples(methods)
192
+ selected_samples = select_balanced_samples(all_samples)
193
+ logger.info(f"Number of selected samples: {len(selected_samples)}")
194
+
195
+ if len(selected_samples) == 0:
196
+ logger.error("No samples were selected.")
197
+ return "No samples were selected", 500
198
+
199
+ session_data = {
200
+ 'username': username,
201
+ 'seed': seed,
202
+ 'methods': methods,
203
+ 'selected_samples': selected_samples,
204
+ 'current_index': 0,
205
+ 'responses': [],
206
+ 'start_time': datetime.now().isoformat()
207
+ }
208
+ save_session_data(username, session_data)
209
+ logger.info(f"Session data initialized for user: {username}")
210
+
211
+ return redirect(url_for('experiment', username=username))
212
+ except Exception as e:
213
+ logger.exception(f"Error in index route: {e}")
214
+ return "An error occurred", 500
215
+ return render_template('index.html')
216
+
217
+
218
+ @app.route('/experiment/<username>', methods=['GET', 'POST'])
219
+ def experiment(username):
220
+ try:
221
+ session_data = load_session_data(username)
222
+ if not session_data:
223
+ logger.error(f"No session data found for user: {username}")
224
+ return redirect(url_for('index'))
225
+
226
+ selected_samples = session_data['selected_samples']
227
+ methods = session_data['methods']
228
+ current_index = session_data['current_index']
229
+
230
+ if current_index >= len(selected_samples):
231
+ return redirect(url_for('completed', username=username))
232
+
233
+ sample = selected_samples[current_index]
234
+ method_a, method_b = methods
235
+
236
+ # Find matching files for both methods
237
+ file_a = None
238
+ file_b = None
239
+
240
+ if method_a == 'No-XAI':
241
+ method_a_dir = ('NO_XAI')
242
+ elif method_a == 'Dater':
243
+ method_a_dir = ('DATER')
244
+ elif method_a == 'Chain-of-Table':
245
+ method_a_dir = ('COT')
246
+ elif method_a == 'Plan-of-SQLs':
247
+ method_a_dir = ('POS')
248
+
249
+ if method_b == 'No-XAI':
250
+ method_b_dir = ('NO_XAI')
251
+ elif method_b == 'Dater':
252
+ method_b_dir = ('DATER')
253
+ elif method_b == 'Chain-of-Table':
254
+ method_b_dir = ('COT')
255
+ elif method_b == 'Plan-of-SQLs':
256
+ method_b_dir = ('POS')
257
+
258
+ for category in ['TP', 'TN', 'FP', 'FN']:
259
+ dir_a = f'htmls_{method_a_dir.upper()}/{category}'
260
+ dir_b = f'htmls_{method_b_dir.upper()}/{category}'
261
+
262
+ files_a = os.listdir(dir_a)
263
+ files_b = os.listdir(dir_b)
264
+
265
+ matching_files = set(files_a) & set(files_b)
266
+ if matching_files:
267
+ file_a = os.path.join(dir_a, next(iter(matching_files)))
268
+ file_b = os.path.join(dir_b, next(iter(matching_files)))
269
+ break
270
+
271
+ if not file_a or not file_b:
272
+ logger.error(f"Missing files for comparison at index {current_index}")
273
+ session_data['current_index'] += 1
274
+ save_session_data(username, session_data)
275
+ return redirect(url_for('experiment', username=username))
276
+
277
+ visualization_a = url_for('send_visualization', filename=file_a)
278
+ visualization_b = url_for('send_visualization', filename=file_b)
279
+
280
+ statement = """
281
+ You are given two explanations that describe the reasoning process of the Table QA model.
282
+ Please analyze the explanations and determine which one provides a clearer and more accurate reasoning process.
283
+ """
284
+
285
+ return render_template('experiment.html',
286
+ sample_id=current_index,
287
+ statement=statement,
288
+ visualization_a=visualization_a,
289
+ visualization_b=visualization_b,
290
+ method_a=method_a,
291
+ method_b=method_b,
292
+ username=username)
293
+ except Exception as e:
294
+ logger.exception(f"An error occurred in the experiment route: {e}")
295
+ return "An error occurred", 500
296
+
297
+ def get_visualization_dir(method):
298
+ if method == "No-XAI":
299
+ return 'htmls_NO_XAI'
300
+ elif method == "Dater":
301
+ return 'htmls_DATER'
302
+ elif method == "Chain-of-Table":
303
+ return 'htmls_COT'
304
+ else: # Plan-of-SQLs
305
+ return 'htmls_POS'
306
+
307
+ @app.route('/feedback', methods=['POST'])
308
+ def feedback():
309
+ try:
310
+ username = request.form['username']
311
+ feedback = request.form['feedback']
312
+
313
+ session_data = load_session_data(username)
314
+ if not session_data:
315
+ logger.error(f"No session data found for user: {username}")
316
+ return redirect(url_for('index'))
317
+
318
+ # Store the feedback
319
+ session_data['responses'].append({
320
+ 'sample_id': session_data['current_index'],
321
+ 'preferred_method': feedback,
322
+ 'timestamp': datetime.now().isoformat()
323
+ })
324
+
325
+ # Move to the next sample
326
+ session_data['current_index'] += 1
327
+
328
+ # Save updated session data
329
+ save_session_data(username, session_data)
330
+ logger.info(f"Feedback saved for user {username}, sample {session_data['current_index'] - 1}")
331
+
332
+ if session_data['current_index'] >= len(session_data['selected_samples']):
333
+ return redirect(url_for('completed', username=username))
334
+
335
+ return redirect(url_for('experiment', username=username))
336
+ except Exception as e:
337
+ logger.exception(f"Error in feedback route: {e}")
338
+ return "An error occurred", 500
339
+
340
+
341
+ @app.route('/completed/<username>')
342
+ def completed(username):
343
+ try:
344
+ session_data = load_session_data(username)
345
+ if not session_data:
346
+ logger.error(f"No session data found for user: {username}")
347
+ return redirect(url_for('index'))
348
+
349
+ session_data['end_time'] = datetime.now().isoformat()
350
+
351
+ methods = session_data['methods']
352
+ responses = session_data['responses']
353
+
354
+ preferences = {method: 0 for method in methods}
355
+ total_responses = len(responses)
356
+
357
+ for response in responses:
358
+ preferred_method = response['preferred_method']
359
+ preferences[preferred_method] += 1
360
+
361
+ for method in preferences:
362
+ preferences[method] = round((preferences[method] / total_responses) * 100, 2)
363
+
364
+ session_data['preferences'] = preferences
365
+ save_session_data(username, session_data)
366
+
367
+ return render_template('completed.html', preferences=preferences)
368
+ except Exception as e:
369
+ logger.exception(f"An error occurred in the completed route: {e}")
370
+ return "An error occurred", 500
371
+
372
+
373
+ @app.route('/visualizations/<path:filename>')
374
+ def send_visualization(filename):
375
+ logger.info(f"Attempting to serve file: {filename}")
376
+ # Ensure the path is safe and doesn't allow access to files outside the intended directory
377
+ base_dir = os.getcwd()
378
+ file_path = os.path.normpath(os.path.join(base_dir, filename))
379
+ if not file_path.startswith(base_dir):
380
+ return "Access denied", 403
381
+
382
+ if not os.path.exists(file_path):
383
+ return "File not found", 404
384
+
385
+ directory = os.path.dirname(file_path)
386
+ file_name = os.path.basename(file_path)
387
+ logger.info(f"Serving file from directory: {directory}, filename: {file_name}")
388
+ return send_from_directory(directory, file_name)
389
+
390
+
391
+ if __name__ == "__main__":
392
+ os.makedirs('session_data', exist_ok=True) # Ensure the directory for session files exists
393
+ app.run(host="0.0.0.0", port=7860, debug=True)