jdev8 commited on
Commit
44e2555
·
verified ·
1 Parent(s): 9eff11a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +250 -82
app.py CHANGED
@@ -6,6 +6,7 @@ import datetime
6
  import logging
7
  from huggingface_hub import hf_hub_download, upload_file, list_repo_tree
8
  from dotenv import load_dotenv
 
9
 
10
  load_dotenv()
11
 
@@ -15,22 +16,26 @@ HF_INPUT_DATASET_PATH = os.getenv("HF_INPUT_DATASET_PATH")
15
  HF_INPUT_DATASET_ID_COLUMN = os.getenv("HF_INPUT_DATASET_ID_COLUMN")
16
  HF_INPUT_DATASET_COLUMN_A = os.getenv("HF_INPUT_DATASET_COLUMN_A")
17
  HF_INPUT_DATASET_COLUMN_B = os.getenv("HF_INPUT_DATASET_COLUMN_B")
 
18
  HF_OUTPUT_DATASET = os.getenv("HF_OUTPUT_DATASET")
19
  HF_OUTPUT_DATASET_DIR = os.getenv("HF_OUTPUT_DATASET_DIR")
20
-
21
  INSTRUCTIONS = """
22
  # Pairwise Model Output Labeling
23
  Please compare the two model outputs shown below and select which one you think is better.
24
- - Choose "Left is better" if the left output is superior
25
- - Choose "Right is better" if the right output is superior
26
- - Choose "Tie" if they are equally good or bad
27
  - Choose "Can't choose" if you cannot make a determination
28
  """
29
-
 
 
30
  class PairwiseLabeler:
31
  def __init__(self):
 
 
32
  self.df = self.read_hf_dataset()
33
- self.results = {}
34
 
35
  def __len__(self):
36
  return len(self.df)
@@ -41,7 +46,7 @@ class PairwiseLabeler:
41
  if local_file.endswith(".json"):
42
  return pd.read_json(local_file)
43
  elif local_file.endswith(".jsonl"):
44
- return pd.read_json(local_file, orient="records", lines=True)
45
  elif local_file.endswith(".csv"):
46
  return pd.read_csv(local_file)
47
  elif local_file.endswith(".parquet"):
@@ -49,112 +54,275 @@ class PairwiseLabeler:
49
  else:
50
  raise ValueError(f"Unsupported file type: {local_file}")
51
  except Exception as e:
 
52
  logging.error(f"Couldn't read HF dataset from {HF_INPUT_DATASET_PATH}. Using sample data instead.")
53
  sample_data = {
54
- HF_INPUT_DATASET_ID_COLUMN: [f"sample_{i}" for i in range(5)],
55
- HF_INPUT_DATASET_COLUMN_A: [f"This is sample generation A {i}" for i in range(5)],
56
- HF_INPUT_DATASET_COLUMN_B: [f"This is sample generation B {i}" for i in range(5)],
57
  }
 
 
 
 
 
58
  return pd.DataFrame(sample_data)
59
-
60
- def get_current_pair(self, user_id, user_index):
61
- if user_index >= len(self.df):
62
- return None, None, None
63
-
64
- item = self.df.iloc[user_index]
65
- item_id = item.get(HF_INPUT_DATASET_ID_COLUMN, f"item_{user_index}")
 
 
 
66
  left_text = item.get(HF_INPUT_DATASET_COLUMN_A, "")
67
  right_text = item.get(HF_INPUT_DATASET_COLUMN_B, "")
68
 
69
- return item_id, left_text, right_text
70
-
71
- def submit_judgment(self, user_id, user_index, item_id, left_text, right_text, choice):
 
 
 
 
72
  if item_id is None:
73
- return None, None, None, user_index
 
 
 
74
 
75
- # Store user votes uniquely
76
- if user_id not in self.results:
77
- self.results[user_id] = []
78
-
79
- # Check if user already voted for this item
80
- existing_vote = next((r for r in self.results[user_id] if r["item_id"] == item_id), None)
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- if existing_vote:
83
- existing_vote["judgment"] = choice
84
- existing_vote["timestamp"] = datetime.datetime.now().isoformat()
 
 
 
 
 
85
  else:
86
- self.results[user_id].append({
87
- "item_id": item_id,
88
- "generation_a": left_text,
89
- "generation_b": right_text,
90
- "judgment": choice,
91
- "timestamp": datetime.datetime.now().isoformat(),
92
- "labeler_id": user_id
93
- })
94
-
95
- # Save immediately
96
- self.save_results(user_id)
97
-
98
- # Move to the next item
99
- user_index += 1
100
- next_id, next_left, next_right = self.get_current_pair(user_id, user_index)
101
- return next_id, next_left, next_right, user_index
102
-
103
- def save_results(self, user_id):
104
- if user_id not in self.results or not self.results[user_id]:
105
  return
106
-
107
  try:
108
- results_df = pd.DataFrame(self.results[user_id])
109
- filename = f"results_{user_id}.jsonl"
110
- results_df.to_json(filename, orient="records", lines=True)
111
-
112
  # Push to Hugging Face Hub
113
- upload_file(repo_id=HF_OUTPUT_DATASET, repo_type="dataset",
114
- path_in_repo=os.path.join(HF_OUTPUT_DATASET_DIR, filename),
115
- path_or_fileobj=filename)
116
-
117
- os.remove(filename)
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  except Exception as e:
119
  logging.error(f"Error saving results: {e}")
 
120
 
121
  # Initialize the labeler
122
  labeler = PairwiseLabeler()
123
 
124
- # Gradio UI
 
 
 
125
  with gr.Blocks() as app:
 
 
 
 
126
  gr.Markdown(INSTRUCTIONS)
 
 
 
 
 
127
 
128
- user_id = gr.Textbox(label="Enter your user ID", interactive=True)
129
- user_index = gr.State(0) # Track each user's progress
130
-
131
  with gr.Row():
132
  with gr.Column():
133
- left_output = gr.Textbox(label="Model Output A", lines=10, interactive=False)
 
 
 
 
 
134
  with gr.Column():
135
- right_output = gr.Textbox(label="Model Output B", lines=10, interactive=False)
 
 
 
 
136
 
137
  item_id = gr.Textbox(visible=False)
138
 
139
  with gr.Row():
140
- left_btn = gr.Button("⬅️ A is better")
141
- right_btn = gr.Button("➡️ B is better")
142
- tie_btn = gr.Button("🤝 Tie")
143
  cant_choose_btn = gr.Button("🤔 Can't choose")
144
 
145
- def load_first_pair(user_id):
146
- if not user_id:
147
- return None, None, None, 0
148
- return labeler.get_current_pair(user_id, 0) + (0,)
149
-
150
- def judge(choice, user_id, user_index, item_id, left_text, right_text):
151
- return labeler.submit_judgment(user_id, user_index, item_id, left_text, right_text, choice)
152
-
153
- user_id.submit(load_first_pair, inputs=[user_id], outputs=[item_id, left_output, right_output, user_index])
154
- left_btn.click(judge, inputs=[gr.State("A is better"), user_id, user_index, item_id, left_output, right_output], outputs=[item_id, left_output, right_output, user_index])
155
- right_btn.click(judge, inputs=[gr.State("B is better"), user_id, user_index, item_id, left_output, right_output], outputs=[item_id, left_output, right_output, user_index])
156
- tie_btn.click(judge, inputs=[gr.State("Tie"), user_id, user_index, item_id, left_output, right_output], outputs=[item_id, left_output, right_output, user_index])
157
- cant_choose_btn.click(judge, inputs=[gr.State("Can't choose"), user_id, user_index, item_id, left_output, right_output], outputs=[item_id, left_output, right_output, user_index])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  if __name__ == "__main__":
160
- app.launch()
 
6
  import logging
7
  from huggingface_hub import hf_hub_download, upload_file, list_repo_tree
8
  from dotenv import load_dotenv
9
+ from collections import defaultdict
10
 
11
  load_dotenv()
12
 
 
16
  HF_INPUT_DATASET_ID_COLUMN = os.getenv("HF_INPUT_DATASET_ID_COLUMN")
17
  HF_INPUT_DATASET_COLUMN_A = os.getenv("HF_INPUT_DATASET_COLUMN_A")
18
  HF_INPUT_DATASET_COLUMN_B = os.getenv("HF_INPUT_DATASET_COLUMN_B")
19
+ HF_INPUT_DATASET_URL_COLUMN = os.getenv("HF_INPUT_DATASET_URL_COLUMN")
20
  HF_OUTPUT_DATASET = os.getenv("HF_OUTPUT_DATASET")
21
  HF_OUTPUT_DATASET_DIR = os.getenv("HF_OUTPUT_DATASET_DIR")
 
22
  INSTRUCTIONS = """
23
  # Pairwise Model Output Labeling
24
  Please compare the two model outputs shown below and select which one you think is better.
25
+ - Choose "A is better" if the output from Model A (left) is superior
26
+ - Choose "B is better" if the output from Model B (right) is superior
27
+ - Choose "Tie" if you think they are equally good or bad
28
  - Choose "Can't choose" if you cannot make a determination
29
  """
30
+ SAVE_EVERY_N_EXAMPLES = 5
31
+
32
+
33
  class PairwiseLabeler:
34
  def __init__(self):
35
+ self.current_index = defaultdict(int)
36
+ self.results = defaultdict(list)
37
  self.df = self.read_hf_dataset()
38
+ self.has_url_column = HF_INPUT_DATASET_URL_COLUMN and HF_INPUT_DATASET_URL_COLUMN in self.df.columns
39
 
40
  def __len__(self):
41
  return len(self.df)
 
46
  if local_file.endswith(".json"):
47
  return pd.read_json(local_file)
48
  elif local_file.endswith(".jsonl"):
49
+ return pd.read_json(local_file, orient="records",lines=True)
50
  elif local_file.endswith(".csv"):
51
  return pd.read_csv(local_file)
52
  elif local_file.endswith(".parquet"):
 
54
  else:
55
  raise ValueError(f"Unsupported file type: {local_file}")
56
  except Exception as e:
57
+ # Fallback to sample data if loading fails
58
  logging.error(f"Couldn't read HF dataset from {HF_INPUT_DATASET_PATH}. Using sample data instead.")
59
  sample_data = {
60
+ HF_INPUT_DATASET_ID_COLUMN: [f"sample_{i}" for i in range(SAVE_EVERY_N_EXAMPLES)],
61
+ HF_INPUT_DATASET_COLUMN_A: [f"This is sample generation A {i}" for i in range(SAVE_EVERY_N_EXAMPLES)],
62
+ HF_INPUT_DATASET_COLUMN_B: [f"This is sample generation B {i}" for i in range(SAVE_EVERY_N_EXAMPLES)],
63
  }
64
+
65
+ # Add URL column to sample data if specified
66
+ if HF_INPUT_DATASET_URL_COLUMN:
67
+ sample_data[HF_INPUT_DATASET_URL_COLUMN] = [f"https://example.com/sample_{i}" for i in range(SAVE_EVERY_N_EXAMPLES)]
68
+
69
  return pd.DataFrame(sample_data)
70
+
71
+ def get_current_pair(self, session_id):
72
+ if self.current_index[session_id] >= len(self.df):
73
+ if self.has_url_column:
74
+ return None, None, None, None
75
+ else:
76
+ return None, None, None
77
+
78
+ item = self.df.iloc[self.current_index[session_id]]
79
+ item_id = item.get(HF_INPUT_DATASET_ID_COLUMN, f"item_{self.current_index[session_id]}")
80
  left_text = item.get(HF_INPUT_DATASET_COLUMN_A, "")
81
  right_text = item.get(HF_INPUT_DATASET_COLUMN_B, "")
82
 
83
+ if self.has_url_column:
84
+ url = item.get(HF_INPUT_DATASET_URL_COLUMN, "")
85
+ return item_id, left_text, right_text, url
86
+ else:
87
+ return item_id, left_text, right_text
88
+
89
+ def submit_judgment(self, item_id, left_text, right_text, choice, session_id):
90
  if item_id is None:
91
+ if self.has_url_column:
92
+ return item_id, left_text, right_text, None, self.current_index[session_id]
93
+ else:
94
+ return item_id, left_text, right_text, self.current_index[session_id]
95
 
96
+ # Get the current URL if available
97
+ current_url = None
98
+ if self.has_url_column:
99
+ current_url = self.df.iloc[self.current_index[session_id]].get(HF_INPUT_DATASET_URL_COLUMN, "")
100
+
101
+ # Record the judgment
102
+ result = {
103
+ "item_id": item_id,
104
+ "judgment": choice,
105
+ "timestamp": datetime.datetime.now().isoformat(),
106
+ "labeler_id": session_id
107
+ }
108
+
109
+ self.results[session_id].append(result)
110
+
111
+ # Move to next item
112
+ self.current_index[session_id] += 1
113
 
114
+ # Save results periodically
115
+ if len(self.results[session_id]) % SAVE_EVERY_N_EXAMPLES == 0:
116
+ self.save_results(session_id)
117
+
118
+ # Get next pair
119
+ if self.has_url_column:
120
+ next_id, next_left, next_right, next_url = self.get_current_pair(session_id)
121
+ return next_id, next_left, next_right, next_url, self.current_index[session_id]
122
  else:
123
+ next_id, next_left, next_right = self.get_current_pair(session_id)
124
+ return next_id, next_left, next_right, self.current_index[session_id]
125
+
126
+ def save_results(self, session_id):
127
+ if not self.results[session_id]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  return
129
+
130
  try:
131
+ # Convert results to dataset format
132
+ results_df = pd.DataFrame(self.results[session_id])
133
+ results_df.to_json("temp.jsonl", orient="records", lines=True)
134
+
135
  # Push to Hugging Face Hub
136
+ try:
137
+ num_files = len([_ for _ in list_repo_tree(repo_id=HF_OUTPUT_DATASET, repo_type="dataset", path_in_repo=HF_OUTPUT_DATASET_DIR) if session_id in _.path])
138
+ except Exception as e:
139
+ num_files = 0
140
+
141
+ # Use session_id in filename to avoid conflicts
142
+ filename = f"results_{session_id}_{num_files+1}.jsonl"
143
+ upload_file(
144
+ repo_id=HF_OUTPUT_DATASET,
145
+ repo_type="dataset",
146
+ path_in_repo=os.path.join(HF_OUTPUT_DATASET_DIR, filename),
147
+ path_or_fileobj="temp.jsonl"
148
+ )
149
+ os.remove("temp.jsonl")
150
+
151
+ # Clear saved results
152
+ self.results[session_id] = []
153
+ logging.info(f"Saved results for session {session_id} to {HF_OUTPUT_DATASET}/{filename}")
154
  except Exception as e:
155
  logging.error(f"Error saving results: {e}")
156
+ # Keep results in memory to try saving again later
157
 
158
  # Initialize the labeler
159
  labeler = PairwiseLabeler()
160
 
161
+ # Create a unique session ID
162
+ def create_new_session():
163
+ return str(uuid.uuid4())[:8]
164
+
165
  with gr.Blocks() as app:
166
+ # State for the session ID
167
+ session_id = gr.State(value=None)
168
+
169
+ # The actual interface components will be created here
170
  gr.Markdown(INSTRUCTIONS)
171
+
172
+ # URL display component - only shown if URL column is defined
173
+ url_display = None
174
+ if labeler.has_url_column:
175
+ url_display = gr.HTML(label="Reference URL")
176
 
177
+ session_id_display = gr.Textbox(label="Session Information", interactive=False)
178
+
 
179
  with gr.Row():
180
  with gr.Column():
181
+ left_output = gr.Textbox(
182
+ label="Model A Output",
183
+ lines=10,
184
+ interactive=False
185
+ )
186
+
187
  with gr.Column():
188
+ right_output = gr.Textbox(
189
+ label="Model B Output",
190
+ lines=10,
191
+ interactive=False
192
+ )
193
 
194
  item_id = gr.Textbox(visible=False)
195
 
196
  with gr.Row():
197
+ left_btn = gr.Button("⬅️ A is better", variant="primary")
198
+ right_btn = gr.Button("➡️ B is better", variant="primary")
199
+ tie_btn = gr.Button("🤝 Tie", variant="primary")
200
  cant_choose_btn = gr.Button("🤔 Can't choose")
201
 
202
+ current_sample_sld = gr.Slider(minimum=0, maximum=len(labeler), step=1,
203
+ interactive=False,
204
+ label='sample_ind',
205
+ info=f"Samples labeled (out of {len(labeler)})",
206
+ show_label=False,
207
+ container=False,
208
+ scale=5)
209
+
210
+ # Initialize the session and get the first pair
211
+ def init_session():
212
+ new_session_id = create_new_session()
213
+
214
+ if labeler.has_url_column:
215
+ initial_id, initial_left, initial_right, initial_url = labeler.get_current_pair(new_session_id)
216
+ url_html = f'<a href="{initial_url}" target="_blank">{initial_url}</a>' if initial_url else ""
217
+
218
+ return (
219
+ new_session_id, # session_id state
220
+ f"Session ID: {new_session_id}", # session_id_display
221
+ url_html, # url_display
222
+ initial_left, # left_output
223
+ initial_right, # right_output
224
+ initial_id, # item_id
225
+ labeler.current_index[new_session_id] # current_sample_sld
226
+ )
227
+ else:
228
+ initial_id, initial_left, initial_right = labeler.get_current_pair(new_session_id)
229
+
230
+ return (
231
+ new_session_id, # session_id state
232
+ f"Session ID: {new_session_id}", # session_id_display
233
+ initial_left, # left_output
234
+ initial_right, # right_output
235
+ initial_id, # item_id
236
+ labeler.current_index[new_session_id] # current_sample_sld
237
+ )
238
+
239
+ # Run the initialization when the app loads
240
+ if labeler.has_url_column:
241
+ app.load(
242
+ init_session,
243
+ inputs=None,
244
+ outputs=[session_id, session_id_display, url_display, left_output, right_output, item_id, current_sample_sld]
245
+ )
246
+ else:
247
+ app.load(
248
+ init_session,
249
+ inputs=None,
250
+ outputs=[session_id, session_id_display, left_output, right_output, item_id, current_sample_sld]
251
+ )
252
+
253
+ def judge_left(session_id, item_id, left_text, right_text):
254
+ return judge("A is better", session_id, item_id, left_text, right_text)
255
+
256
+ def judge_right(session_id, item_id, left_text, right_text):
257
+ return judge("B is better", session_id, item_id, left_text, right_text)
258
+
259
+ def judge_tie(session_id, item_id, left_text, right_text):
260
+ return judge("Tie", session_id, item_id, left_text, right_text)
261
+
262
+ def judge_cant_choose(session_id, item_id, left_text, right_text):
263
+ return judge("Can't choose", session_id, item_id, left_text, right_text)
264
+
265
+ def judge(choice, session_id, item_id, left_text, right_text):
266
+ if labeler.has_url_column:
267
+ new_id, new_left, new_right, new_url, new_index = labeler.submit_judgment(
268
+ item_id, left_text, right_text, choice, session_id
269
+ )
270
+ url_html = f'<a href="{new_url}" target="_blank">{new_url}</a>' if new_url else ""
271
+ return new_id, new_left, new_right, url_html, new_index
272
+ else:
273
+ new_id, new_left, new_right, new_index = labeler.submit_judgment(
274
+ item_id, left_text, right_text, choice, session_id
275
+ )
276
+ return new_id, new_left, new_right, new_index
277
+
278
+ if labeler.has_url_column:
279
+ left_btn.click(
280
+ judge_left,
281
+ inputs=[session_id, item_id, left_output, right_output],
282
+ outputs=[item_id, left_output, right_output, url_display, current_sample_sld]
283
+ )
284
+
285
+ right_btn.click(
286
+ judge_right,
287
+ inputs=[session_id, item_id, left_output, right_output],
288
+ outputs=[item_id, left_output, right_output, url_display, current_sample_sld]
289
+ )
290
+
291
+ tie_btn.click(
292
+ judge_tie,
293
+ inputs=[session_id, item_id, left_output, right_output],
294
+ outputs=[item_id, left_output, right_output, url_display, current_sample_sld]
295
+ )
296
+
297
+ cant_choose_btn.click(
298
+ judge_cant_choose,
299
+ inputs=[session_id, item_id, left_output, right_output],
300
+ outputs=[item_id, left_output, right_output, url_display, current_sample_sld]
301
+ )
302
+ else:
303
+ left_btn.click(
304
+ judge_left,
305
+ inputs=[session_id, item_id, left_output, right_output],
306
+ outputs=[item_id, left_output, right_output, current_sample_sld]
307
+ )
308
+
309
+ right_btn.click(
310
+ judge_right,
311
+ inputs=[session_id, item_id, left_output, right_output],
312
+ outputs=[item_id, left_output, right_output, current_sample_sld]
313
+ )
314
+
315
+ tie_btn.click(
316
+ judge_tie,
317
+ inputs=[session_id, item_id, left_output, right_output],
318
+ outputs=[item_id, left_output, right_output, current_sample_sld]
319
+ )
320
+
321
+ cant_choose_btn.click(
322
+ judge_cant_choose,
323
+ inputs=[session_id, item_id, left_output, right_output],
324
+ outputs=[item_id, left_output, right_output, current_sample_sld]
325
+ )
326
 
327
  if __name__ == "__main__":
328
+ app.launch()