Spaces:
Sleeping
Sleeping
testing for new checkout page
Browse files- pages/Gallery.py +132 -73
- pages/Ranking.py +2 -2
pages/Gallery.py
CHANGED
@@ -27,7 +27,7 @@ class GalleryApp:
|
|
27 |
|
28 |
# init gallery state
|
29 |
if 'gallery_state' not in st.session_state:
|
30 |
-
st.session_state.gallery_state =
|
31 |
|
32 |
# initialize selected_dict
|
33 |
if 'selected_dict' not in st.session_state:
|
@@ -36,7 +36,7 @@ class GalleryApp:
|
|
36 |
if 'gallery_focus' not in st.session_state:
|
37 |
st.session_state.gallery_focus = {'tag': None, 'prompt': None}
|
38 |
|
39 |
-
def gallery_standard(self, items, col_num, info):
|
40 |
rows = len(items) // col_num + 1
|
41 |
containers = [st.container() for _ in range(rows)]
|
42 |
for idx in range(0, len(items), col_num):
|
@@ -59,8 +59,9 @@ class GalleryApp:
|
|
59 |
|
60 |
# st.write("Position: ", idx + j)
|
61 |
|
62 |
-
|
63 |
-
|
|
|
64 |
|
65 |
# show selected info
|
66 |
for key in info:
|
@@ -278,6 +279,7 @@ class GalleryApp:
|
|
278 |
# return prompt_tags, tag, prompt_id, items
|
279 |
|
280 |
def app(self):
|
|
|
281 |
st.write('### Model Visualization and Retrieval')
|
282 |
# st.write('This is a gallery of images generated by the models')
|
283 |
|
@@ -286,65 +288,66 @@ class GalleryApp:
|
|
286 |
# sort tags by alphabetical order
|
287 |
prompt_tags = np.sort(prompt_tags)[::1].tolist()
|
288 |
|
289 |
-
#
|
290 |
-
|
|
|
|
|
|
|
291 |
|
292 |
# save tag to session state on change
|
293 |
-
tag = st.radio('Select a tag', prompt_tags, index=
|
294 |
-
|
295 |
-
# tabs = st.tabs(prompt_tags)
|
296 |
-
# for i in range(len(prompt_tags)):
|
297 |
-
# with tabs[i]:
|
298 |
-
# tag = prompt_tags[i]
|
299 |
-
items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
|
300 |
-
|
301 |
-
prompts = np.sort(items['prompt'].unique())[::1].tolist()
|
302 |
-
|
303 |
-
# st.caption('Select a prompt')
|
304 |
-
subset_selector = st.columns([3, 1])
|
305 |
-
with subset_selector[0]:
|
306 |
-
selected_prompt = selectbox('Select prompt', prompts, key=f'prompt_{tag}', no_selection_label='---', label_visibility='collapsed', index=0)
|
307 |
-
# st.session_state.prompt_idx_last_time = prompts.index(selected_prompt) if selected_prompt else 0
|
308 |
|
309 |
-
|
310 |
-
# st.markdown(':orange[Please select a prompt above👆]')
|
311 |
-
st.write('**Feel free to navigate among tags and pages! Your selection will be saved within one log-in session.**')
|
312 |
|
313 |
-
|
314 |
-
st.write(':orange[👈 **Please select a prompt**]')
|
315 |
|
316 |
-
|
317 |
-
items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
|
318 |
-
prompt_id = items['prompt_id'].unique()[0]
|
319 |
-
note = items['note'].unique()[0]
|
320 |
-
|
321 |
-
# add state to session state
|
322 |
-
if prompt_id not in st.session_state.gallery_state:
|
323 |
-
st.session_state.gallery_state[prompt_id] = 'graph'
|
324 |
|
325 |
-
|
326 |
-
st.session_state.gallery_focus['tag'] = tag
|
327 |
-
st.session_state.gallery_focus['prompt'] = selected_prompt
|
328 |
|
329 |
-
#
|
330 |
-
|
|
|
|
|
|
|
331 |
|
332 |
-
#
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
# # manually add unsafe prompts
|
338 |
-
# unsafe_prompts['world knowledge'] = [83]
|
339 |
-
# unsafe_prompts['abstract'] = [1, 3]
|
340 |
|
341 |
-
if
|
342 |
-
st.
|
343 |
-
|
344 |
|
345 |
-
|
|
|
346 |
|
347 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
348 |
if safety_check:
|
349 |
self.graph_mode(prompt_id, items)
|
350 |
with subset_selector[-1]:
|
@@ -359,35 +362,48 @@ class GalleryApp:
|
|
359 |
checkout = st.button('Check out selections', use_container_width=True, type='primary')
|
360 |
if checkout:
|
361 |
print('checkout')
|
|
|
|
|
|
|
362 |
|
363 |
-
st.session_state.gallery_state
|
364 |
-
print(st.session_state.gallery_state
|
365 |
st.experimental_rerun()
|
366 |
else:
|
367 |
st.write(':orange[👇 **Select images you like below**]')
|
|
|
|
|
|
|
|
|
368 |
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
|
|
|
|
|
|
373 |
|
374 |
-
|
375 |
-
state_operations = st.columns([1, 1])
|
376 |
-
with state_operations[0]:
|
377 |
-
back = st.button('Back to 🖼️', use_container_width=True)
|
378 |
-
if back:
|
379 |
-
st.session_state.gallery_state[prompt_id] = 'graph'
|
380 |
-
st.experimental_rerun()
|
381 |
|
382 |
-
with state_operations[1]:
|
383 |
-
forward = st.button('Check out', use_container_width=True, type='primary', on_click=self.submit_actions, args=('Continue', prompt_id))
|
384 |
-
if forward:
|
385 |
-
switch_page('ranking')
|
386 |
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
391 |
|
392 |
def graph_mode(self, prompt_id, items):
|
393 |
graph_cols = st.columns([3, 1])
|
@@ -498,6 +514,49 @@ class GalleryApp:
|
|
498 |
with gallery_space.container():
|
499 |
self.gallery_standard(items, col_num, info)
|
500 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
501 |
def submit_actions(self, status, prompt_id):
|
502 |
# remove counter from session state
|
503 |
# st.session_state.pop('counter', None)
|
|
|
27 |
|
28 |
# init gallery state
|
29 |
if 'gallery_state' not in st.session_state:
|
30 |
+
st.session_state.gallery_state = 'graph'
|
31 |
|
32 |
# initialize selected_dict
|
33 |
if 'selected_dict' not in st.session_state:
|
|
|
36 |
if 'gallery_focus' not in st.session_state:
|
37 |
st.session_state.gallery_focus = {'tag': None, 'prompt': None}
|
38 |
|
39 |
+
def gallery_standard(self, items, col_num, info, show_checkbox=True):
|
40 |
rows = len(items) // col_num + 1
|
41 |
containers = [st.container() for _ in range(rows)]
|
42 |
for idx in range(0, len(items), col_num):
|
|
|
59 |
|
60 |
# st.write("Position: ", idx + j)
|
61 |
|
62 |
+
if show_checkbox:
|
63 |
+
# show checkbox
|
64 |
+
st.checkbox('Select', key=f'select_{prompt_id}_{modelVersion_id}', value=check_init)
|
65 |
|
66 |
# show selected info
|
67 |
for key in info:
|
|
|
279 |
# return prompt_tags, tag, prompt_id, items
|
280 |
|
281 |
def app(self):
|
282 |
+
print(st.session_state.gallery_focus)
|
283 |
st.write('### Model Visualization and Retrieval')
|
284 |
# st.write('This is a gallery of images generated by the models')
|
285 |
|
|
|
288 |
# sort tags by alphabetical order
|
289 |
prompt_tags = np.sort(prompt_tags)[::1].tolist()
|
290 |
|
291 |
+
# set focus tag and prompt index if exists
|
292 |
+
if st.session_state.gallery_focus['tag'] is None:
|
293 |
+
tag_focus_idx = 5
|
294 |
+
else:
|
295 |
+
tag_focus_idx = prompt_tags.index(st.session_state.gallery_focus['tag'])
|
296 |
|
297 |
# save tag to session state on change
|
298 |
+
tag = st.radio('Select a tag', prompt_tags, index=tag_focus_idx, horizontal=True, key='tag', label_visibility='collapsed')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
299 |
|
300 |
+
print('current state: ', st.session_state.gallery_state)
|
|
|
|
|
301 |
|
302 |
+
if st.session_state.gallery_state == 'graph':
|
|
|
303 |
|
304 |
+
items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
305 |
|
306 |
+
prompts = np.sort(items['prompt'].unique())[::1].tolist()
|
|
|
|
|
307 |
|
308 |
+
# selt focus prompt index if exists
|
309 |
+
if st.session_state.gallery_focus['prompt'] is None:
|
310 |
+
prompt_focus_idx = 0
|
311 |
+
else:
|
312 |
+
prompt_focus_idx = 1 + prompts.index(st.session_state.gallery_focus['prompt'])
|
313 |
|
314 |
+
# st.caption('Select a prompt')
|
315 |
+
subset_selector = st.columns([3, 1])
|
316 |
+
with subset_selector[0]:
|
317 |
+
selected_prompt = selectbox('Select prompt', prompts, key=f'prompt_{tag}', no_selection_label='---', label_visibility='collapsed', index=prompt_focus_idx)
|
318 |
+
# st.session_state.prompt_idx_last_time = prompts.index(selected_prompt) if selected_prompt else 0
|
|
|
|
|
|
|
319 |
|
320 |
+
if selected_prompt is None:
|
321 |
+
# st.markdown(':orange[Please select a prompt above👆]')
|
322 |
+
st.write('**Feel free to navigate among tags and pages! Your selection will be saved within one log-in session.**')
|
323 |
|
324 |
+
with subset_selector[-1]:
|
325 |
+
st.write(':orange[👈 **Please select a prompt**]')
|
326 |
|
327 |
+
else:
|
328 |
+
items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
|
329 |
+
prompt_id = items['prompt_id'].unique()[0]
|
330 |
+
note = items['note'].unique()[0]
|
331 |
+
|
332 |
+
# add safety check for some prompts
|
333 |
+
safety_check = True
|
334 |
+
|
335 |
+
# load unsafe prompts
|
336 |
+
unsafe_prompts = json.load(open('./data/unsafe_prompts.json', 'r'))
|
337 |
+
for prompt_tag in prompt_tags:
|
338 |
+
if prompt_tag not in unsafe_prompts:
|
339 |
+
unsafe_prompts[prompt_tag] = []
|
340 |
+
# # manually add unsafe prompts
|
341 |
+
# unsafe_prompts['world knowledge'] = [83]
|
342 |
+
# unsafe_prompts['abstract'] = [1, 3]
|
343 |
+
|
344 |
+
if int(prompt_id.item()) in unsafe_prompts[tag]:
|
345 |
+
st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
|
346 |
+
safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'safety_{prompt_id}')
|
347 |
+
|
348 |
+
# print('current state: ', st.session_state.gallery_state)
|
349 |
+
#
|
350 |
+
# if st.session_state.gallery_state == 'graph':
|
351 |
if safety_check:
|
352 |
self.graph_mode(prompt_id, items)
|
353 |
with subset_selector[-1]:
|
|
|
362 |
checkout = st.button('Check out selections', use_container_width=True, type='primary')
|
363 |
if checkout:
|
364 |
print('checkout')
|
365 |
+
# add focus to session state
|
366 |
+
st.session_state.gallery_focus['tag'] = tag
|
367 |
+
st.session_state.gallery_focus['prompt'] = selected_prompt
|
368 |
|
369 |
+
st.session_state.gallery_state = 'check out'
|
370 |
+
print(st.session_state.gallery_state)
|
371 |
st.experimental_rerun()
|
372 |
else:
|
373 |
st.write(':orange[👇 **Select images you like below**]')
|
374 |
+
try:
|
375 |
+
self.sidebar(items, prompt_id, note)
|
376 |
+
except:
|
377 |
+
pass
|
378 |
|
379 |
+
elif st.session_state.gallery_state == 'check out':
|
380 |
+
# select items under the current tag, while model_id in selected_dict keys with corresponding modelVersion_ids
|
381 |
+
items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
|
382 |
+
temp_items = pd.DataFrame()
|
383 |
+
for prompt_id, selected_models in st.session_state.selected_dict.items():
|
384 |
+
temp_items = temp_items.append(items[items['modelVersion_id'].isin(selected_models) & (items['prompt_id'] == prompt_id)])
|
385 |
+
items = temp_items.reset_index(drop=True)
|
386 |
|
387 |
+
self.checkout_mode(tag, items)
|
|
|
|
|
|
|
|
|
|
|
|
|
388 |
|
|
|
|
|
|
|
|
|
389 |
|
390 |
+
# items = items[items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(
|
391 |
+
# drop=True)
|
392 |
+
# self.gallery_mode(prompt_id, items)
|
393 |
+
#
|
394 |
+
# with subset_selector[-1]:
|
395 |
+
# state_operations = st.columns([1, 1])
|
396 |
+
# with state_operations[0]:
|
397 |
+
# back = st.button('Back to 🖼️', use_container_width=True)
|
398 |
+
# if back:
|
399 |
+
# st.session_state.gallery_state[prompt_id] = 'graph'
|
400 |
+
# st.experimental_rerun()
|
401 |
+
#
|
402 |
+
# with state_operations[1]:
|
403 |
+
# forward = st.button('Check out', use_container_width=True, type='primary', on_click=self.submit_actions, args=('Continue', prompt_id))
|
404 |
+
# if forward:
|
405 |
+
# switch_page('ranking')
|
406 |
+
|
407 |
|
408 |
def graph_mode(self, prompt_id, items):
|
409 |
graph_cols = st.columns([3, 1])
|
|
|
514 |
with gallery_space.container():
|
515 |
self.gallery_standard(items, col_num, info)
|
516 |
|
517 |
+
def checkout_mode(self, tag, items):
|
518 |
+
# st.write(items)
|
519 |
+
if len(items) > 0:
|
520 |
+
for prompt_id in items['prompt_id'].unique():
|
521 |
+
prompt = items[items['prompt_id'] == prompt_id]['prompt'].unique()[0]
|
522 |
+
default_expand = True if st.session_state.gallery_focus['prompt'] == prompt else False
|
523 |
+
with st.expander(f'**{prompt}**', expanded=default_expand):
|
524 |
+
# st.caption('select info to show')
|
525 |
+
checkout_panel = st.columns([6, 1, 1])
|
526 |
+
with checkout_panel[0]:
|
527 |
+
pass
|
528 |
+
info = st.multiselect('Show Info',
|
529 |
+
['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id',
|
530 |
+
'weighted_score_sum', 'model_download_count', 'clip_score', 'mcos_score',
|
531 |
+
'nsfw_score', 'norm_nsfw'],
|
532 |
+
label_visibility='collapsed', key=f'info_{prompt_id}', placeholder='Select info to show')
|
533 |
+
with checkout_panel[1]:
|
534 |
+
back = st.button('Back to 🖼️', key=f'checkout_back_{prompt_id}', use_container_width=True)
|
535 |
+
if back:
|
536 |
+
st.session_state.gallery_focus['tag'] = tag
|
537 |
+
st.session_state.gallery_focus['prompt'] = prompt
|
538 |
+
print(st.session_state.gallery_focus)
|
539 |
+
st.session_state.gallery_state = 'graph'
|
540 |
+
st.experimental_rerun()
|
541 |
+
with checkout_panel[2]:
|
542 |
+
proceed = st.button('Proceed ➡️', key=f'checkout_proceed_{prompt_id}', use_container_width=True,
|
543 |
+
type='primary')
|
544 |
+
if proceed:
|
545 |
+
st.session_state.gallery_focus['tag'] = tag
|
546 |
+
st.session_state.gallery_focus['prompt'] = prompt
|
547 |
+
switch_page('ranking')
|
548 |
+
|
549 |
+
self.gallery_standard(items[items['prompt_id'] == prompt_id].reset_index(drop=True), 4, info, show_checkbox=False)
|
550 |
+
else:
|
551 |
+
# with st.form(key=f'checkout_{tag}'):
|
552 |
+
st.info('No selection under this tag')
|
553 |
+
back = st.button('🖼️ Back to gallery and select something you like', key=f'checkout_{tag}', type='primary')
|
554 |
+
if back:
|
555 |
+
st.session_state.gallery_focus['tag'] = tag
|
556 |
+
st.session_state.gallery_focus['prompt'] = None
|
557 |
+
st.session_state.gallery_state = 'graph'
|
558 |
+
st.experimental_rerun()
|
559 |
+
|
560 |
def submit_actions(self, status, prompt_id):
|
561 |
# remove counter from session state
|
562 |
# st.session_state.pop('counter', None)
|
pages/Ranking.py
CHANGED
@@ -149,10 +149,10 @@ class RankingApp:
|
|
149 |
with control:
|
150 |
if st.session_state.counter[prompt_id] < batch_num - 1:
|
151 |
st.button(":arrow_right:", key='next', on_click=self.next_batch, help='Next Batch',
|
152 |
-
kwargs={'prompt_id': prompt_id}, use_container_width=True)
|
153 |
else:
|
154 |
st.button(":heavy_check_mark:", key='finished', on_click=self.next_batch, help='Finished',
|
155 |
-
kwargs={'prompt_id': prompt_id, 'progress': 'finished'}, use_container_width=True)
|
156 |
|
157 |
if st.session_state.counter[prompt_id] > 0:
|
158 |
st.button(":arrow_left:", key='prev', on_click=self.prev_batch, help='Previous Batch',
|
|
|
149 |
with control:
|
150 |
if st.session_state.counter[prompt_id] < batch_num - 1:
|
151 |
st.button(":arrow_right:", key='next', on_click=self.next_batch, help='Next Batch',
|
152 |
+
kwargs={'prompt_id': prompt_id}, use_container_width=True, type='primary')
|
153 |
else:
|
154 |
st.button(":heavy_check_mark:", key='finished', on_click=self.next_batch, help='Finished',
|
155 |
+
kwargs={'prompt_id': prompt_id, 'progress': 'finished'}, use_container_width=True, type='primary')
|
156 |
|
157 |
if st.session_state.counter[prompt_id] > 0:
|
158 |
st.button(":arrow_left:", key='prev', on_click=self.prev_batch, help='Previous Batch',
|