Ricercar commited on
Commit
81e6943
·
1 Parent(s): fac8866

clear up unused code

Browse files
Files changed (1) hide show
  1. pages/Gallery.py +2 -283
pages/Gallery.py CHANGED
@@ -18,8 +18,6 @@ from streamlit_extras.tags import tagger_component
18
  from streamlit_extras.no_default_selectbox import selectbox
19
  from sklearn.svm import LinearSVC
20
 
21
- SCORE_NAME_MAPPING = {'clip': 'clip_score', 'rank': 'msq_score', 'pop': 'model_download_count'}
22
-
23
 
24
  class GalleryApp:
25
  def __init__(self, promptBook, images_ds):
@@ -123,113 +121,6 @@ class GalleryApp:
123
  config=config,
124
  )
125
 
126
- def selection_panel(self, items):
127
- # temperal function
128
-
129
- selecters = st.columns([1, 4])
130
-
131
- if 'score_weights' not in st.session_state:
132
- # st.session_state.score_weights = [1.0, 0.8, 0.2, 0.8]
133
- st.session_state.score_weights = [1.0, 0.8, 0.2]
134
-
135
- # select sort type
136
- with selecters[0]:
137
- sort_type = st.selectbox('Sort by', ['Scores', 'IDs and Names'])
138
- if sort_type == 'Scores':
139
- sort_by = 'weighted_score_sum'
140
-
141
- # select other options
142
- with selecters[1]:
143
- if sort_type == 'IDs and Names':
144
- sub_selecters = st.columns([3])
145
- # select sort by
146
- with sub_selecters[0]:
147
- sort_by = st.selectbox('Sort by',
148
- ['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id', 'norm_nsfw'],
149
- label_visibility='hidden')
150
-
151
- continue_idx = 1
152
-
153
- else:
154
- # add custom weights
155
- sub_selecters = st.columns([1, 1, 1])
156
-
157
- with sub_selecters[0]:
158
- clip_weight = st.number_input('Clip Score Weight', min_value=-100.0, max_value=100.0, value=1.0, step=0.1, help='the weight for normalized clip score')
159
- with sub_selecters[1]:
160
- mcos_weight = st.number_input('Dissimilarity Weight', min_value=-100.0, max_value=100.0, value=0.8, step=0.1, help='the weight for m(eam) s(imilarity) q(antile) score for measuring distinctiveness')
161
- with sub_selecters[2]:
162
- pop_weight = st.number_input('Popularity Weight', min_value=-100.0, max_value=100.0, value=0.2, step=0.1, help='the weight for normalized popularity score')
163
-
164
- items.loc[:, 'weighted_score_sum'] = round(items[f'norm_clip'] * clip_weight + items[f'norm_mcos'] * mcos_weight + items[
165
- 'norm_pop'] * pop_weight, 4)
166
-
167
- continue_idx = 3
168
-
169
- # save latest weights
170
- st.session_state.score_weights[0] = round(clip_weight, 2)
171
- st.session_state.score_weights[1] = round(mcos_weight, 2)
172
- st.session_state.score_weights[2] = round(pop_weight, 2)
173
-
174
- # # select threshold
175
- # with sub_selecters[continue_idx]:
176
- # nsfw_threshold = st.number_input('NSFW Score Threshold', min_value=0.0, max_value=1.0, value=0.8, step=0.01, help='Only show models with nsfw score lower than this threshold, set 1.0 to show all images')
177
- # items = items[items['norm_nsfw'] <= nsfw_threshold].reset_index(drop=True)
178
- #
179
- # # save latest threshold
180
- # st.session_state.score_weights[3] = nsfw_threshold
181
-
182
- # # draw a distribution histogram
183
- # if sort_type == 'Scores':
184
- # try:
185
- # with st.expander('Show score distribution histogram and select score range'):
186
- # st.write('**Score distribution histogram**')
187
- # chart_space = st.container()
188
- # # st.write('Select the range of scores to show')
189
- # hist_data = pd.DataFrame(items[sort_by])
190
- # mini = hist_data[sort_by].min().item()
191
- # mini = mini//0.1 * 0.1
192
- # maxi = hist_data[sort_by].max().item()
193
- # maxi = maxi//0.1 * 0.1 + 0.1
194
- # st.write('**Select the range of scores to show**')
195
- # r = st.slider('Select the range of scores to show', min_value=mini, max_value=maxi, value=(mini, maxi), step=0.05, label_visibility='collapsed')
196
- # with chart_space:
197
- # st.altair_chart(altair_histogram(hist_data, sort_by, r[0], r[1]), use_container_width=True)
198
- # # event_dict = altair_component(altair_chart=altair_histogram(hist_data, sort_by))
199
- # # r = event_dict.get(sort_by)
200
- # if r:
201
- # items = items[(items[sort_by] >= r[0]) & (items[sort_by] <= r[1])].reset_index(drop=True)
202
- # # st.write(r)
203
- # except:
204
- # pass
205
-
206
- display_options = st.columns([1, 4])
207
-
208
- with display_options[0]:
209
- # select order
210
- order = st.selectbox('Order', ['Ascending', 'Descending'], index=1 if sort_type == 'Scores' else 0)
211
- if order == 'Ascending':
212
- order = True
213
- else:
214
- order = False
215
-
216
- with display_options[1]:
217
-
218
- # select info to show
219
- info = st.multiselect('Show Info',
220
- ['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id',
221
- 'weighted_score_sum', 'model_download_count', 'clip_score', 'mcos_score',
222
- 'nsfw_score', 'norm_nsfw'],
223
- default=sort_by)
224
-
225
- # apply sorting to dataframe
226
- items = items.sort_values(by=[sort_by], ascending=order).reset_index(drop=True)
227
-
228
- # select number of columns
229
- col_num = st.slider('Number of columns', min_value=1, max_value=9, value=4, step=1, key='col_num')
230
-
231
- return items, info, col_num
232
-
233
  def sidebar(self, items, prompt_id, note):
234
  with st.sidebar:
235
  # show source
@@ -476,50 +367,6 @@ class GalleryApp:
476
  else:
477
  st.info('Please click on an image to show')
478
 
479
- def gallery_mode(self, prompt_id, items):
480
- items, info, col_num = self.selection_panel(items)
481
-
482
- # if 'selected_dict' in st.session_state:
483
- # # st.write('checked: ', str(st.session_state.selected_dict.get(prompt_id, [])))
484
- # dynamic_weight_options = ['Grid Search', 'SVM', 'Greedy']
485
- # dynamic_weight_panel = st.columns(len(dynamic_weight_options))
486
- #
487
- # if len(st.session_state.selected_dict.get(prompt_id, [])) > 0:
488
- # btn_disable = False
489
- # else:
490
- # btn_disable = True
491
- #
492
- # for i in range(len(dynamic_weight_options)):
493
- # method = dynamic_weight_options[i]
494
- # with dynamic_weight_panel[i]:
495
- # btn = st.button(method, use_container_width=True, disabled=btn_disable, on_click=self.dynamic_weight, args=(prompt_id, items, method))
496
-
497
- # prompt = st.chat_input(f"Selected model version ids: {str(st.session_state.selected_dict.get(prompt_id, []))}", disabled=False, key=f'{prompt_id}')
498
- # if prompt:
499
- # switch_page("ranking")
500
-
501
- # with st.form(key=f'{prompt_id}'):
502
- # buttons = st.columns([1, 1, 1])
503
- # buttons_space = st.columns([1, 1, 1])
504
- gallery_space = st.empty()
505
-
506
- # with buttons_space[0]:
507
- # continue_btn = st.button('Proceed selections to ranking', use_container_width=True, type='primary')
508
- # if continue_btn:
509
- # # self.submit_actions('Continue', prompt_id)
510
- # switch_page("ranking")
511
- #
512
- # with buttons_space[1]:
513
- # deselect_btn = st.button('Deselect All', use_container_width=True)
514
- # if deselect_btn:
515
- # self.submit_actions('Deselect', prompt_id)
516
- #
517
- # with buttons_space[2]:
518
- # refresh_btn = st.button('Refresh', on_click=gallery_space.empty, use_container_width=True)
519
-
520
- with gallery_space.container():
521
- self.gallery_standard(items, col_num, info)
522
-
523
  def checkout_mode(self, tag, items):
524
  # st.write(items)
525
  if len(items) > 0:
@@ -533,8 +380,8 @@ class GalleryApp:
533
  pass
534
  info = st.multiselect('Show Info',
535
  ['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id',
536
- 'weighted_score_sum', 'model_download_count', 'clip_score', 'mcos_score',
537
- 'nsfw_score', 'norm_nsfw'],
538
  label_visibility='collapsed', key=f'info_{prompt_id}', placeholder='Select what infos to show')
539
 
540
  with checkout_panel[-1]:
@@ -593,100 +440,6 @@ class GalleryApp:
593
  st.session_state.gallery_state = 'graph'
594
  st.experimental_rerun()
595
 
596
- def submit_actions(self, status, prompt_id):
597
- # remove counter from session state
598
- # st.session_state.pop('counter', None)
599
- self.remove_ranking_states('prompt_id')
600
- if status == 'Select':
601
- modelVersions = self.promptBook[self.promptBook['prompt_id'] == prompt_id]['modelVersion_id'].unique()
602
- st.session_state.selected_dict[prompt_id] = modelVersions.tolist()
603
- print(st.session_state.selected_dict, 'select')
604
- st.experimental_rerun()
605
- elif status == 'Deselect':
606
- st.session_state.selected_dict[prompt_id] = []
607
- print(st.session_state.selected_dict, 'deselect')
608
- st.experimental_rerun()
609
- # self.promptBook.loc[self.promptBook['prompt_id'] == prompt_id, 'checked'] = False
610
- elif status == 'Continue':
611
- st.session_state.selected_dict[prompt_id] = []
612
- for key in st.session_state:
613
- keys = key.split('_')
614
- if keys[0] == 'select' and keys[1] == str(prompt_id):
615
- if st.session_state[key]:
616
- st.session_state.selected_dict[prompt_id].append(int(keys[2]))
617
- # switch_page("ranking")
618
- print(st.session_state.selected_dict, 'continue')
619
- # st.experimental_rerun()
620
-
621
- def dynamic_weight(self, prompt_id, items, method='Grid Search'):
622
- selected = items[
623
- items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(drop=True)
624
- optimal_weight = [0, 0, 0]
625
-
626
- if method == 'Grid Search':
627
- # grid search method
628
- top_ranking = len(items) * len(selected)
629
-
630
- for clip_weight in np.arange(-1, 1, 0.1):
631
- for mcos_weight in np.arange(-1, 1, 0.1):
632
- for pop_weight in np.arange(-1, 1, 0.1):
633
-
634
- weight_all = clip_weight*items[f'norm_clip'] + mcos_weight*items[f'norm_mcos'] + pop_weight*items['norm_pop']
635
- weight_all_sorted = weight_all.sort_values(ascending=False).reset_index(drop=True)
636
- # print('weight_all_sorted:', weight_all_sorted)
637
- weight_selected = clip_weight*selected[f'norm_clip'] + mcos_weight*selected[f'norm_mcos'] + pop_weight*selected['norm_pop']
638
-
639
- # get the index of values of weight_selected in weight_all_sorted
640
- rankings = []
641
- for weight in weight_selected:
642
- rankings.append(weight_all_sorted.index[weight_all_sorted == weight].tolist()[0])
643
- if sum(rankings) <= top_ranking:
644
- top_ranking = sum(rankings)
645
- print('current top ranking:', top_ranking, rankings)
646
- optimal_weight = [clip_weight, mcos_weight, pop_weight]
647
- print('optimal weight:', optimal_weight)
648
-
649
- elif method == 'SVM':
650
- # svm method
651
- print('start svm method')
652
- # get residual dataframe that contains models not selected
653
- residual = items[~items['modelVersion_id'].isin(selected['modelVersion_id'])].reset_index(drop=True)
654
- residual = residual[['norm_clip_crop', 'norm_mcos_crop', 'norm_pop']]
655
- residual = residual.to_numpy()
656
- selected = selected[['norm_clip_crop', 'norm_mcos_crop', 'norm_pop']]
657
- selected = selected.to_numpy()
658
-
659
- y = np.concatenate((np.full((len(selected), 1), -1), np.full((len(residual), 1), 1)), axis=0).ravel()
660
- X = np.concatenate((selected, residual), axis=0)
661
-
662
- # fit svm model, and get parameters for the hyperplane
663
- clf = LinearSVC(random_state=0, C=1.0, fit_intercept=False, dual='auto')
664
- clf.fit(X, y)
665
- optimal_weight = clf.coef_[0].tolist()
666
- print('optimal weight:', optimal_weight)
667
- pass
668
-
669
- elif method == 'Greedy':
670
- for idx in selected.index:
671
- # find which score is the highest, clip, mcos, or pop
672
- clip_score = selected.loc[idx, 'norm_clip_crop']
673
- mcos_score = selected.loc[idx, 'norm_mcos_crop']
674
- pop_score = selected.loc[idx, 'norm_pop']
675
- if clip_score >= mcos_score and clip_score >= pop_score:
676
- optimal_weight[0] += 1
677
- elif mcos_score >= clip_score and mcos_score >= pop_score:
678
- optimal_weight[1] += 1
679
- elif pop_score >= clip_score and pop_score >= mcos_score:
680
- optimal_weight[2] += 1
681
-
682
- # normalize optimal_weight
683
- optimal_weight = [round(weight/len(selected), 2) for weight in optimal_weight]
684
- print('optimal weight:', optimal_weight)
685
- print('optimal weight:', optimal_weight)
686
-
687
- st.session_state.score_weights[0: 3] = optimal_weight
688
-
689
-
690
  def remove_ranking_states(self, prompt_id):
691
  # for drag sort
692
  try:
@@ -710,34 +463,6 @@ class GalleryApp:
710
  except:
711
  print('no page progress states to be reset')
712
 
713
-
714
- # hist_data = pd.DataFrame(np.random.normal(42, 10, (200, 1)), columns=["x"])
715
- @st.cache_resource
716
- def altair_histogram(hist_data, sort_by, mini, maxi):
717
- brushed = alt.selection_interval(encodings=['x'], name="brushed")
718
-
719
- chart = (
720
- alt.Chart(hist_data)
721
- .mark_bar(opacity=0.7, cornerRadius=2)
722
- .encode(alt.X(f"{sort_by}:Q", bin=alt.Bin(maxbins=25)), y="count()")
723
- # .add_selection(brushed)
724
- # .properties(width=800, height=300)
725
- )
726
-
727
- # Create a transparent rectangle for highlighting the range
728
- highlight = (
729
- alt.Chart(pd.DataFrame({'x1': [mini], 'x2': [maxi]}))
730
- .mark_rect(opacity=0.3)
731
- .encode(x='x1', x2='x2')
732
- # .properties(width=800, height=300)
733
- )
734
-
735
- # Layer the chart and the highlight rectangle
736
- layered_chart = alt.layer(chart, highlight)
737
-
738
- return layered_chart
739
-
740
-
741
  @st.cache_data
742
  def load_hf_dataset(show_NSFW=False):
743
  # login to huggingface
@@ -797,13 +522,7 @@ if __name__ == "__main__":
797
  if home_btn:
798
  switch_page("home")
799
  else:
800
- # st.write('You have already logged in as ' + st.session_state.user_id[0])
801
  roster, promptBook, images_ds = load_hf_dataset(st.session_state.show_NSFW)
802
- # print(promptBook.columns)
803
-
804
- # # initialize selected_dict
805
- # if 'selected_dict' not in st.session_state:
806
- # st.session_state['selected_dict'] = {}
807
 
808
  app = GalleryApp(promptBook=promptBook, images_ds=images_ds)
809
  app.app()
 
18
  from streamlit_extras.no_default_selectbox import selectbox
19
  from sklearn.svm import LinearSVC
20
 
 
 
21
 
22
  class GalleryApp:
23
  def __init__(self, promptBook, images_ds):
 
121
  config=config,
122
  )
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  def sidebar(self, items, prompt_id, note):
125
  with st.sidebar:
126
  # show source
 
367
  else:
368
  st.info('Please click on an image to show')
369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  def checkout_mode(self, tag, items):
371
  # st.write(items)
372
  if len(items) > 0:
 
380
  pass
381
  info = st.multiselect('Show Info',
382
  ['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id',
383
+ 'total_score', 'model_download_count', 'clip_score', 'mcos_score',
384
+ 'norm_nsfw'],
385
  label_visibility='collapsed', key=f'info_{prompt_id}', placeholder='Select what infos to show')
386
 
387
  with checkout_panel[-1]:
 
440
  st.session_state.gallery_state = 'graph'
441
  st.experimental_rerun()
442
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
  def remove_ranking_states(self, prompt_id):
444
  # for drag sort
445
  try:
 
463
  except:
464
  print('no page progress states to be reset')
465
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
  @st.cache_data
467
  def load_hf_dataset(show_NSFW=False):
468
  # login to huggingface
 
522
  if home_btn:
523
  switch_page("home")
524
  else:
 
525
  roster, promptBook, images_ds = load_hf_dataset(st.session_state.show_NSFW)
 
 
 
 
 
526
 
527
  app = GalleryApp(promptBook=promptBook, images_ds=images_ds)
528
  app.app()