Ricercar commited on
Commit
2c3dcf3
·
1 Parent(s): ab37b94

fix safety check

Browse files
Files changed (1) hide show
  1. app.py +9 -18
app.py CHANGED
@@ -195,6 +195,7 @@ class GalleryApp:
195
  return items, info, col_num
196
 
197
  def selection_panel_2(self, items):
 
198
  selecters = st.columns([1, 4])
199
 
200
  # select sort type
@@ -219,17 +220,17 @@ class GalleryApp:
219
  # add custom weights
220
  sub_selecters = st.columns([1, 1, 1, 1])
221
 
222
- if 'default_weights' not in st.session_state:
223
- st.session_state.default_weights = [1.0, 0.8, 0.2]
224
 
225
  with sub_selecters[0]:
226
- clip_weight = st.number_input('Clip Score Weight', min_value=-100.0, max_value=100.0, value=st.session_state.default_weights[0], step=0.1, help='the weight for normalized clip score')
227
  with sub_selecters[1]:
228
- rank_weight = st.number_input('Distinctiveness Weight', min_value=-100.0, max_value=100.0, value=st.session_state.default_weights[1], step=0.1, help='the weight for average rank')
229
  with sub_selecters[2]:
230
- pop_weight = st.number_input('Popularity Weight', min_value=-100.0, max_value=100.0, value=st.session_state.default_weights[2], step=0.1, help='the weight for normalized popularity score')
231
 
232
- st.session_state.default_weights = [clip_weight, rank_weight, pop_weight]
233
 
234
  items.loc[:, 'weighted_score_sum'] = round(items['norm_clip'] * clip_weight + items['avg_rank'] * rank_weight + items[
235
  'norm_pop'] * pop_weight, 4)
@@ -241,17 +242,6 @@ class GalleryApp:
241
  dist_threshold = st.number_input('Distinctiveness Threshold', min_value=0.0, max_value=1.0, value=0.84, step=0.01, help='Only show models with distinctiveness score lower than this threshold, set 1.0 to show all images')
242
  items = items[items['avg_rank'] < dist_threshold].reset_index(drop=True)
243
 
244
- # filter = st.selectbox('Filter', ['Safe', 'All', 'Unsafe'])
245
- # print('filter', filter)
246
- # # initialize unsafe_modelVersion_ids
247
- # if filter == 'Safe':
248
- # # return unchecked items
249
- # items = items[items['checked'] == False].reset_index(drop=True)
250
- #
251
- # elif filter == 'Unsafe':
252
- # # return checked items
253
- # items = items[items['checked'] == True].reset_index(drop=True)
254
-
255
  # draw a distribution histogram
256
  if sort_type == 'Scores':
257
  try:
@@ -369,7 +359,7 @@ class GalleryApp:
369
 
370
  if int(prompt_id.item()) in unsafe_prompts[tag]:
371
  st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
372
- safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.')
373
 
374
  if safety_check:
375
  items, info, col_num = self.selection_panel_2(items)
@@ -392,6 +382,7 @@ class GalleryApp:
392
  with buttons_space:
393
  st.form_submit_button('Refresh', on_click=gallery_space.empty, use_container_width=True, type='primary')
394
 
 
395
  def reset_current_prompt(self, prompt_id):
396
  # reset current prompt
397
  self.promptBook.loc[self.promptBook['prompt_id'] == prompt_id, 'checked'] = False
 
195
  return items, info, col_num
196
 
197
  def selection_panel_2(self, items):
198
+
199
  selecters = st.columns([1, 4])
200
 
201
  # select sort type
 
220
  # add custom weights
221
  sub_selecters = st.columns([1, 1, 1, 1])
222
 
223
+ if 'score_weights' not in st.session_state:
224
+ st.session_state.score_weights = [1.0, 0.8, 0.2]
225
 
226
  with sub_selecters[0]:
227
+ clip_weight = st.number_input('Clip Score Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[0], step=0.1, help='the weight for normalized clip score')
228
  with sub_selecters[1]:
229
+ rank_weight = st.number_input('Distinctiveness Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[1], step=0.1, help='the weight for average rank')
230
  with sub_selecters[2]:
231
+ pop_weight = st.number_input('Popularity Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[2], step=0.1, help='the weight for normalized popularity score')
232
 
233
+ st.session_state.score_weights = [clip_weight, rank_weight, pop_weight]
234
 
235
  items.loc[:, 'weighted_score_sum'] = round(items['norm_clip'] * clip_weight + items['avg_rank'] * rank_weight + items[
236
  'norm_pop'] * pop_weight, 4)
 
242
  dist_threshold = st.number_input('Distinctiveness Threshold', min_value=0.0, max_value=1.0, value=0.84, step=0.01, help='Only show models with distinctiveness score lower than this threshold, set 1.0 to show all images')
243
  items = items[items['avg_rank'] < dist_threshold].reset_index(drop=True)
244
 
 
 
 
 
 
 
 
 
 
 
 
245
  # draw a distribution histogram
246
  if sort_type == 'Scores':
247
  try:
 
359
 
360
  if int(prompt_id.item()) in unsafe_prompts[tag]:
361
  st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
362
+ safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'{prompt_id}')
363
 
364
  if safety_check:
365
  items, info, col_num = self.selection_panel_2(items)
 
382
  with buttons_space:
383
  st.form_submit_button('Refresh', on_click=gallery_space.empty, use_container_width=True, type='primary')
384
 
385
+
386
  def reset_current_prompt(self, prompt_id):
387
  # reset current prompt
388
  self.promptBook.loc[self.promptBook['prompt_id'] == prompt_id, 'checked'] = False