Spaces:
Running
Running
fix safety check
Browse files
app.py
CHANGED
@@ -195,6 +195,7 @@ class GalleryApp:
|
|
195 |
return items, info, col_num
|
196 |
|
197 |
def selection_panel_2(self, items):
|
|
|
198 |
selecters = st.columns([1, 4])
|
199 |
|
200 |
# select sort type
|
@@ -219,17 +220,17 @@ class GalleryApp:
|
|
219 |
# add custom weights
|
220 |
sub_selecters = st.columns([1, 1, 1, 1])
|
221 |
|
222 |
-
if '
|
223 |
-
st.session_state.
|
224 |
|
225 |
with sub_selecters[0]:
|
226 |
-
clip_weight = st.number_input('Clip Score Weight', min_value=-100.0, max_value=100.0, value=st.session_state.
|
227 |
with sub_selecters[1]:
|
228 |
-
rank_weight = st.number_input('Distinctiveness Weight', min_value=-100.0, max_value=100.0, value=st.session_state.
|
229 |
with sub_selecters[2]:
|
230 |
-
pop_weight = st.number_input('Popularity Weight', min_value=-100.0, max_value=100.0, value=st.session_state.
|
231 |
|
232 |
-
st.session_state.
|
233 |
|
234 |
items.loc[:, 'weighted_score_sum'] = round(items['norm_clip'] * clip_weight + items['avg_rank'] * rank_weight + items[
|
235 |
'norm_pop'] * pop_weight, 4)
|
@@ -241,17 +242,6 @@ class GalleryApp:
|
|
241 |
dist_threshold = st.number_input('Distinctiveness Threshold', min_value=0.0, max_value=1.0, value=0.84, step=0.01, help='Only show models with distinctiveness score lower than this threshold, set 1.0 to show all images')
|
242 |
items = items[items['avg_rank'] < dist_threshold].reset_index(drop=True)
|
243 |
|
244 |
-
# filter = st.selectbox('Filter', ['Safe', 'All', 'Unsafe'])
|
245 |
-
# print('filter', filter)
|
246 |
-
# # initialize unsafe_modelVersion_ids
|
247 |
-
# if filter == 'Safe':
|
248 |
-
# # return unchecked items
|
249 |
-
# items = items[items['checked'] == False].reset_index(drop=True)
|
250 |
-
#
|
251 |
-
# elif filter == 'Unsafe':
|
252 |
-
# # return checked items
|
253 |
-
# items = items[items['checked'] == True].reset_index(drop=True)
|
254 |
-
|
255 |
# draw a distribution histogram
|
256 |
if sort_type == 'Scores':
|
257 |
try:
|
@@ -369,7 +359,7 @@ class GalleryApp:
|
|
369 |
|
370 |
if int(prompt_id.item()) in unsafe_prompts[tag]:
|
371 |
st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
|
372 |
-
safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.')
|
373 |
|
374 |
if safety_check:
|
375 |
items, info, col_num = self.selection_panel_2(items)
|
@@ -392,6 +382,7 @@ class GalleryApp:
|
|
392 |
with buttons_space:
|
393 |
st.form_submit_button('Refresh', on_click=gallery_space.empty, use_container_width=True, type='primary')
|
394 |
|
|
|
395 |
def reset_current_prompt(self, prompt_id):
|
396 |
# reset current prompt
|
397 |
self.promptBook.loc[self.promptBook['prompt_id'] == prompt_id, 'checked'] = False
|
|
|
195 |
return items, info, col_num
|
196 |
|
197 |
def selection_panel_2(self, items):
|
198 |
+
|
199 |
selecters = st.columns([1, 4])
|
200 |
|
201 |
# select sort type
|
|
|
220 |
# add custom weights
|
221 |
sub_selecters = st.columns([1, 1, 1, 1])
|
222 |
|
223 |
+
if 'score_weights' not in st.session_state:
|
224 |
+
st.session_state.score_weights = [1.0, 0.8, 0.2]
|
225 |
|
226 |
with sub_selecters[0]:
|
227 |
+
clip_weight = st.number_input('Clip Score Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[0], step=0.1, help='the weight for normalized clip score')
|
228 |
with sub_selecters[1]:
|
229 |
+
rank_weight = st.number_input('Distinctiveness Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[1], step=0.1, help='the weight for average rank')
|
230 |
with sub_selecters[2]:
|
231 |
+
pop_weight = st.number_input('Popularity Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[2], step=0.1, help='the weight for normalized popularity score')
|
232 |
|
233 |
+
st.session_state.score_weights = [clip_weight, rank_weight, pop_weight]
|
234 |
|
235 |
items.loc[:, 'weighted_score_sum'] = round(items['norm_clip'] * clip_weight + items['avg_rank'] * rank_weight + items[
|
236 |
'norm_pop'] * pop_weight, 4)
|
|
|
242 |
dist_threshold = st.number_input('Distinctiveness Threshold', min_value=0.0, max_value=1.0, value=0.84, step=0.01, help='Only show models with distinctiveness score lower than this threshold, set 1.0 to show all images')
|
243 |
items = items[items['avg_rank'] < dist_threshold].reset_index(drop=True)
|
244 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
# draw a distribution histogram
|
246 |
if sort_type == 'Scores':
|
247 |
try:
|
|
|
359 |
|
360 |
if int(prompt_id.item()) in unsafe_prompts[tag]:
|
361 |
st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
|
362 |
+
safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'{prompt_id}')
|
363 |
|
364 |
if safety_check:
|
365 |
items, info, col_num = self.selection_panel_2(items)
|
|
|
382 |
with buttons_space:
|
383 |
st.form_submit_button('Refresh', on_click=gallery_space.empty, use_container_width=True, type='primary')
|
384 |
|
385 |
+
|
386 |
def reset_current_prompt(self, prompt_id):
|
387 |
# reset current prompt
|
388 |
self.promptBook.loc[self.promptBook['prompt_id'] == prompt_id, 'checked'] = False
|