luulinh90s commited on
Commit
3909e19
·
1 Parent(s): 479b115
Files changed (1) hide show
  1. app.py +182 -154
app.py CHANGED
@@ -4,6 +4,16 @@ import random
4
  import os
5
  import string
6
  from flask_session import Session
 
 
 
 
 
 
 
 
 
 
7
 
8
  app = Flask(__name__)
9
  app.config['SECRET_KEY'] = 'supersecretkey' # Change this to a random secret key
@@ -28,6 +38,7 @@ VISUALIZATION_DIRS_CHAIN_OF_TABLE = {
28
 
29
  # Load all sample files from the directories based on the selected method
30
  def load_samples(method):
 
31
  if method == "Chain-of-Table":
32
  visualization_dirs = VISUALIZATION_DIRS_CHAIN_OF_TABLE
33
  else:
@@ -35,17 +46,26 @@ def load_samples(method):
35
 
36
  samples = {"TP": [], "TN": [], "FP": [], "FN": []}
37
  for category, dir_path in visualization_dirs.items():
38
- for filename in os.listdir(dir_path):
39
- if filename.endswith(".html"):
40
- samples[category].append(filename)
 
 
 
 
41
  return samples
42
 
43
 
44
  # Randomly select balanced samples
45
  def select_balanced_samples(samples):
46
- tp_fp_samples = random.sample(samples["TP"] + samples["FP"], 5)
47
- tn_fn_samples = random.sample(samples["TN"] + samples["FN"], 5)
48
- return tp_fp_samples + tn_fn_samples
 
 
 
 
 
49
 
50
 
51
  def generate_random_string(length=8):
@@ -54,195 +74,203 @@ def generate_random_string(length=8):
54
 
55
  @app.route('/', methods=['GET', 'POST'])
56
  def index():
 
57
  if request.method == 'POST':
58
  username = request.form.get('username')
59
  seed = request.form.get('seed')
60
  method = request.form.get('method')
61
 
62
  if not username or not seed or not method:
 
63
  return "Missing username, seed, or method", 400
64
 
65
- seed = int(seed)
66
- random.seed(seed)
67
- all_samples = load_samples(method)
68
- selected_samples = select_balanced_samples(all_samples)
69
- random_string = generate_random_string()
70
- filename = f'{username}_{seed}_{method}_{random_string}.json' # Append method to filename
 
 
 
71
 
72
- session['selected_samples'] = selected_samples
73
- session['responses'] = [] # Initialize responses list
74
- session['method'] = method # Store the selected method
75
 
76
- return redirect(url_for('experiment', username=username, sample_index=0, seed=seed, filename=filename))
 
 
 
77
  return render_template('index.html')
78
 
79
 
80
  @app.route('/experiment/<username>/<sample_index>/<seed>/<filename>', methods=['GET'])
81
  def experiment(username, sample_index, seed, filename):
82
- sample_index = int(sample_index)
83
- selected_samples = session.get('selected_samples', [])
84
- method = session.get('method') # Retrieve the selected method
 
85
 
86
- if sample_index >= len(selected_samples):
87
- return redirect(url_for('completed', filename=filename))
88
 
89
- visualization_file = selected_samples[sample_index]
90
- visualization_path = None
91
 
92
- # Determine the correct visualization directory based on the method
93
- if method == "Chain-of-Table":
94
- visualization_dirs = VISUALIZATION_DIRS_CHAIN_OF_TABLE
95
- else:
96
- visualization_dirs = VISUALIZATION_DIRS_PLAN_OF_SQLS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- # Find the correct visualization path
99
- for category, dir_path in visualization_dirs.items():
100
- if visualization_file in os.listdir(dir_path):
101
- visualization_path = f"{category}/{visualization_file}"
102
- break
103
-
104
- if not visualization_path:
105
- return "Visualization file not found", 404
106
-
107
- statement = "Please make a decision to Accept/Reject the AI prediction based on the explanation."
108
- return render_template('experiment.html',
109
- sample_id=sample_index,
110
- statement=statement,
111
- visualization=visualization_path,
112
- username=username,
113
- seed=seed,
114
- sample_index=sample_index,
115
- filename=filename)
116
 
117
  @app.route('/visualizations/<path:path>')
118
  def send_visualization(path):
119
- # Determine which visualization folder to use based on the selected method
120
- method = session.get('method')
121
- if method == "Chain-of-Table":
122
- visualization_dir = 'htmls_COT'
123
- else: # Default to Plan-of-SQLs
124
- visualization_dir = 'visualizations'
125
 
126
- # Serve the file from the appropriate directory
127
- return send_from_directory(visualization_dir, path)
 
 
128
 
129
 
130
  @app.route('/feedback', methods=['POST'])
131
  def feedback():
132
- sample_id = request.form['sample_id']
133
- feedback = request.form['feedback']
134
- username = request.form['username']
135
- seed = request.form['seed']
136
- sample_index = int(request.form['sample_index'])
137
- filename = request.form['filename']
138
-
139
- selected_samples = session.get('selected_samples', [])
140
- responses = session.get('responses', [])
141
-
142
- # Store the feedback
143
- responses.append({
144
- 'sample_id': sample_id,
145
- 'feedback': feedback
146
- })
147
- session['responses'] = responses
148
-
149
- # Create the result directory if it doesn't exist
150
- result_dir = 'human_study'
151
- os.makedirs(result_dir, exist_ok=True)
152
-
153
- # Load existing data if the JSON file exists
154
- filepath = os.path.join(result_dir, filename)
155
- if os.path.exists(filepath):
156
- with open(filepath, 'r') as f:
157
- data = json.load(f)
158
- else:
159
- data = {}
160
 
161
- # Update data with the current feedback
162
- data[sample_index] = {
163
- 'Username': username,
164
- 'Seed': seed,
165
- 'Sample ID': sample_id,
166
- 'Task': f"Please make a decision to Accept/Reject the AI prediction based on the explanation.",
167
- 'User Feedback': feedback
168
- }
169
 
170
- # Save updated data to the file
171
- with open(filepath, 'w') as f:
172
- json.dump(data, f, indent=4)
173
 
174
- next_sample_index = sample_index + 1
175
- if next_sample_index >= len(selected_samples):
176
- return redirect(url_for('completed', filename=filename))
 
 
 
 
 
 
 
 
177
 
178
- return redirect(
179
- url_for('experiment', username=username, sample_index=next_sample_index, seed=seed, filename=filename))
180
 
181
  @app.route('/completed/<filename>')
182
  def completed(filename):
183
- # Load responses from the session
184
- responses = session.get('responses', [])
185
-
186
- # Determine which JSON file to load based on the method
187
- method = session.get('method')
188
- if method == "Chain-of-Table":
189
- json_file = 'Tabular_LLMs_human_study_vis_6_COT.json'
190
- else: # Default to Plan-of-SQLs
191
- json_file = 'Tabular_LLMs_human_study_vis_6.json'
192
-
193
- # Load the ground truth data from the appropriate JSON file
194
- with open(json_file, 'r') as f:
195
- ground_truth = json.load(f)
196
-
197
- # Initialize counters
198
- correct_responses = 0
199
- accept_count = 0
200
- reject_count = 0
201
-
202
- for response in responses:
203
- sample_id = response['sample_id']
204
- feedback = response['feedback']
205
- index = sample_id.split('-')[1].split('.')[0] # Extract index from filename
206
-
207
- # Count the feedback
208
- if feedback.upper() == "TRUE":
209
- accept_count += 1
210
- elif feedback.upper() == "FALSE":
211
- reject_count += 1
212
-
213
- # Construct the ground truth key
214
  if method == "Chain-of-Table":
215
- ground_truth_key = f"COT_test-{index}.html" # Adjust this based on your actual key format in the CoTable JSON
216
- else:
217
- ground_truth_key = f"POS_test-{index}.html"
218
 
219
- # Check if the key exists in the ground truth data
220
- if ground_truth_key in ground_truth and ground_truth[ground_truth_key]['answer'].upper() == feedback.upper():
221
- correct_responses += 1
222
- else:
223
- print(f"Missing or mismatched key: {ground_truth_key}")
224
 
225
- # Calculate accuracy
226
- accuracy = (correct_responses / len(responses)) * 100 if responses else 0
227
- accuracy = round(accuracy, 2)
228
 
229
- # Calculate percentages
230
- total_responses = len(responses)
231
- accept_percentage = (accept_count / total_responses) * 100 if total_responses else 0
232
- reject_percentage = (reject_count / total_responses) * 100 if total_responses else 0
233
 
234
- # Round percentages
235
- accept_percentage = round(accept_percentage, 2)
236
- reject_percentage = round(reject_percentage, 2)
 
237
 
238
- return render_template('completed.html',
239
- accuracy=accuracy,
240
- accept_percentage=accept_percentage,
241
- reject_percentage=reject_percentage)
242
 
 
 
 
 
 
243
 
244
- if __name__ == '__main__':
245
- # app.run(debug=True, port=8080)
 
 
 
 
 
 
246
 
247
- # change for running on HuggingFace
 
 
 
 
 
 
 
 
 
248
  app.run(debug=False, port=7860)
 
4
  import os
5
  import string
6
  from flask_session import Session
7
+ import logging
8
+
9
+ # Set up logging
10
+ logging.basicConfig(level=logging.INFO,
11
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
12
+ handlers=[
13
+ logging.FileHandler("app.log"),
14
+ logging.StreamHandler()
15
+ ])
16
+ logger = logging.getLogger(__name__)
17
 
18
  app = Flask(__name__)
19
  app.config['SECRET_KEY'] = 'supersecretkey' # Change this to a random secret key
 
38
 
39
  # Load all sample files from the directories based on the selected method
40
  def load_samples(method):
41
+ logger.info(f"Loading samples for method: {method}")
42
  if method == "Chain-of-Table":
43
  visualization_dirs = VISUALIZATION_DIRS_CHAIN_OF_TABLE
44
  else:
 
46
 
47
  samples = {"TP": [], "TN": [], "FP": [], "FN": []}
48
  for category, dir_path in visualization_dirs.items():
49
+ try:
50
+ for filename in os.listdir(dir_path):
51
+ if filename.endswith(".html"):
52
+ samples[category].append(filename)
53
+ logger.info(f"Loaded {len(samples[category])} samples for category {category}")
54
+ except Exception as e:
55
+ logger.exception(f"Error loading samples from {dir_path}: {e}")
56
  return samples
57
 
58
 
59
  # Randomly select balanced samples
60
  def select_balanced_samples(samples):
61
+ try:
62
+ tp_fp_samples = random.sample(samples["TP"] + samples["FP"], 5)
63
+ tn_fn_samples = random.sample(samples["TN"] + samples["FN"], 5)
64
+ logger.info(f"Selected balanced samples: {len(tp_fp_samples + tn_fn_samples)}")
65
+ return tp_fp_samples + tn_fn_samples
66
+ except Exception as e:
67
+ logger.exception("Error selecting balanced samples")
68
+ return []
69
 
70
 
71
  def generate_random_string(length=8):
 
74
 
75
  @app.route('/', methods=['GET', 'POST'])
76
  def index():
77
+ logger.info("Rendering index page.")
78
  if request.method == 'POST':
79
  username = request.form.get('username')
80
  seed = request.form.get('seed')
81
  method = request.form.get('method')
82
 
83
  if not username or not seed or not method:
84
+ logger.error("Missing username, seed, or method.")
85
  return "Missing username, seed, or method", 400
86
 
87
+ try:
88
+ seed = int(seed)
89
+ random.seed(seed)
90
+ all_samples = load_samples(method)
91
+ selected_samples = select_balanced_samples(all_samples)
92
+ random_string = generate_random_string()
93
+ filename = f'{username}_{seed}_{method}_{random_string}.json'
94
+
95
+ logger.info(f"Generated filename: {filename}")
96
 
97
+ session['selected_samples'] = selected_samples
98
+ session['responses'] = [] # Initialize responses list
99
+ session['method'] = method # Store the selected method
100
 
101
+ return redirect(url_for('experiment', username=username, sample_index=0, seed=seed, filename=filename))
102
+ except Exception as e:
103
+ logger.exception(f"Error in index route: {e}")
104
+ return "An error occurred", 500
105
  return render_template('index.html')
106
 
107
 
108
  @app.route('/experiment/<username>/<sample_index>/<seed>/<filename>', methods=['GET'])
109
  def experiment(username, sample_index, seed, filename):
110
+ try:
111
+ sample_index = int(sample_index)
112
+ selected_samples = session.get('selected_samples', [])
113
+ method = session.get('method') # Retrieve the selected method
114
 
115
+ if sample_index >= len(selected_samples):
116
+ return redirect(url_for('completed', filename=filename))
117
 
118
+ visualization_file = selected_samples[sample_index]
119
+ visualization_path = None
120
 
121
+ # Determine the correct visualization directory based on the method
122
+ if method == "Chain-of-Table":
123
+ visualization_dirs = VISUALIZATION_DIRS_CHAIN_OF_TABLE
124
+ else:
125
+ visualization_dirs = VISUALIZATION_DIRS_PLAN_OF_SQLS
126
+
127
+ # Find the correct visualization path
128
+ for category, dir_path in visualization_dirs.items():
129
+ if visualization_file in os.listdir(dir_path):
130
+ visualization_path = f"{category}/{visualization_file}"
131
+ break
132
+
133
+ if not visualization_path:
134
+ logger.error("Visualization file not found.")
135
+ return "Visualization file not found", 404
136
+
137
+ statement = "Please make a decision to Accept/Reject the AI prediction based on the explanation."
138
+ return render_template('experiment.html',
139
+ sample_id=sample_index,
140
+ statement=statement,
141
+ visualization=visualization_path,
142
+ username=username,
143
+ seed=seed,
144
+ sample_index=sample_index,
145
+ filename=filename)
146
+ except Exception as e:
147
+ logger.exception(f"An error occurred in the experiment route: {e}")
148
+ return "An error occurred", 500
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
  @app.route('/visualizations/<path:path>')
152
  def send_visualization(path):
153
+ try:
154
+ method = session.get('method')
155
+ if method == "Chain-of-Table":
156
+ visualization_dir = 'htmls_COT'
157
+ else: # Default to Plan-of-SQLs
158
+ visualization_dir = 'visualizations'
159
 
160
+ return send_from_directory(visualization_dir, path)
161
+ except Exception as e:
162
+ logger.exception(f"Error sending visualization: {e}")
163
+ return "An error occurred", 500
164
 
165
 
166
  @app.route('/feedback', methods=['POST'])
167
  def feedback():
168
+ try:
169
+ sample_id = request.form['sample_id']
170
+ feedback = request.form['feedback']
171
+ username = request.form['username']
172
+ seed = request.form['seed']
173
+ sample_index = int(request.form['sample_index'])
174
+ filename = request.form['filename']
175
+
176
+ selected_samples = session.get('selected_samples', [])
177
+ responses = session.get('responses', [])
178
+
179
+ responses.append({
180
+ 'sample_id': sample_id,
181
+ 'feedback': feedback
182
+ })
183
+ session['responses'] = responses
184
+
185
+ result_dir = 'human_study'
186
+ os.makedirs(result_dir, exist_ok=True)
187
+
188
+ filepath = os.path.join(result_dir, filename)
189
+ if os.path.exists(filepath):
190
+ with open(filepath, 'r') as f:
191
+ data = json.load(f)
192
+ else:
193
+ data = {}
 
 
194
 
195
+ data[sample_index] = {
196
+ 'Username': username,
197
+ 'Seed': seed,
198
+ 'Sample ID': sample_id,
199
+ 'Task': f"Please make a decision to Accept/Reject the AI prediction based on the explanation.",
200
+ 'User Feedback': feedback
201
+ }
 
202
 
203
+ with open(filepath, 'w') as f:
204
+ json.dump(data, f, indent=4)
 
205
 
206
+ logger.info(f"Feedback saved for sample {sample_id}")
207
+
208
+ next_sample_index = sample_index + 1
209
+ if next_sample_index >= len(selected_samples):
210
+ return redirect(url_for('completed', filename=filename))
211
+
212
+ return redirect(
213
+ url_for('experiment', username=username, sample_index=next_sample_index, seed=seed, filename=filename))
214
+ except Exception as e:
215
+ logger.exception(f"Error in feedback route: {e}")
216
+ return "An error occurred", 500
217
 
 
 
218
 
219
  @app.route('/completed/<filename>')
220
  def completed(filename):
221
+ try:
222
+ responses = session.get('responses', [])
223
+ method = session.get('method')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  if method == "Chain-of-Table":
225
+ json_file = 'Tabular_LLMs_human_study_vis_6_COT.json'
226
+ else: # Default to Plan-of-SQLs
227
+ json_file = 'Tabular_LLMs_human_study_vis_6.json'
228
 
229
+ with open(json_file, 'r') as f:
230
+ ground_truth = json.load(f)
 
 
 
231
 
232
+ correct_responses = 0
233
+ accept_count = 0
234
+ reject_count = 0
235
 
236
+ for response in responses:
237
+ sample_id = response['sample_id']
238
+ feedback = response['feedback']
239
+ index = sample_id.split('-')[1].split('.')[0] # Extract index from filename
240
 
241
+ if feedback.upper() == "TRUE":
242
+ accept_count += 1
243
+ elif feedback.upper() == "FALSE":
244
+ reject_count += 1
245
 
246
+ if method == "Chain-of-Table":
247
+ ground_truth_key = f"COT_test-{index}.html"
248
+ else:
249
+ ground_truth_key = f"POS_test-{index}.html"
250
 
251
+ if ground_truth_key in ground_truth and ground_truth[ground_truth_key][
252
+ 'answer'].upper() == feedback.upper():
253
+ correct_responses += 1
254
+ else:
255
+ logger.warning(f"Missing or mismatched key: {ground_truth_key}")
256
 
257
+ accuracy = (correct_responses / len(responses)) * 100 if responses else 0
258
+ accuracy = round(accuracy, 2)
259
+
260
+ accept_percentage = (accept_count / len(responses)) * 100 if len(responses) else 0
261
+ reject_percentage = (reject_count / len(responses)) * 100 if len(responses) else 0
262
+
263
+ accept_percentage = round(accept_percentage, 2)
264
+ reject_percentage = round(reject_percentage, 2)
265
 
266
+ return render_template('completed.html',
267
+ accuracy=accuracy,
268
+ accept_percentage=accept_percentage,
269
+ reject_percentage=reject_percentage)
270
+ except Exception as e:
271
+ logger.exception(f"Error in completed route: {e}")
272
+ return "An error occurred", 500
273
+
274
+
275
+ if __name__ == '__main__':
276
  app.run(debug=False, port=7860)