Ricercar commited on
Commit
de41d12
·
1 Parent(s): fb1a1d0

testing for new checkout page

Browse files
Files changed (2) hide show
  1. pages/Gallery.py +132 -73
  2. pages/Ranking.py +2 -2
pages/Gallery.py CHANGED
@@ -27,7 +27,7 @@ class GalleryApp:
27
 
28
  # init gallery state
29
  if 'gallery_state' not in st.session_state:
30
- st.session_state.gallery_state = {}
31
 
32
  # initialize selected_dict
33
  if 'selected_dict' not in st.session_state:
@@ -36,7 +36,7 @@ class GalleryApp:
36
  if 'gallery_focus' not in st.session_state:
37
  st.session_state.gallery_focus = {'tag': None, 'prompt': None}
38
 
39
- def gallery_standard(self, items, col_num, info):
40
  rows = len(items) // col_num + 1
41
  containers = [st.container() for _ in range(rows)]
42
  for idx in range(0, len(items), col_num):
@@ -59,8 +59,9 @@ class GalleryApp:
59
 
60
  # st.write("Position: ", idx + j)
61
 
62
- # show checkbox
63
- st.checkbox('Select', key=f'select_{prompt_id}_{modelVersion_id}', value=check_init)
 
64
 
65
  # show selected info
66
  for key in info:
@@ -278,6 +279,7 @@ class GalleryApp:
278
  # return prompt_tags, tag, prompt_id, items
279
 
280
  def app(self):
 
281
  st.write('### Model Visualization and Retrieval')
282
  # st.write('This is a gallery of images generated by the models')
283
 
@@ -286,65 +288,66 @@ class GalleryApp:
286
  # sort tags by alphabetical order
287
  prompt_tags = np.sort(prompt_tags)[::1].tolist()
288
 
289
- # chosen_data = [stx.TabBarItemData(id=tag, title=tag, description='') for tag in prompt_tags]
290
- # tag = stx.tab_bar(chosen_data, key='tag', default='food')
 
 
 
291
 
292
  # save tag to session state on change
293
- tag = st.radio('Select a tag', prompt_tags, index=5, horizontal=True, key='tag', label_visibility='collapsed')
294
-
295
- # tabs = st.tabs(prompt_tags)
296
- # for i in range(len(prompt_tags)):
297
- # with tabs[i]:
298
- # tag = prompt_tags[i]
299
- items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
300
-
301
- prompts = np.sort(items['prompt'].unique())[::1].tolist()
302
-
303
- # st.caption('Select a prompt')
304
- subset_selector = st.columns([3, 1])
305
- with subset_selector[0]:
306
- selected_prompt = selectbox('Select prompt', prompts, key=f'prompt_{tag}', no_selection_label='---', label_visibility='collapsed', index=0)
307
- # st.session_state.prompt_idx_last_time = prompts.index(selected_prompt) if selected_prompt else 0
308
 
309
- if selected_prompt is None:
310
- # st.markdown(':orange[Please select a prompt above👆]')
311
- st.write('**Feel free to navigate among tags and pages! Your selection will be saved within one log-in session.**')
312
 
313
- with subset_selector[-1]:
314
- st.write(':orange[👈 **Please select a prompt**]')
315
 
316
- else:
317
- items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
318
- prompt_id = items['prompt_id'].unique()[0]
319
- note = items['note'].unique()[0]
320
-
321
- # add state to session state
322
- if prompt_id not in st.session_state.gallery_state:
323
- st.session_state.gallery_state[prompt_id] = 'graph'
324
 
325
- # add focus to session state
326
- st.session_state.gallery_focus['tag'] = tag
327
- st.session_state.gallery_focus['prompt'] = selected_prompt
328
 
329
- # add safety check for some prompts
330
- safety_check = True
 
 
 
331
 
332
- # load unsafe prompts
333
- unsafe_prompts = json.load(open('./data/unsafe_prompts.json', 'r'))
334
- for prompt_tag in prompt_tags:
335
- if prompt_tag not in unsafe_prompts:
336
- unsafe_prompts[prompt_tag] = []
337
- # # manually add unsafe prompts
338
- # unsafe_prompts['world knowledge'] = [83]
339
- # unsafe_prompts['abstract'] = [1, 3]
340
 
341
- if int(prompt_id.item()) in unsafe_prompts[tag]:
342
- st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
343
- safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'safety_{prompt_id}')
344
 
345
- print('current state: ', st.session_state.gallery_state[prompt_id])
 
346
 
347
- if st.session_state.gallery_state[prompt_id] == 'graph':
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
  if safety_check:
349
  self.graph_mode(prompt_id, items)
350
  with subset_selector[-1]:
@@ -359,35 +362,48 @@ class GalleryApp:
359
  checkout = st.button('Check out selections', use_container_width=True, type='primary')
360
  if checkout:
361
  print('checkout')
 
 
 
362
 
363
- st.session_state.gallery_state[prompt_id] = 'gallery'
364
- print(st.session_state.gallery_state[prompt_id])
365
  st.experimental_rerun()
366
  else:
367
  st.write(':orange[👇 **Select images you like below**]')
 
 
 
 
368
 
369
- elif st.session_state.gallery_state[prompt_id] == 'gallery':
370
- items = items[items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(
371
- drop=True)
372
- self.gallery_mode(prompt_id, items)
 
 
 
373
 
374
- with subset_selector[-1]:
375
- state_operations = st.columns([1, 1])
376
- with state_operations[0]:
377
- back = st.button('Back to 🖼️', use_container_width=True)
378
- if back:
379
- st.session_state.gallery_state[prompt_id] = 'graph'
380
- st.experimental_rerun()
381
 
382
- with state_operations[1]:
383
- forward = st.button('Check out', use_container_width=True, type='primary', on_click=self.submit_actions, args=('Continue', prompt_id))
384
- if forward:
385
- switch_page('ranking')
386
 
387
- try:
388
- self.sidebar(items, prompt_id, note)
389
- except:
390
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
391
 
392
  def graph_mode(self, prompt_id, items):
393
  graph_cols = st.columns([3, 1])
@@ -498,6 +514,49 @@ class GalleryApp:
498
  with gallery_space.container():
499
  self.gallery_standard(items, col_num, info)
500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501
  def submit_actions(self, status, prompt_id):
502
  # remove counter from session state
503
  # st.session_state.pop('counter', None)
 
27
 
28
  # init gallery state
29
  if 'gallery_state' not in st.session_state:
30
+ st.session_state.gallery_state = 'graph'
31
 
32
  # initialize selected_dict
33
  if 'selected_dict' not in st.session_state:
 
36
  if 'gallery_focus' not in st.session_state:
37
  st.session_state.gallery_focus = {'tag': None, 'prompt': None}
38
 
39
+ def gallery_standard(self, items, col_num, info, show_checkbox=True):
40
  rows = len(items) // col_num + 1
41
  containers = [st.container() for _ in range(rows)]
42
  for idx in range(0, len(items), col_num):
 
59
 
60
  # st.write("Position: ", idx + j)
61
 
62
+ if show_checkbox:
63
+ # show checkbox
64
+ st.checkbox('Select', key=f'select_{prompt_id}_{modelVersion_id}', value=check_init)
65
 
66
  # show selected info
67
  for key in info:
 
279
  # return prompt_tags, tag, prompt_id, items
280
 
281
  def app(self):
282
+ print(st.session_state.gallery_focus)
283
  st.write('### Model Visualization and Retrieval')
284
  # st.write('This is a gallery of images generated by the models')
285
 
 
288
  # sort tags by alphabetical order
289
  prompt_tags = np.sort(prompt_tags)[::1].tolist()
290
 
291
+ # set focus tag and prompt index if exists
292
+ if st.session_state.gallery_focus['tag'] is None:
293
+ tag_focus_idx = 5
294
+ else:
295
+ tag_focus_idx = prompt_tags.index(st.session_state.gallery_focus['tag'])
296
 
297
  # save tag to session state on change
298
+ tag = st.radio('Select a tag', prompt_tags, index=tag_focus_idx, horizontal=True, key='tag', label_visibility='collapsed')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
+ print('current state: ', st.session_state.gallery_state)
 
 
301
 
302
+ if st.session_state.gallery_state == 'graph':
 
303
 
304
+ items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
 
 
 
 
 
 
 
305
 
306
+ prompts = np.sort(items['prompt'].unique())[::1].tolist()
 
 
307
 
308
+ # selt focus prompt index if exists
309
+ if st.session_state.gallery_focus['prompt'] is None:
310
+ prompt_focus_idx = 0
311
+ else:
312
+ prompt_focus_idx = 1 + prompts.index(st.session_state.gallery_focus['prompt'])
313
 
314
+ # st.caption('Select a prompt')
315
+ subset_selector = st.columns([3, 1])
316
+ with subset_selector[0]:
317
+ selected_prompt = selectbox('Select prompt', prompts, key=f'prompt_{tag}', no_selection_label='---', label_visibility='collapsed', index=prompt_focus_idx)
318
+ # st.session_state.prompt_idx_last_time = prompts.index(selected_prompt) if selected_prompt else 0
 
 
 
319
 
320
+ if selected_prompt is None:
321
+ # st.markdown(':orange[Please select a prompt above👆]')
322
+ st.write('**Feel free to navigate among tags and pages! Your selection will be saved within one log-in session.**')
323
 
324
+ with subset_selector[-1]:
325
+ st.write(':orange[👈 **Please select a prompt**]')
326
 
327
+ else:
328
+ items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
329
+ prompt_id = items['prompt_id'].unique()[0]
330
+ note = items['note'].unique()[0]
331
+
332
+ # add safety check for some prompts
333
+ safety_check = True
334
+
335
+ # load unsafe prompts
336
+ unsafe_prompts = json.load(open('./data/unsafe_prompts.json', 'r'))
337
+ for prompt_tag in prompt_tags:
338
+ if prompt_tag not in unsafe_prompts:
339
+ unsafe_prompts[prompt_tag] = []
340
+ # # manually add unsafe prompts
341
+ # unsafe_prompts['world knowledge'] = [83]
342
+ # unsafe_prompts['abstract'] = [1, 3]
343
+
344
+ if int(prompt_id.item()) in unsafe_prompts[tag]:
345
+ st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
346
+ safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'safety_{prompt_id}')
347
+
348
+ # print('current state: ', st.session_state.gallery_state)
349
+ #
350
+ # if st.session_state.gallery_state == 'graph':
351
  if safety_check:
352
  self.graph_mode(prompt_id, items)
353
  with subset_selector[-1]:
 
362
  checkout = st.button('Check out selections', use_container_width=True, type='primary')
363
  if checkout:
364
  print('checkout')
365
+ # add focus to session state
366
+ st.session_state.gallery_focus['tag'] = tag
367
+ st.session_state.gallery_focus['prompt'] = selected_prompt
368
 
369
+ st.session_state.gallery_state = 'check out'
370
+ print(st.session_state.gallery_state)
371
  st.experimental_rerun()
372
  else:
373
  st.write(':orange[👇 **Select images you like below**]')
374
+ try:
375
+ self.sidebar(items, prompt_id, note)
376
+ except:
377
+ pass
378
 
379
+ elif st.session_state.gallery_state == 'check out':
380
+ # select items under the current tag, while model_id in selected_dict keys with corresponding modelVersion_ids
381
+ items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
382
+ temp_items = pd.DataFrame()
383
+ for prompt_id, selected_models in st.session_state.selected_dict.items():
384
+ temp_items = temp_items.append(items[items['modelVersion_id'].isin(selected_models) & (items['prompt_id'] == prompt_id)])
385
+ items = temp_items.reset_index(drop=True)
386
 
387
+ self.checkout_mode(tag, items)
 
 
 
 
 
 
388
 
 
 
 
 
389
 
390
+ # items = items[items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(
391
+ # drop=True)
392
+ # self.gallery_mode(prompt_id, items)
393
+ #
394
+ # with subset_selector[-1]:
395
+ # state_operations = st.columns([1, 1])
396
+ # with state_operations[0]:
397
+ # back = st.button('Back to 🖼️', use_container_width=True)
398
+ # if back:
399
+ # st.session_state.gallery_state[prompt_id] = 'graph'
400
+ # st.experimental_rerun()
401
+ #
402
+ # with state_operations[1]:
403
+ # forward = st.button('Check out', use_container_width=True, type='primary', on_click=self.submit_actions, args=('Continue', prompt_id))
404
+ # if forward:
405
+ # switch_page('ranking')
406
+
407
 
408
  def graph_mode(self, prompt_id, items):
409
  graph_cols = st.columns([3, 1])
 
514
  with gallery_space.container():
515
  self.gallery_standard(items, col_num, info)
516
 
517
+ def checkout_mode(self, tag, items):
518
+ # st.write(items)
519
+ if len(items) > 0:
520
+ for prompt_id in items['prompt_id'].unique():
521
+ prompt = items[items['prompt_id'] == prompt_id]['prompt'].unique()[0]
522
+ default_expand = True if st.session_state.gallery_focus['prompt'] == prompt else False
523
+ with st.expander(f'**{prompt}**', expanded=default_expand):
524
+ # st.caption('select info to show')
525
+ checkout_panel = st.columns([6, 1, 1])
526
+ with checkout_panel[0]:
527
+ pass
528
+ info = st.multiselect('Show Info',
529
+ ['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id',
530
+ 'weighted_score_sum', 'model_download_count', 'clip_score', 'mcos_score',
531
+ 'nsfw_score', 'norm_nsfw'],
532
+ label_visibility='collapsed', key=f'info_{prompt_id}', placeholder='Select info to show')
533
+ with checkout_panel[1]:
534
+ back = st.button('Back to 🖼️', key=f'checkout_back_{prompt_id}', use_container_width=True)
535
+ if back:
536
+ st.session_state.gallery_focus['tag'] = tag
537
+ st.session_state.gallery_focus['prompt'] = prompt
538
+ print(st.session_state.gallery_focus)
539
+ st.session_state.gallery_state = 'graph'
540
+ st.experimental_rerun()
541
+ with checkout_panel[2]:
542
+ proceed = st.button('Proceed ➡️', key=f'checkout_proceed_{prompt_id}', use_container_width=True,
543
+ type='primary')
544
+ if proceed:
545
+ st.session_state.gallery_focus['tag'] = tag
546
+ st.session_state.gallery_focus['prompt'] = prompt
547
+ switch_page('ranking')
548
+
549
+ self.gallery_standard(items[items['prompt_id'] == prompt_id].reset_index(drop=True), 4, info, show_checkbox=False)
550
+ else:
551
+ # with st.form(key=f'checkout_{tag}'):
552
+ st.info('No selection under this tag')
553
+ back = st.button('🖼️ Back to gallery and select something you like', key=f'checkout_{tag}', type='primary')
554
+ if back:
555
+ st.session_state.gallery_focus['tag'] = tag
556
+ st.session_state.gallery_focus['prompt'] = None
557
+ st.session_state.gallery_state = 'graph'
558
+ st.experimental_rerun()
559
+
560
  def submit_actions(self, status, prompt_id):
561
  # remove counter from session state
562
  # st.session_state.pop('counter', None)
pages/Ranking.py CHANGED
@@ -149,10 +149,10 @@ class RankingApp:
149
  with control:
150
  if st.session_state.counter[prompt_id] < batch_num - 1:
151
  st.button(":arrow_right:", key='next', on_click=self.next_batch, help='Next Batch',
152
- kwargs={'prompt_id': prompt_id}, use_container_width=True)
153
  else:
154
  st.button(":heavy_check_mark:", key='finished', on_click=self.next_batch, help='Finished',
155
- kwargs={'prompt_id': prompt_id, 'progress': 'finished'}, use_container_width=True)
156
 
157
  if st.session_state.counter[prompt_id] > 0:
158
  st.button(":arrow_left:", key='prev', on_click=self.prev_batch, help='Previous Batch',
 
149
  with control:
150
  if st.session_state.counter[prompt_id] < batch_num - 1:
151
  st.button(":arrow_right:", key='next', on_click=self.next_batch, help='Next Batch',
152
+ kwargs={'prompt_id': prompt_id}, use_container_width=True, type='primary')
153
  else:
154
  st.button(":heavy_check_mark:", key='finished', on_click=self.next_batch, help='Finished',
155
+ kwargs={'prompt_id': prompt_id, 'progress': 'finished'}, use_container_width=True, type='primary')
156
 
157
  if st.session_state.counter[prompt_id] > 0:
158
  st.button(":arrow_left:", key='prev', on_click=self.prev_batch, help='Previous Batch',