tdurbor commited on
Commit
c098942
·
1 Parent(s): 23f5054

Remember last 10 indices used and sample from the other ones

Browse files
Files changed (1) hide show
  1. app.py +58 -32
app.py CHANGED
@@ -6,6 +6,7 @@ import logging
6
  import threading
7
  from pathlib import Path
8
  from datetime import datetime, timedelta
 
9
 
10
  import numpy as np
11
  import gradio as gr
@@ -80,18 +81,24 @@ def update_rankings_table():
80
  return []
81
  return rankings
82
 
83
-
84
- def select_new_image():
85
  """Select a new image and its segmented versions."""
86
  max_attempts = 10
87
- last_image_index = None
88
-
 
 
 
89
  for _ in range(max_attempts):
90
- available_indices = [i for i in range(len(dataset)) if i != last_image_index]
 
 
 
 
91
 
92
  if not available_indices:
93
  logging.error("No available images to select from.")
94
- return None
95
 
96
  random_index = random.choice(available_indices)
97
  sample = dataset[random_index]
@@ -102,23 +109,26 @@ def select_new_image():
102
 
103
  if segmented_images.count(None) > 2:
104
  logging.error("Not enough segmented images found for: %s. Resampling another image.", sample['original_filename'])
105
- last_image_index = random_index
106
  continue
107
 
108
  try:
109
  selected_indices = random.sample([i for i, img in enumerate(segmented_images) if img is not None], 2)
110
  model_a_index, model_b_index = selected_indices
 
 
 
 
111
  return (
112
- sample['original_filename'], input_image,
113
  segmented_images[model_a_index], segmented_images[model_b_index],
114
- segmented_sources[model_a_index], segmented_sources[model_b_index]
 
115
  )
116
  except Exception as e:
117
  logging.error("Error processing images: %s. Resampling another image.", str(e))
118
- last_image_index = random_index
119
 
120
  logging.error("Failed to select a new image after %d attempts.", max_attempts)
121
- return None
122
 
123
  def get_notice_markdown():
124
  """Generate the notice markdown with dynamic vote count."""
@@ -225,8 +235,9 @@ def get_default_username(profile: gr.OAuthProfile | None) -> str:
225
  def gradio_interface():
226
  """Create and return the Gradio interface."""
227
  with gr.Blocks(js=js, head=head, fill_width=True) as demo:
228
-
229
-
 
230
  button_name = "Difference between masks"
231
 
232
  with gr.Tabs() as tabs:
@@ -278,11 +289,14 @@ def gradio_interface():
278
  input_image_display = gr.AnnotatedImage(label="Input Image", width=image_width, height=image_height)
279
  image_b = gr.Image(label="Image B", width=image_width, height=image_height)
280
 
281
- # Refresh states to load new image data
282
-
283
- def refresh_states(state_filename, state_model_a_name, state_model_b_name):
284
  # Call select_new_image to get new image data
285
- filename, input_image, segmented_a, segmented_b, model_a_name, model_b_name = select_new_image()
 
 
 
 
 
286
  mask_difference = compute_mask_difference(segmented_a, segmented_b)
287
 
288
  # Update states with new data
@@ -299,11 +313,10 @@ def gradio_interface():
299
  height=image_height
300
  )
301
 
302
- outputs = [
303
  state_filename, image_a, image_b, state_model_a_name, state_model_b_name,
304
- input_image_display
305
  ]
306
- return outputs
307
 
308
 
309
  with gr.Row():
@@ -311,7 +324,7 @@ def gradio_interface():
311
  vote_tie_button = gr.Button("🤝 Tie")
312
  vote_b_button = gr.Button("👉 B is better")
313
 
314
- def vote_for_model(choice, original_filename, model_a_name, model_b_name, user_username):
315
  """Submit a vote for a model and return updated images and model names."""
316
 
317
 
@@ -355,34 +368,40 @@ def gradio_interface():
355
  except Exception as e:
356
  logging.error("Error recording vote: %s", str(e))
357
 
358
- outputs = refresh_states(state_filename, state_model_a_name, state_model_b_name)
359
  new_notice_markdown = get_notice_markdown()
360
 
361
  return outputs + [new_notice_markdown]
362
 
363
  notice_markdown = gr.Markdown(get_notice_markdown(), elem_id="notice_markdown")
364
  vote_a_button.click(
365
- fn=lambda username: vote_for_model("model_a", state_filename, state_model_a_name, state_model_b_name, username),
366
- inputs=[username_input],
 
 
367
  outputs=[
368
  state_filename, image_a, image_b, state_model_a_name, state_model_b_name,
369
- input_image_display, notice_markdown
370
  ]
371
  )
372
  vote_b_button.click(
373
- fn=lambda username: vote_for_model("model_b", state_filename, state_model_a_name, state_model_b_name, username),
374
- inputs=[username_input],
 
 
375
  outputs=[
376
  state_filename, image_a, image_b, state_model_a_name, state_model_b_name,
377
- input_image_display, notice_markdown
378
  ]
379
  )
380
  vote_tie_button.click(
381
- fn=lambda username: vote_for_model("tie", state_filename, state_model_a_name, state_model_b_name, username),
382
- inputs=[username_input],
 
 
383
  outputs=[
384
  state_filename, image_a, image_b, state_model_a_name, state_model_b_name,
385
- input_image_display, notice_markdown
386
  ]
387
  )
388
 
@@ -473,7 +492,14 @@ def gradio_interface():
473
  fn=lambda: get_weekly_user_leaderboard(),
474
  outputs=user_leaderboard_table
475
  )
476
- demo.load(lambda: refresh_states(state_filename, state_model_a_name, state_model_b_name), inputs=None, outputs=[state_filename, image_a, image_b, state_model_a_name, state_model_b_name, input_image_display])
 
 
 
 
 
 
 
477
  return demo
478
 
479
  def dump_database_to_json():
 
6
  import threading
7
  from pathlib import Path
8
  from datetime import datetime, timedelta
9
+ from collections import deque
10
 
11
  import numpy as np
12
  import gradio as gr
 
81
  return []
82
  return rankings
83
 
84
+ def select_new_image(last_used_indices):
 
85
  """Select a new image and its segmented versions."""
86
  max_attempts = 10
87
+
88
+ # Initialize empty deque if None
89
+ if last_used_indices is None:
90
+ last_used_indices = deque(maxlen=10)
91
+
92
  for _ in range(max_attempts):
93
+ # Filter out recently used indices
94
+ available_indices = [
95
+ i for i in range(len(dataset))
96
+ if i not in last_used_indices
97
+ ]
98
 
99
  if not available_indices:
100
  logging.error("No available images to select from.")
101
+ return None, last_used_indices
102
 
103
  random_index = random.choice(available_indices)
104
  sample = dataset[random_index]
 
109
 
110
  if segmented_images.count(None) > 2:
111
  logging.error("Not enough segmented images found for: %s. Resampling another image.", sample['original_filename'])
 
112
  continue
113
 
114
  try:
115
  selected_indices = random.sample([i for i, img in enumerate(segmented_images) if img is not None], 2)
116
  model_a_index, model_b_index = selected_indices
117
+
118
+ # Add the used index to our history
119
+ last_used_indices.append(random_index)
120
+
121
  return (
122
+ (sample['original_filename'], input_image,
123
  segmented_images[model_a_index], segmented_images[model_b_index],
124
+ segmented_sources[model_a_index], segmented_sources[model_b_index]),
125
+ last_used_indices
126
  )
127
  except Exception as e:
128
  logging.error("Error processing images: %s. Resampling another image.", str(e))
 
129
 
130
  logging.error("Failed to select a new image after %d attempts.", max_attempts)
131
+ return None, last_used_indices
132
 
133
  def get_notice_markdown():
134
  """Generate the notice markdown with dynamic vote count."""
 
235
  def gradio_interface():
236
  """Create and return the Gradio interface."""
237
  with gr.Blocks(js=js, head=head, fill_width=True) as demo:
238
+ # Initialize session state for last used indices
239
+ last_used_indices_state = gr.State()
240
+
241
  button_name = "Difference between masks"
242
 
243
  with gr.Tabs() as tabs:
 
289
  input_image_display = gr.AnnotatedImage(label="Input Image", width=image_width, height=image_height)
290
  image_b = gr.Image(label="Image B", width=image_width, height=image_height)
291
 
292
+ def refresh_states(state_filename, state_model_a_name, state_model_b_name, last_used_indices):
 
 
293
  # Call select_new_image to get new image data
294
+ result, new_last_used_indices = select_new_image(last_used_indices)
295
+ if result is None:
296
+ return [state_filename, image_a, image_b, state_model_a_name, state_model_b_name,
297
+ input_image_display, new_last_used_indices]
298
+
299
+ filename, input_image, segmented_a, segmented_b, model_a_name, model_b_name = result
300
  mask_difference = compute_mask_difference(segmented_a, segmented_b)
301
 
302
  # Update states with new data
 
313
  height=image_height
314
  )
315
 
316
+ return [
317
  state_filename, image_a, image_b, state_model_a_name, state_model_b_name,
318
+ input_image_display, new_last_used_indices
319
  ]
 
320
 
321
 
322
  with gr.Row():
 
324
  vote_tie_button = gr.Button("🤝 Tie")
325
  vote_b_button = gr.Button("👉 B is better")
326
 
327
+ def vote_for_model(choice, original_filename, model_a_name, model_b_name, user_username, last_used_indices):
328
  """Submit a vote for a model and return updated images and model names."""
329
 
330
 
 
368
  except Exception as e:
369
  logging.error("Error recording vote: %s", str(e))
370
 
371
+ outputs = refresh_states(state_filename, state_model_a_name, state_model_b_name, last_used_indices)
372
  new_notice_markdown = get_notice_markdown()
373
 
374
  return outputs + [new_notice_markdown]
375
 
376
  notice_markdown = gr.Markdown(get_notice_markdown(), elem_id="notice_markdown")
377
  vote_a_button.click(
378
+ fn=lambda username, last_used_indices: vote_for_model(
379
+ "model_a", state_filename, state_model_a_name, state_model_b_name, username, last_used_indices
380
+ ),
381
+ inputs=[username_input, last_used_indices_state],
382
  outputs=[
383
  state_filename, image_a, image_b, state_model_a_name, state_model_b_name,
384
+ input_image_display, last_used_indices_state, notice_markdown
385
  ]
386
  )
387
  vote_b_button.click(
388
+ fn=lambda username, last_used_indices: vote_for_model(
389
+ "model_b", state_filename, state_model_a_name, state_model_b_name, username, last_used_indices
390
+ ),
391
+ inputs=[username_input, last_used_indices_state],
392
  outputs=[
393
  state_filename, image_a, image_b, state_model_a_name, state_model_b_name,
394
+ input_image_display, last_used_indices_state, notice_markdown
395
  ]
396
  )
397
  vote_tie_button.click(
398
+ fn=lambda username, last_used_indices: vote_for_model(
399
+ "tie", state_filename, state_model_a_name, state_model_b_name, username, last_used_indices
400
+ ),
401
+ inputs=[username_input, last_used_indices_state],
402
  outputs=[
403
  state_filename, image_a, image_b, state_model_a_name, state_model_b_name,
404
+ input_image_display, last_used_indices_state, notice_markdown
405
  ]
406
  )
407
 
 
492
  fn=lambda: get_weekly_user_leaderboard(),
493
  outputs=user_leaderboard_table
494
  )
495
+ demo.load(
496
+ lambda: refresh_states(state_filename, state_model_a_name, state_model_b_name, None),
497
+ inputs=None,
498
+ outputs=[
499
+ state_filename, image_a, image_b, state_model_a_name, state_model_b_name,
500
+ input_image_display, last_used_indices_state
501
+ ]
502
+ )
503
  return demo
504
 
505
  def dump_database_to_json():