Spaces:
Running
Running
Update Gallery.py
Browse files- pages/Gallery.py +9 -11
pages/Gallery.py
CHANGED
@@ -82,11 +82,11 @@ class GalleryApp:
|
|
82 |
sub_selecters = st.columns([1, 1, 1, 1])
|
83 |
|
84 |
with sub_selecters[0]:
|
85 |
-
clip_weight = st.number_input('Clip Score Weight', min_value=-100.0, max_value=100.0, value=
|
86 |
with sub_selecters[1]:
|
87 |
-
mcos_weight = st.number_input('Dissimilarity Weight', min_value=-100.0, max_value=100.0, value=
|
88 |
with sub_selecters[2]:
|
89 |
-
pop_weight = st.number_input('Popularity Weight', min_value=-100.0, max_value=100.0, value=
|
90 |
|
91 |
items.loc[:, 'weighted_score_sum'] = round(items[f'norm_clip'] * clip_weight + items[f'norm_mcos'] * mcos_weight + items[
|
92 |
'norm_pop'] * pop_weight, 4)
|
@@ -94,13 +94,13 @@ class GalleryApp:
|
|
94 |
continue_idx = 3
|
95 |
|
96 |
# save latest weights
|
97 |
-
st.session_state.score_weights[0] = clip_weight
|
98 |
-
st.session_state.score_weights[1] = mcos_weight
|
99 |
-
st.session_state.score_weights[2] = pop_weight
|
100 |
|
101 |
# select threshold
|
102 |
with sub_selecters[continue_idx]:
|
103 |
-
nsfw_threshold = st.number_input('NSFW Score Threshold', min_value=0.0, max_value=1.0, value=
|
104 |
items = items[items['norm_nsfw'] <= nsfw_threshold].reset_index(drop=True)
|
105 |
|
106 |
# save latest threshold
|
@@ -214,6 +214,7 @@ class GalleryApp:
|
|
214 |
st.write('This is a gallery of images generated by the models')
|
215 |
|
216 |
prompt_tags, tag, prompt_id, items = self.sidebar()
|
|
|
217 |
|
218 |
# add safety check for some prompts
|
219 |
safety_check = True
|
@@ -223,16 +224,13 @@ class GalleryApp:
|
|
223 |
unsafe_prompts[prompt_tag] = []
|
224 |
# manually add unsafe prompts
|
225 |
unsafe_prompts['world knowledge'] = [83]
|
226 |
-
# unsafe_prompts['art'] = [23]
|
227 |
unsafe_prompts['abstract'] = [1, 3]
|
228 |
-
# unsafe_prompts['food'] = [34]
|
229 |
|
230 |
if int(prompt_id.item()) in unsafe_prompts[tag]:
|
231 |
st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
|
232 |
-
safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'{prompt_id}')
|
233 |
|
234 |
if safety_check:
|
235 |
-
items, info, col_num = self.selection_panel(items)
|
236 |
|
237 |
if 'selected_dict' in st.session_state:
|
238 |
# st.write('checked: ', str(st.session_state.selected_dict.get(prompt_id, [])))
|
|
|
82 |
sub_selecters = st.columns([1, 1, 1, 1])
|
83 |
|
84 |
with sub_selecters[0]:
|
85 |
+
clip_weight = st.number_input('Clip Score Weight', min_value=-100.0, max_value=100.0, value=1.0, step=0.1, help='the weight for normalized clip score')
|
86 |
with sub_selecters[1]:
|
87 |
+
mcos_weight = st.number_input('Dissimilarity Weight', min_value=-100.0, max_value=100.0, value=0.8, step=0.1, help='the weight for m(eam) s(imilarity) q(antile) score for measuring distinctiveness')
|
88 |
with sub_selecters[2]:
|
89 |
+
pop_weight = st.number_input('Popularity Weight', min_value=-100.0, max_value=100.0, value=0.2, step=0.1, help='the weight for normalized popularity score')
|
90 |
|
91 |
items.loc[:, 'weighted_score_sum'] = round(items[f'norm_clip'] * clip_weight + items[f'norm_mcos'] * mcos_weight + items[
|
92 |
'norm_pop'] * pop_weight, 4)
|
|
|
94 |
continue_idx = 3
|
95 |
|
96 |
# save latest weights
|
97 |
+
st.session_state.score_weights[0] = round(clip_weight, 2)
|
98 |
+
st.session_state.score_weights[1] = round(mcos_weight, 2)
|
99 |
+
st.session_state.score_weights[2] = round(pop_weight, 2)
|
100 |
|
101 |
# select threshold
|
102 |
with sub_selecters[continue_idx]:
|
103 |
+
nsfw_threshold = st.number_input('NSFW Score Threshold', min_value=0.0, max_value=1.0, value=0.8, step=0.01, help='Only show models with nsfw score lower than this threshold, set 1.0 to show all images')
|
104 |
items = items[items['norm_nsfw'] <= nsfw_threshold].reset_index(drop=True)
|
105 |
|
106 |
# save latest threshold
|
|
|
214 |
st.write('This is a gallery of images generated by the models')
|
215 |
|
216 |
prompt_tags, tag, prompt_id, items = self.sidebar()
|
217 |
+
items, info, col_num = self.selection_panel(items)
|
218 |
|
219 |
# add safety check for some prompts
|
220 |
safety_check = True
|
|
|
224 |
unsafe_prompts[prompt_tag] = []
|
225 |
# manually add unsafe prompts
|
226 |
unsafe_prompts['world knowledge'] = [83]
|
|
|
227 |
unsafe_prompts['abstract'] = [1, 3]
|
|
|
228 |
|
229 |
if int(prompt_id.item()) in unsafe_prompts[tag]:
|
230 |
st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
|
231 |
+
safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'safety_{prompt_id}')
|
232 |
|
233 |
if safety_check:
|
|
|
234 |
|
235 |
if 'selected_dict' in st.session_state:
|
236 |
# st.write('checked: ', str(st.session_state.selected_dict.get(prompt_id, [])))
|