Ricercar commited on
Commit
f3de8ec
·
1 Parent(s): 0b0509d

remove class

Browse files
Files changed (2) hide show
  1. Archive/Gallery_archive_8_5.py +446 -0
  2. pages/Gallery.py +337 -346
Archive/Gallery_archive_8_5.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+
4
+ import altair as alt
5
+ import numpy as np
6
+ import pandas as pd
7
+ import streamlit as st
8
+
9
+ from bs4 import BeautifulSoup
10
+ from datasets import load_dataset, Dataset, load_from_disk
11
+ from huggingface_hub import login
12
+ from streamlit_extras.switch_page_button import switch_page
13
+ from sklearn.svm import LinearSVC
14
+
15
+ SCORE_NAME_MAPPING = {'clip': 'clip_score', 'rank': 'msq_score', 'pop': 'model_download_count'}
16
+
17
+
18
+ class GalleryApp:
19
+ def __init__(self, promptBook, images_ds):
20
+ self.promptBook = promptBook
21
+ self.images_ds = images_ds
22
+
23
+ def gallery_standard(self, items, col_num, info):
24
+ rows = len(items) // col_num + 1
25
+ containers = [st.container() for _ in range(rows)]
26
+ for idx in range(0, len(items), col_num):
27
+ row_idx = idx // col_num
28
+ with containers[row_idx]:
29
+ cols = st.columns(col_num)
30
+ for j in range(col_num):
31
+ if idx + j < len(items):
32
+ with cols[j]:
33
+ # show image
34
+ # image = self.images_ds[items.iloc[idx + j]['row_idx'].item()]['image']
35
+ # image = f"https://modelcofferbucket.s3.us-east-2.amazonaws.com/{items.iloc[idx + j]['image_id']}.png"
36
+ image = f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{items.iloc[idx + j]['image_id']}.png"
37
+ st.image(image, use_column_width=True)
38
+
39
+ # handel checkbox information
40
+ prompt_id = items.iloc[idx + j]['prompt_id']
41
+ modelVersion_id = items.iloc[idx + j]['modelVersion_id']
42
+
43
+ check_init = True if modelVersion_id in st.session_state.selected_dict.get(prompt_id, []) else False
44
+
45
+ st.write("Position: ", idx + j)
46
+
47
+ # show checkbox
48
+ st.checkbox('Select', key=f'select_{prompt_id}_{modelVersion_id}', value=check_init)
49
+
50
+ # show selected info
51
+ for key in info:
52
+ st.write(f"**{key}**: {items.iloc[idx + j][key]}")
53
+
54
+ def selection_panel(self, items):
55
+ # temperal function
56
+
57
+ selecters = st.columns([1, 4])
58
+
59
+ if 'score_weights' not in st.session_state:
60
+ st.session_state.score_weights = [1.0, 0.8, 0.2, 0.8]
61
+
62
+ # select sort type
63
+ with selecters[0]:
64
+ sort_type = st.selectbox('Sort by', ['Scores', 'IDs and Names'])
65
+ if sort_type == 'Scores':
66
+ sort_by = 'weighted_score_sum'
67
+
68
+ # select other options
69
+ with selecters[1]:
70
+ if sort_type == 'IDs and Names':
71
+ sub_selecters = st.columns([3, 1])
72
+ # select sort by
73
+ with sub_selecters[0]:
74
+ sort_by = st.selectbox('Sort by',
75
+ ['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id', 'norm_nsfw'],
76
+ label_visibility='hidden')
77
+
78
+ continue_idx = 1
79
+
80
+ else:
81
+ # add custom weights
82
+ sub_selecters = st.columns([1, 1, 1, 1])
83
+
84
+ with sub_selecters[0]:
85
+ clip_weight = st.number_input('Clip Score Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[0], step=0.1, help='the weight for normalized clip score')
86
+ with sub_selecters[1]:
87
+ mcos_weight = st.number_input('Dissimilarity Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[1], step=0.1, help='the weight for m(eam) s(imilarity) q(antile) score for measuring distinctiveness')
88
+ with sub_selecters[2]:
89
+ pop_weight = st.number_input('Popularity Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[2], step=0.1, help='the weight for normalized popularity score')
90
+
91
+ items.loc[:, 'weighted_score_sum'] = round(items[f'norm_clip'] * clip_weight + items[f'norm_mcos'] * mcos_weight + items[
92
+ 'norm_pop'] * pop_weight, 4)
93
+
94
+ continue_idx = 3
95
+
96
+ # save latest weights
97
+ st.session_state.score_weights[0] = clip_weight
98
+ st.session_state.score_weights[1] = mcos_weight
99
+ st.session_state.score_weights[2] = pop_weight
100
+
101
+ # select threshold
102
+ with sub_selecters[continue_idx]:
103
+ nsfw_threshold = st.number_input('NSFW Score Threshold', min_value=0.0, max_value=1.0, value=st.session_state.score_weights[3], step=0.01, help='Only show models with nsfw score lower than this threshold, set 1.0 to show all images')
104
+ items = items[items['norm_nsfw'] <= nsfw_threshold].reset_index(drop=True)
105
+
106
+ # save latest threshold
107
+ st.session_state.score_weights[3] = nsfw_threshold
108
+
109
+ # draw a distribution histogram
110
+ if sort_type == 'Scores':
111
+ try:
112
+ with st.expander('Show score distribution histogram and select score range'):
113
+ st.write('**Score distribution histogram**')
114
+ chart_space = st.container()
115
+ # st.write('Select the range of scores to show')
116
+ hist_data = pd.DataFrame(items[sort_by])
117
+ mini = hist_data[sort_by].min().item()
118
+ mini = mini//0.1 * 0.1
119
+ maxi = hist_data[sort_by].max().item()
120
+ maxi = maxi//0.1 * 0.1 + 0.1
121
+ st.write('**Select the range of scores to show**')
122
+ r = st.slider('Select the range of scores to show', min_value=mini, max_value=maxi, value=(mini, maxi), step=0.05, label_visibility='collapsed')
123
+ with chart_space:
124
+ st.altair_chart(altair_histogram(hist_data, sort_by, r[0], r[1]), use_container_width=True)
125
+ # event_dict = altair_component(altair_chart=altair_histogram(hist_data, sort_by))
126
+ # r = event_dict.get(sort_by)
127
+ if r:
128
+ items = items[(items[sort_by] >= r[0]) & (items[sort_by] <= r[1])].reset_index(drop=True)
129
+ # st.write(r)
130
+ except:
131
+ pass
132
+
133
+ display_options = st.columns([1, 4])
134
+
135
+ with display_options[0]:
136
+ # select order
137
+ order = st.selectbox('Order', ['Ascending', 'Descending'], index=1 if sort_type == 'Scores' else 0)
138
+ if order == 'Ascending':
139
+ order = True
140
+ else:
141
+ order = False
142
+
143
+ with display_options[1]:
144
+
145
+ # select info to show
146
+ info = st.multiselect('Show Info',
147
+ ['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id',
148
+ 'weighted_score_sum', 'model_download_count', 'clip_score', 'mcos_score',
149
+ 'nsfw_score', 'norm_nsfw'],
150
+ default=sort_by)
151
+
152
+ # apply sorting to dataframe
153
+ items = items.sort_values(by=[sort_by], ascending=order).reset_index(drop=True)
154
+
155
+ # select number of columns
156
+ col_num = st.slider('Number of columns', min_value=1, max_value=9, value=4, step=1, key='col_num')
157
+
158
+ return items, info, col_num
159
+
160
+ def sidebar(self):
161
+ with st.sidebar:
162
+ prompt_tags = self.promptBook['tag'].unique()
163
+ # sort tags by alphabetical order
164
+ prompt_tags = np.sort(prompt_tags)[::-1]
165
+
166
+ tag = st.selectbox('Select a tag', prompt_tags)
167
+
168
+ items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
169
+
170
+ prompts = np.sort(items['prompt'].unique())[::-1]
171
+
172
+ selected_prompt = st.selectbox('Select prompt', prompts)
173
+
174
+ items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
175
+ prompt_id = items['prompt_id'].unique()[0]
176
+ note = items['note'].unique()[0]
177
+
178
+ # show source
179
+ if isinstance(note, str):
180
+ if note.isdigit():
181
+ st.caption(f"`Source: civitai`")
182
+ else:
183
+ st.caption(f"`Source: {note}`")
184
+ else:
185
+ st.caption("`Source: Parti-prompts`")
186
+
187
+ # show image metadata
188
+ image_metadatas = ['prompt_id', 'prompt', 'negativePrompt', 'sampler', 'cfgScale', 'size', 'seed']
189
+ for key in image_metadatas:
190
+ label = ' '.join(key.split('_')).capitalize()
191
+ st.write(f"**{label}**")
192
+ if items[key][0] == ' ':
193
+ st.write('`None`')
194
+ else:
195
+ st.caption(f"{items[key][0]}")
196
+
197
+ # for note as civitai image id, add civitai reference
198
+ if isinstance(note, str) and note.isdigit():
199
+ try:
200
+ st.write(f'**[Civitai Reference](https://civitai.com/images/{note})**')
201
+ res = requests.get(f'https://civitai.com/images/{note}')
202
+ # st.write(res.text)
203
+ soup = BeautifulSoup(res.text, 'html.parser')
204
+ image_section = soup.find('div', {'class': 'mantine-12rlksp'})
205
+ image_url = image_section.find('img')['src']
206
+ st.image(image_url, use_column_width=True)
207
+ except:
208
+ pass
209
+
210
+ return prompt_tags, tag, prompt_id, items
211
+
212
+ def app(self):
213
+ st.title('Model Visualization and Retrieval')
214
+ st.write('This is a gallery of images generated by the models')
215
+
216
+ prompt_tags, tag, prompt_id, items = self.sidebar()
217
+
218
+ # add safety check for some prompts
219
+ safety_check = True
220
+ unsafe_prompts = {}
221
+ # initialize unsafe prompts
222
+ for prompt_tag in prompt_tags:
223
+ unsafe_prompts[prompt_tag] = []
224
+ # manually add unsafe prompts
225
+ unsafe_prompts['world knowledge'] = [83]
226
+ # unsafe_prompts['art'] = [23]
227
+ unsafe_prompts['abstract'] = [1, 3]
228
+ # unsafe_prompts['food'] = [34]
229
+
230
+ if int(prompt_id.item()) in unsafe_prompts[tag]:
231
+ st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
232
+ safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'{prompt_id}')
233
+
234
+ if safety_check:
235
+ items, info, col_num = self.selection_panel(items)
236
+
237
+ if 'selected_dict' in st.session_state:
238
+ st.write('checked: ', str(st.session_state.selected_dict.get(prompt_id, [])))
239
+ dynamic_weight_options = ['Grid Search', 'SVM', 'Greedy']
240
+ dynamic_weight_panel = st.columns(len(dynamic_weight_options))
241
+
242
+ if len(st.session_state.selected_dict.get(prompt_id, [])) > 0:
243
+ btn_disable = False
244
+ else:
245
+ btn_disable = True
246
+
247
+ for i in range(len(dynamic_weight_options)):
248
+ method = dynamic_weight_options[i]
249
+ with dynamic_weight_panel[i]:
250
+ btn = st.button(method, use_container_width=True, disabled=btn_disable, on_click=self.dynamic_weight, args=(prompt_id, items, method))
251
+
252
+ with st.form(key=f'{prompt_id}'):
253
+ # buttons = st.columns([1, 1, 1])
254
+ buttons_space = st.columns([1, 1, 1, 1])
255
+ gallery_space = st.empty()
256
+
257
+ with buttons_space[0]:
258
+ continue_btn = st.form_submit_button('Confirm Selection', use_container_width=True, type='primary')
259
+ if continue_btn:
260
+ self.submit_actions('Continue', prompt_id)
261
+
262
+ with buttons_space[1]:
263
+ select_btn = st.form_submit_button('Select All', use_container_width=True)
264
+ if select_btn:
265
+ self.submit_actions('Select', prompt_id)
266
+
267
+ with buttons_space[2]:
268
+ deselect_btn = st.form_submit_button('Deselect All', use_container_width=True)
269
+ if deselect_btn:
270
+ self.submit_actions('Deselect', prompt_id)
271
+
272
+ with buttons_space[3]:
273
+ refresh_btn = st.form_submit_button('Refresh', on_click=gallery_space.empty, use_container_width=True)
274
+
275
+ with gallery_space.container():
276
+ with st.spinner('Loading images...'):
277
+ self.gallery_standard(items, col_num, info)
278
+
279
+ def submit_actions(self, status, prompt_id):
280
+ if status == 'Select':
281
+ modelVersions = self.promptBook[self.promptBook['prompt_id'] == prompt_id]['modelVersion_id'].unique()
282
+ st.session_state.selected_dict[prompt_id] = modelVersions.tolist()
283
+ print(st.session_state.selected_dict, 'select')
284
+ st.experimental_rerun()
285
+ elif status == 'Deselect':
286
+ st.session_state.selected_dict[prompt_id] = []
287
+ print(st.session_state.selected_dict, 'deselect')
288
+ st.experimental_rerun()
289
+ # self.promptBook.loc[self.promptBook['prompt_id'] == prompt_id, 'checked'] = False
290
+ elif status == 'Continue':
291
+ st.session_state.selected_dict[prompt_id] = []
292
+ for key in st.session_state:
293
+ keys = key.split('_')
294
+ if keys[0] == 'select' and keys[1] == str(prompt_id):
295
+ if st.session_state[key]:
296
+ st.session_state.selected_dict[prompt_id].append(int(keys[2]))
297
+ # switch_page("ranking")
298
+ print(st.session_state.selected_dict, 'continue')
299
+ st.experimental_rerun()
300
+
301
+ def dynamic_weight(self, prompt_id, items, method='Grid Search'):
302
+ selected = items[
303
+ items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(drop=True)
304
+ optimal_weight = [0, 0, 0]
305
+
306
+ if method == 'Grid Search':
307
+ # grid search method
308
+ top_ranking = len(items) * len(selected)
309
+
310
+ for clip_weight in np.arange(-1, 1, 0.1):
311
+ for mcos_weight in np.arange(-1, 1, 0.1):
312
+ for pop_weight in np.arange(-1, 1, 0.1):
313
+
314
+ weight_all = clip_weight*items[f'norm_clip'] + mcos_weight*items[f'norm_mcos'] + pop_weight*items['norm_pop']
315
+ weight_all_sorted = weight_all.sort_values(ascending=False).reset_index(drop=True)
316
+ # print('weight_all_sorted:', weight_all_sorted)
317
+ weight_selected = clip_weight*selected[f'norm_clip'] + mcos_weight*selected[f'norm_mcos'] + pop_weight*selected['norm_pop']
318
+
319
+ # get the index of values of weight_selected in weight_all_sorted
320
+ rankings = []
321
+ for weight in weight_selected:
322
+ rankings.append(weight_all_sorted.index[weight_all_sorted == weight].tolist()[0])
323
+ if sum(rankings) <= top_ranking:
324
+ top_ranking = sum(rankings)
325
+ print('current top ranking:', top_ranking, rankings)
326
+ optimal_weight = [clip_weight, mcos_weight, pop_weight]
327
+ print('optimal weight:', optimal_weight)
328
+
329
+ elif method == 'SVM':
330
+ # svm method
331
+ print('start svm method')
332
+ # get residual dataframe that contains models not selected
333
+ residual = items[~items['modelVersion_id'].isin(selected['modelVersion_id'])].reset_index(drop=True)
334
+ residual = residual[['norm_clip_crop', 'norm_mcos_crop', 'norm_pop']]
335
+ residual = residual.to_numpy()
336
+ selected = selected[['norm_clip_crop', 'norm_mcos_crop', 'norm_pop']]
337
+ selected = selected.to_numpy()
338
+
339
+ y = np.concatenate((np.full((len(selected), 1), -1), np.full((len(residual), 1), 1)), axis=0).ravel()
340
+ X = np.concatenate((selected, residual), axis=0)
341
+
342
+ # fit svm model, and get parameters for the hyperplane
343
+ clf = LinearSVC(random_state=0, C=1.0, fit_intercept=False, dual='auto')
344
+ clf.fit(X, y)
345
+ optimal_weight = clf.coef_[0].tolist()
346
+ print('optimal weight:', optimal_weight)
347
+ pass
348
+
349
+ elif method == 'Greedy':
350
+ for idx in selected.index:
351
+ # find which score is the highest, clip, mcos, or pop
352
+ clip_score = selected.loc[idx, 'norm_clip_crop']
353
+ mcos_score = selected.loc[idx, 'norm_mcos_crop']
354
+ pop_score = selected.loc[idx, 'norm_pop']
355
+ if clip_score >= mcos_score and clip_score >= pop_score:
356
+ optimal_weight[0] += 1
357
+ elif mcos_score >= clip_score and mcos_score >= pop_score:
358
+ optimal_weight[1] += 1
359
+ elif pop_score >= clip_score and pop_score >= mcos_score:
360
+ optimal_weight[2] += 1
361
+
362
+ # normalize optimal_weight
363
+ optimal_weight = [round(weight/len(selected), 2) for weight in optimal_weight]
364
+ print('optimal weight:', optimal_weight)
365
+
366
+ st.session_state.score_weights[0: 3] = optimal_weight
367
+
368
+
369
+ # hist_data = pd.DataFrame(np.random.normal(42, 10, (200, 1)), columns=["x"])
370
+ @st.cache_resource
371
+ def altair_histogram(hist_data, sort_by, mini, maxi):
372
+ brushed = alt.selection_interval(encodings=['x'], name="brushed")
373
+
374
+ chart = (
375
+ alt.Chart(hist_data)
376
+ .mark_bar(opacity=0.7, cornerRadius=2)
377
+ .encode(alt.X(f"{sort_by}:Q", bin=alt.Bin(maxbins=25)), y="count()")
378
+ # .add_selection(brushed)
379
+ # .properties(width=800, height=300)
380
+ )
381
+
382
+ # Create a transparent rectangle for highlighting the range
383
+ highlight = (
384
+ alt.Chart(pd.DataFrame({'x1': [mini], 'x2': [maxi]}))
385
+ .mark_rect(opacity=0.3)
386
+ .encode(x='x1', x2='x2')
387
+ # .properties(width=800, height=300)
388
+ )
389
+
390
+ # Layer the chart and the highlight rectangle
391
+ layered_chart = alt.layer(chart, highlight)
392
+
393
+ return layered_chart
394
+
395
+
396
+ @st.cache_data
397
+ def load_hf_dataset():
398
+ # login to huggingface
399
+ login(token=os.environ.get("HF_TOKEN"))
400
+
401
+ # load from huggingface
402
+ roster = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferRoster', split='train'))
403
+ promptBook = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferMetadata', split='train'))
404
+ # images_ds = load_from_disk(os.path.join(os.getcwd(), 'data', 'promptbook'))
405
+ images_ds = None # set to None for now since we use s3 bucket to store images
406
+
407
+ # process dataset
408
+ roster = roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name',
409
+ 'model_download_count']].drop_duplicates().reset_index(drop=True)
410
+
411
+ # add 'custom_score_weights' column to promptBook if not exist
412
+ if 'weighted_score_sum' not in promptBook.columns:
413
+ promptBook.loc[:, 'weighted_score_sum'] = 0
414
+
415
+ # merge roster and promptbook
416
+ promptBook = promptBook.merge(roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name', 'model_download_count']],
417
+ on=['model_id', 'modelVersion_id'], how='left')
418
+
419
+ # add column to record current row index
420
+ promptBook.loc[:, 'row_idx'] = promptBook.index
421
+
422
+ return roster, promptBook, images_ds
423
+
424
+
425
+ if __name__ == "__main__":
426
+ st.set_page_config(page_title="Model Coffer Gallery", page_icon="🖼️", layout="wide")
427
+
428
+ # remove ranking in the session state if it is created in Ranking.py
429
+ st.session_state.pop('ranking', None)
430
+
431
+ if 'user_id' not in st.session_state:
432
+ st.warning('Please log in first.')
433
+ home_btn = st.button('Go to Home Page')
434
+ if home_btn:
435
+ switch_page("home")
436
+ else:
437
+ st.write('You have already logged in as ' + st.session_state.user_id[0])
438
+ roster, promptBook, images_ds = load_hf_dataset()
439
+ # print(promptBook.columns)
440
+
441
+ # initialize selected_dict
442
+ if 'selected_dict' not in st.session_state:
443
+ st.session_state['selected_dict'] = {}
444
+
445
+ app = GalleryApp(promptBook=promptBook, images_ds=images_ds)
446
+ app.app()
pages/Gallery.py CHANGED
@@ -14,356 +14,350 @@ from sklearn.svm import LinearSVC
14
 
15
  SCORE_NAME_MAPPING = {'clip': 'clip_score', 'rank': 'msq_score', 'pop': 'model_download_count'}
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- class GalleryApp:
19
- def __init__(self, promptBook, images_ds):
20
- self.promptBook = promptBook
21
- self.images_ds = images_ds
22
-
23
- def gallery_standard(self, items, col_num, info):
24
- rows = len(items) // col_num + 1
25
- containers = [st.container() for _ in range(rows)]
26
- for idx in range(0, len(items), col_num):
27
- row_idx = idx // col_num
28
- with containers[row_idx]:
29
- cols = st.columns(col_num)
30
- for j in range(col_num):
31
- if idx + j < len(items):
32
- with cols[j]:
33
- # show image
34
- # image = self.images_ds[items.iloc[idx + j]['row_idx'].item()]['image']
35
- # image = f"https://modelcofferbucket.s3.us-east-2.amazonaws.com/{items.iloc[idx + j]['image_id']}.png"
36
- image = f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{items.iloc[idx + j]['image_id']}.png"
37
- st.image(image, use_column_width=True)
38
-
39
- # handel checkbox information
40
- prompt_id = items.iloc[idx + j]['prompt_id']
41
- modelVersion_id = items.iloc[idx + j]['modelVersion_id']
42
-
43
- check_init = True if modelVersion_id in st.session_state.selected_dict.get(prompt_id, []) else False
44
-
45
- st.write("Position: ", idx + j)
46
-
47
- # show checkbox
48
- st.checkbox('Select', key=f'select_{prompt_id}_{modelVersion_id}', value=check_init)
49
-
50
- # show selected info
51
- for key in info:
52
- st.write(f"**{key}**: {items.iloc[idx + j][key]}")
53
-
54
- def selection_panel(self, items):
55
- # temperal function
56
-
57
- selecters = st.columns([1, 4])
58
-
59
- if 'score_weights' not in st.session_state:
60
- st.session_state.score_weights = [1.0, 0.8, 0.2, 0.8]
61
-
62
- # select sort type
63
- with selecters[0]:
64
- sort_type = st.selectbox('Sort by', ['Scores', 'IDs and Names'])
65
- if sort_type == 'Scores':
66
- sort_by = 'weighted_score_sum'
67
-
68
- # select other options
69
- with selecters[1]:
70
- if sort_type == 'IDs and Names':
71
- sub_selecters = st.columns([3, 1])
72
- # select sort by
73
- with sub_selecters[0]:
74
- sort_by = st.selectbox('Sort by',
75
- ['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id', 'norm_nsfw'],
76
- label_visibility='hidden')
77
-
78
- continue_idx = 1
79
 
80
- else:
81
- # add custom weights
82
- sub_selecters = st.columns([1, 1, 1, 1])
 
 
 
 
83
 
84
- with sub_selecters[0]:
85
- clip_weight = st.number_input('Clip Score Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[0], step=0.1, help='the weight for normalized clip score')
86
- with sub_selecters[1]:
87
- mcos_weight = st.number_input('Dissimilarity Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[1], step=0.1, help='the weight for m(eam) s(imilarity) q(antile) score for measuring distinctiveness')
88
- with sub_selecters[2]:
89
- pop_weight = st.number_input('Popularity Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[2], step=0.1, help='the weight for normalized popularity score')
90
 
91
- items.loc[:, 'weighted_score_sum'] = round(items[f'norm_clip'] * clip_weight + items[f'norm_mcos'] * mcos_weight + items[
92
- 'norm_pop'] * pop_weight, 4)
 
 
 
 
93
 
94
- continue_idx = 3
 
95
 
96
- # save latest weights
97
- st.session_state.score_weights[0] = clip_weight
98
- st.session_state.score_weights[1] = mcos_weight
99
- st.session_state.score_weights[2] = pop_weight
100
 
101
- # select threshold
102
- with sub_selecters[continue_idx]:
103
- nsfw_threshold = st.number_input('NSFW Score Threshold', min_value=0.0, max_value=1.0, value=st.session_state.score_weights[3], step=0.01, help='Only show models with nsfw score lower than this threshold, set 1.0 to show all images')
104
- items = items[items['norm_nsfw'] <= nsfw_threshold].reset_index(drop=True)
105
 
106
- # save latest threshold
107
- st.session_state.score_weights[3] = nsfw_threshold
 
 
 
108
 
109
- # draw a distribution histogram
110
- if sort_type == 'Scores':
111
- try:
112
- with st.expander('Show score distribution histogram and select score range'):
113
- st.write('**Score distribution histogram**')
114
- chart_space = st.container()
115
- # st.write('Select the range of scores to show')
116
- hist_data = pd.DataFrame(items[sort_by])
117
- mini = hist_data[sort_by].min().item()
118
- mini = mini//0.1 * 0.1
119
- maxi = hist_data[sort_by].max().item()
120
- maxi = maxi//0.1 * 0.1 + 0.1
121
- st.write('**Select the range of scores to show**')
122
- r = st.slider('Select the range of scores to show', min_value=mini, max_value=maxi, value=(mini, maxi), step=0.05, label_visibility='collapsed')
123
- with chart_space:
124
- st.altair_chart(altair_histogram(hist_data, sort_by, r[0], r[1]), use_container_width=True)
125
- # event_dict = altair_component(altair_chart=altair_histogram(hist_data, sort_by))
126
- # r = event_dict.get(sort_by)
127
- if r:
128
- items = items[(items[sort_by] >= r[0]) & (items[sort_by] <= r[1])].reset_index(drop=True)
129
- # st.write(r)
130
- except:
131
- pass
132
 
133
- display_options = st.columns([1, 4])
134
 
135
- with display_options[0]:
136
- # select order
137
- order = st.selectbox('Order', ['Ascending', 'Descending'], index=1 if sort_type == 'Scores' else 0)
138
- if order == 'Ascending':
139
- order = True
140
- else:
141
- order = False
142
 
143
- with display_options[1]:
144
 
145
- # select info to show
146
- info = st.multiselect('Show Info',
147
- ['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id',
148
- 'weighted_score_sum', 'model_download_count', 'clip_score', 'mcos_score',
149
- 'nsfw_score', 'norm_nsfw'],
150
- default=sort_by)
151
 
152
- # apply sorting to dataframe
153
- items = items.sort_values(by=[sort_by], ascending=order).reset_index(drop=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
- # select number of columns
156
- col_num = st.slider('Number of columns', min_value=1, max_value=9, value=4, step=1, key='col_num')
 
 
 
 
 
 
 
 
 
 
157
 
158
- return items, info, col_num
159
 
160
- def sidebar(self):
161
- with st.sidebar:
162
- prompt_tags = self.promptBook['tag'].unique()
163
- # sort tags by alphabetical order
164
- prompt_tags = np.sort(prompt_tags)[::-1]
165
 
166
- tag = st.selectbox('Select a tag', prompt_tags)
167
 
168
- items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
 
 
 
 
 
 
 
 
 
 
169
 
170
- prompts = np.sort(items['prompt'].unique())[::-1]
 
 
171
 
172
- selected_prompt = st.selectbox('Select prompt', prompts)
 
173
 
174
- items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
175
- prompt_id = items['prompt_id'].unique()[0]
176
- note = items['note'].unique()[0]
 
177
 
178
- # show source
179
- if isinstance(note, str):
180
- if note.isdigit():
181
- st.caption(f"`Source: civitai`")
182
- else:
183
- st.caption(f"`Source: {note}`")
184
  else:
185
- st.caption("`Source: Parti-prompts`")
186
-
187
- # show image metadata
188
- image_metadatas = ['prompt_id', 'prompt', 'negativePrompt', 'sampler', 'cfgScale', 'size', 'seed']
189
- for key in image_metadatas:
190
- label = ' '.join(key.split('_')).capitalize()
191
- st.write(f"**{label}**")
192
- if items[key][0] == ' ':
193
- st.write('`None`')
194
- else:
195
- st.caption(f"{items[key][0]}")
196
-
197
- # for note as civitai image id, add civitai reference
198
- if isinstance(note, str) and note.isdigit():
199
- try:
200
- st.write(f'**[Civitai Reference](https://civitai.com/images/{note})**')
201
- res = requests.get(f'https://civitai.com/images/{note}')
202
- # st.write(res.text)
203
- soup = BeautifulSoup(res.text, 'html.parser')
204
- image_section = soup.find('div', {'class': 'mantine-12rlksp'})
205
- image_url = image_section.find('img')['src']
206
- st.image(image_url, use_column_width=True)
207
- except:
208
- pass
209
-
210
- return prompt_tags, tag, prompt_id, items
211
-
212
- def app(self):
213
- st.title('Model Visualization and Retrieval')
214
- st.write('This is a gallery of images generated by the models')
215
-
216
- prompt_tags, tag, prompt_id, items = self.sidebar()
217
-
218
- # add safety check for some prompts
219
- safety_check = True
220
- unsafe_prompts = {}
221
- # initialize unsafe prompts
222
- for prompt_tag in prompt_tags:
223
- unsafe_prompts[prompt_tag] = []
224
- # manually add unsafe prompts
225
- unsafe_prompts['world knowledge'] = [83]
226
- # unsafe_prompts['art'] = [23]
227
- unsafe_prompts['abstract'] = [1, 3]
228
- # unsafe_prompts['food'] = [34]
229
-
230
- if int(prompt_id.item()) in unsafe_prompts[tag]:
231
- st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
232
- safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'{prompt_id}')
233
-
234
- if safety_check:
235
- items, info, col_num = self.selection_panel(items)
236
-
237
- if 'selected_dict' in st.session_state:
238
- st.write('checked: ', str(st.session_state.selected_dict.get(prompt_id, [])))
239
- dynamic_weight_options = ['Grid Search', 'SVM', 'Greedy']
240
- dynamic_weight_panel = st.columns(len(dynamic_weight_options))
241
-
242
- if len(st.session_state.selected_dict.get(prompt_id, [])) > 0:
243
- btn_disable = False
244
- else:
245
- btn_disable = True
246
-
247
- for i in range(len(dynamic_weight_options)):
248
- method = dynamic_weight_options[i]
249
- with dynamic_weight_panel[i]:
250
- btn = st.button(method, use_container_width=True, disabled=btn_disable, on_click=self.dynamic_weight, args=(prompt_id, items, method))
251
-
252
- with st.form(key=f'{prompt_id}'):
253
- # buttons = st.columns([1, 1, 1])
254
- buttons_space = st.columns([1, 1, 1, 1])
255
- gallery_space = st.empty()
256
-
257
- with buttons_space[0]:
258
- continue_btn = st.form_submit_button('Confirm Selection', use_container_width=True, type='primary')
259
- if continue_btn:
260
- self.submit_actions('Continue', prompt_id)
261
-
262
- with buttons_space[1]:
263
- select_btn = st.form_submit_button('Select All', use_container_width=True)
264
- if select_btn:
265
- self.submit_actions('Select', prompt_id)
266
-
267
- with buttons_space[2]:
268
- deselect_btn = st.form_submit_button('Deselect All', use_container_width=True)
269
- if deselect_btn:
270
- self.submit_actions('Deselect', prompt_id)
271
-
272
- with buttons_space[3]:
273
- refresh_btn = st.form_submit_button('Refresh', on_click=gallery_space.empty, use_container_width=True)
274
-
275
- with gallery_space.container():
276
- with st.spinner('Loading images...'):
277
- self.gallery_standard(items, col_num, info)
278
-
279
- def submit_actions(self, status, prompt_id):
280
- if status == 'Select':
281
- modelVersions = self.promptBook[self.promptBook['prompt_id'] == prompt_id]['modelVersion_id'].unique()
282
- st.session_state.selected_dict[prompt_id] = modelVersions.tolist()
283
- print(st.session_state.selected_dict, 'select')
284
- st.experimental_rerun()
285
- elif status == 'Deselect':
286
- st.session_state.selected_dict[prompt_id] = []
287
- print(st.session_state.selected_dict, 'deselect')
288
- st.experimental_rerun()
289
- # self.promptBook.loc[self.promptBook['prompt_id'] == prompt_id, 'checked'] = False
290
- elif status == 'Continue':
291
- st.session_state.selected_dict[prompt_id] = []
292
- for key in st.session_state:
293
- keys = key.split('_')
294
- if keys[0] == 'select' and keys[1] == str(prompt_id):
295
- if st.session_state[key]:
296
- st.session_state.selected_dict[prompt_id].append(int(keys[2]))
297
- # switch_page("ranking")
298
- print(st.session_state.selected_dict, 'continue')
299
- st.experimental_rerun()
300
-
301
- def dynamic_weight(self, prompt_id, items, method='Grid Search'):
302
- selected = items[
303
- items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(drop=True)
304
- optimal_weight = [0, 0, 0]
305
-
306
- if method == 'Grid Search':
307
- # grid search method
308
- top_ranking = len(items) * len(selected)
309
-
310
- for clip_weight in np.arange(-1, 1, 0.1):
311
- for mcos_weight in np.arange(-1, 1, 0.1):
312
- for pop_weight in np.arange(-1, 1, 0.1):
313
-
314
- weight_all = clip_weight*items[f'norm_clip'] + mcos_weight*items[f'norm_mcos'] + pop_weight*items['norm_pop']
315
- weight_all_sorted = weight_all.sort_values(ascending=False).reset_index(drop=True)
316
- # print('weight_all_sorted:', weight_all_sorted)
317
- weight_selected = clip_weight*selected[f'norm_clip'] + mcos_weight*selected[f'norm_mcos'] + pop_weight*selected['norm_pop']
318
-
319
- # get the index of values of weight_selected in weight_all_sorted
320
- rankings = []
321
- for weight in weight_selected:
322
- rankings.append(weight_all_sorted.index[weight_all_sorted == weight].tolist()[0])
323
- if sum(rankings) <= top_ranking:
324
- top_ranking = sum(rankings)
325
- print('current top ranking:', top_ranking, rankings)
326
- optimal_weight = [clip_weight, mcos_weight, pop_weight]
327
- print('optimal weight:', optimal_weight)
328
-
329
- elif method == 'SVM':
330
- # svm method
331
- print('start svm method')
332
- # get residual dataframe that contains models not selected
333
- residual = items[~items['modelVersion_id'].isin(selected['modelVersion_id'])].reset_index(drop=True)
334
- residual = residual[['norm_clip_crop', 'norm_mcos_crop', 'norm_pop']]
335
- residual = residual.to_numpy()
336
- selected = selected[['norm_clip_crop', 'norm_mcos_crop', 'norm_pop']]
337
- selected = selected.to_numpy()
338
-
339
- y = np.concatenate((np.full((len(selected), 1), -1), np.full((len(residual), 1), 1)), axis=0).ravel()
340
- X = np.concatenate((selected, residual), axis=0)
341
-
342
- # fit svm model, and get parameters for the hyperplane
343
- clf = LinearSVC(random_state=0, C=1.0, fit_intercept=False, dual='auto')
344
- clf.fit(X, y)
345
- optimal_weight = clf.coef_[0].tolist()
346
- print('optimal weight:', optimal_weight)
347
- pass
348
-
349
- elif method == 'Greedy':
350
- for idx in selected.index:
351
- # find which score is the highest, clip, mcos, or pop
352
- clip_score = selected.loc[idx, 'norm_clip_crop']
353
- mcos_score = selected.loc[idx, 'norm_mcos_crop']
354
- pop_score = selected.loc[idx, 'norm_pop']
355
- if clip_score >= mcos_score and clip_score >= pop_score:
356
- optimal_weight[0] += 1
357
- elif mcos_score >= clip_score and mcos_score >= pop_score:
358
- optimal_weight[1] += 1
359
- elif pop_score >= clip_score and pop_score >= mcos_score:
360
- optimal_weight[2] += 1
361
-
362
- # normalize optimal_weight
363
- optimal_weight = [round(weight/len(selected), 2) for weight in optimal_weight]
364
- print('optimal weight:', optimal_weight)
365
-
366
- st.session_state.score_weights[0: 3] = optimal_weight
367
 
368
 
369
  # hist_data = pd.DataFrame(np.random.normal(42, 10, (200, 1)), columns=["x"])
@@ -422,27 +416,24 @@ def load_hf_dataset():
422
  return roster, promptBook, images_ds
423
 
424
 
425
- # if __name__ == "__main__":
426
-
427
- # start the app
428
- st.set_page_config(page_title="Model Coffer Gallery", page_icon="🖼️", layout="wide")
429
 
430
- # remove ranking in the session state if it is created in Ranking.py
431
- st.session_state.pop('ranking', None)
432
 
433
- if 'user_id' not in st.session_state:
434
- st.warning('Please log in first.')
435
- home_btn = st.button('Go to Home Page')
436
- if home_btn:
437
- switch_page("home")
438
- else:
439
- st.write('You have already logged in as ' + st.session_state.user_id[0])
440
- roster, promptBook, images_ds = load_hf_dataset()
441
- # print(promptBook.columns)
442
 
443
- # initialize selected_dict
444
- if 'selected_dict' not in st.session_state:
445
- st.session_state['selected_dict'] = {}
446
 
447
- app = GalleryApp(promptBook=promptBook, images_ds=images_ds)
448
- app.app()
 
14
 
15
  SCORE_NAME_MAPPING = {'clip': 'clip_score', 'rank': 'msq_score', 'pop': 'model_download_count'}
16
 
17
+ def gallery_standard(items, col_num, info):
18
+ rows = len(items) // col_num + 1
19
+ containers = [st.container() for _ in range(rows)]
20
+ for idx in range(0, len(items), col_num):
21
+ row_idx = idx // col_num
22
+ with containers[row_idx]:
23
+ cols = st.columns(col_num)
24
+ for j in range(col_num):
25
+ if idx + j < len(items):
26
+ with cols[j]:
27
+ # show image
28
+ # image = self.images_ds[items.iloc[idx + j]['row_idx'].item()]['image']
29
+ # image = f"https://modelcofferbucket.s3.us-east-2.amazonaws.com/{items.iloc[idx + j]['image_id']}.png"
30
+ image = f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{items.iloc[idx + j]['image_id']}.png"
31
+ st.image(image, use_column_width=True)
32
+
33
+ # handel checkbox information
34
+ prompt_id = items.iloc[idx + j]['prompt_id']
35
+ modelVersion_id = items.iloc[idx + j]['modelVersion_id']
36
+
37
+ check_init = True if modelVersion_id in st.session_state.selected_dict.get(prompt_id, []) else False
38
+
39
+ st.write("Position: ", idx + j)
40
+
41
+ # show checkbox
42
+ st.checkbox('Select', key=f'select_{prompt_id}_{modelVersion_id}', value=check_init)
43
+
44
+ # show selected info
45
+ for key in info:
46
+ st.write(f"**{key}**: {items.iloc[idx + j][key]}")
47
+
48
+ def selection_panel(items):
49
+ # temperal function
50
+
51
+ selecters = st.columns([1, 4])
52
+
53
+ if 'score_weights' not in st.session_state:
54
+ st.session_state.score_weights = [1.0, 0.8, 0.2, 0.8]
55
+
56
+ # select sort type
57
+ with selecters[0]:
58
+ sort_type = st.selectbox('Sort by', ['Scores', 'IDs and Names'])
59
+ if sort_type == 'Scores':
60
+ sort_by = 'weighted_score_sum'
61
+
62
+ # select other options
63
+ with selecters[1]:
64
+ if sort_type == 'IDs and Names':
65
+ sub_selecters = st.columns([3, 1])
66
+ # select sort by
67
+ with sub_selecters[0]:
68
+ sort_by = st.selectbox('Sort by',
69
+ ['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id', 'norm_nsfw'],
70
+ label_visibility='hidden')
71
+
72
+ continue_idx = 1
73
+
74
+ else:
75
+ # add custom weights
76
+ sub_selecters = st.columns([1, 1, 1, 1])
77
+
78
+ with sub_selecters[0]:
79
+ clip_weight = st.number_input('Clip Score Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[0], step=0.1, help='the weight for normalized clip score')
80
+ with sub_selecters[1]:
81
+ mcos_weight = st.number_input('Dissimilarity Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[1], step=0.1, help='the weight for m(eam) s(imilarity) q(antile) score for measuring distinctiveness')
82
+ with sub_selecters[2]:
83
+ pop_weight = st.number_input('Popularity Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[2], step=0.1, help='the weight for normalized popularity score')
84
+
85
+ items.loc[:, 'weighted_score_sum'] = round(items[f'norm_clip'] * clip_weight + items[f'norm_mcos'] * mcos_weight + items[
86
+ 'norm_pop'] * pop_weight, 4)
87
+
88
+ continue_idx = 3
89
+
90
+ # save latest weights
91
+ st.session_state.score_weights[0] = clip_weight
92
+ st.session_state.score_weights[1] = mcos_weight
93
+ st.session_state.score_weights[2] = pop_weight
94
+
95
+ # select threshold
96
+ with sub_selecters[continue_idx]:
97
+ nsfw_threshold = st.number_input('NSFW Score Threshold', min_value=0.0, max_value=1.0, value=st.session_state.score_weights[3], step=0.01, help='Only show models with nsfw score lower than this threshold, set 1.0 to show all images')
98
+ items = items[items['norm_nsfw'] <= nsfw_threshold].reset_index(drop=True)
99
+
100
+ # save latest threshold
101
+ st.session_state.score_weights[3] = nsfw_threshold
102
+
103
+ # draw a distribution histogram
104
+ if sort_type == 'Scores':
105
+ try:
106
+ with st.expander('Show score distribution histogram and select score range'):
107
+ st.write('**Score distribution histogram**')
108
+ chart_space = st.container()
109
+ # st.write('Select the range of scores to show')
110
+ hist_data = pd.DataFrame(items[sort_by])
111
+ mini = hist_data[sort_by].min().item()
112
+ mini = mini//0.1 * 0.1
113
+ maxi = hist_data[sort_by].max().item()
114
+ maxi = maxi//0.1 * 0.1 + 0.1
115
+ st.write('**Select the range of scores to show**')
116
+ r = st.slider('Select the range of scores to show', min_value=mini, max_value=maxi, value=(mini, maxi), step=0.05, label_visibility='collapsed')
117
+ with chart_space:
118
+ st.altair_chart(altair_histogram(hist_data, sort_by, r[0], r[1]), use_container_width=True)
119
+ # event_dict = altair_component(altair_chart=altair_histogram(hist_data, sort_by))
120
+ # r = event_dict.get(sort_by)
121
+ if r:
122
+ items = items[(items[sort_by] >= r[0]) & (items[sort_by] <= r[1])].reset_index(drop=True)
123
+ # st.write(r)
124
+ except:
125
+ pass
126
 
127
+ display_options = st.columns([1, 4])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
+ with display_options[0]:
130
+ # select order
131
+ order = st.selectbox('Order', ['Ascending', 'Descending'], index=1 if sort_type == 'Scores' else 0)
132
+ if order == 'Ascending':
133
+ order = True
134
+ else:
135
+ order = False
136
 
137
+ with display_options[1]:
 
 
 
 
 
138
 
139
+ # select info to show
140
+ info = st.multiselect('Show Info',
141
+ ['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id',
142
+ 'weighted_score_sum', 'model_download_count', 'clip_score', 'mcos_score',
143
+ 'nsfw_score', 'norm_nsfw'],
144
+ default=sort_by)
145
 
146
+ # apply sorting to dataframe
147
+ items = items.sort_values(by=[sort_by], ascending=order).reset_index(drop=True)
148
 
149
+ # select number of columns
150
+ col_num = st.slider('Number of columns', min_value=1, max_value=9, value=4, step=1, key='col_num')
 
 
151
 
152
+ return items, info, col_num
 
 
 
153
 
154
+ def sidebar(promptbook, images_db):
155
+ with st.sidebar:
156
+ prompt_tags = promptBook['tag'].unique()
157
+ # sort tags by alphabetical order
158
+ prompt_tags = np.sort(prompt_tags)[::-1]
159
 
160
+ tag = st.selectbox('Select a tag', prompt_tags)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
+ items = promptBook[promptBook['tag'] == tag].reset_index(drop=True)
163
 
164
+ prompts = np.sort(items['prompt'].unique())[::-1]
 
 
 
 
 
 
165
 
166
+ selected_prompt = st.selectbox('Select prompt', prompts)
167
 
168
+ items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
169
+ prompt_id = items['prompt_id'].unique()[0]
170
+ note = items['note'].unique()[0]
 
 
 
171
 
172
+ # show source
173
+ if isinstance(note, str):
174
+ if note.isdigit():
175
+ st.caption(f"`Source: civitai`")
176
+ else:
177
+ st.caption(f"`Source: {note}`")
178
+ else:
179
+ st.caption("`Source: Parti-prompts`")
180
+
181
+ # show image metadata
182
+ image_metadatas = ['prompt_id', 'prompt', 'negativePrompt', 'sampler', 'cfgScale', 'size', 'seed']
183
+ for key in image_metadatas:
184
+ label = ' '.join(key.split('_')).capitalize()
185
+ st.write(f"**{label}**")
186
+ if items[key][0] == ' ':
187
+ st.write('`None`')
188
+ else:
189
+ st.caption(f"{items[key][0]}")
190
 
191
+ # for note as civitai image id, add civitai reference
192
+ if isinstance(note, str) and note.isdigit():
193
+ try:
194
+ st.write(f'**[Civitai Reference](https://civitai.com/images/{note})**')
195
+ res = requests.get(f'https://civitai.com/images/{note}')
196
+ # st.write(res.text)
197
+ soup = BeautifulSoup(res.text, 'html.parser')
198
+ image_section = soup.find('div', {'class': 'mantine-12rlksp'})
199
+ image_url = image_section.find('img')['src']
200
+ st.image(image_url, use_column_width=True)
201
+ except:
202
+ pass
203
 
204
+ return prompt_tags, tag, prompt_id, items
205
 
206
+ def app(promptbook, images_db):
207
+ st.title('Model Visualization and Retrieval')
208
+ st.write('This is a gallery of images generated by the models')
 
 
209
 
210
+ prompt_tags, tag, prompt_id, items = sidebar(promptbook, images_db)
211
 
212
+ # add safety check for some prompts
213
+ safety_check = True
214
+ unsafe_prompts = {}
215
+ # initialize unsafe prompts
216
+ for prompt_tag in prompt_tags:
217
+ unsafe_prompts[prompt_tag] = []
218
+ # manually add unsafe prompts
219
+ unsafe_prompts['world knowledge'] = [83]
220
+ # unsafe_prompts['art'] = [23]
221
+ unsafe_prompts['abstract'] = [1, 3]
222
+ # unsafe_prompts['food'] = [34]
223
 
224
+ if int(prompt_id.item()) in unsafe_prompts[tag]:
225
+ st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
226
+ safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'{prompt_id}')
227
 
228
+ if safety_check:
229
+ items, info, col_num = selection_panel(items)
230
 
231
+ if 'selected_dict' in st.session_state:
232
+ st.write('checked: ', str(st.session_state.selected_dict.get(prompt_id, [])))
233
+ dynamic_weight_options = ['Grid Search', 'SVM', 'Greedy']
234
+ dynamic_weight_panel = st.columns(len(dynamic_weight_options))
235
 
236
+ if len(st.session_state.selected_dict.get(prompt_id, [])) > 0:
237
+ btn_disable = False
 
 
 
 
238
  else:
239
+ btn_disable = True
240
+
241
+ for i in range(len(dynamic_weight_options)):
242
+ method = dynamic_weight_options[i]
243
+ with dynamic_weight_panel[i]:
244
+ btn = st.button(method, use_container_width=True, disabled=btn_disable, on_click=dynamic_weight, args=(prompt_id, items, method))
245
+
246
+ with st.form(key=f'{prompt_id}'):
247
+ # buttons = st.columns([1, 1, 1])
248
+ buttons_space = st.columns([1, 1, 1, 1])
249
+ gallery_space = st.empty()
250
+
251
+ with buttons_space[0]:
252
+ continue_btn = st.form_submit_button('Confirm Selection', use_container_width=True, type='primary')
253
+ if continue_btn:
254
+ submit_actions('Continue', prompt_id)
255
+
256
+ with buttons_space[1]:
257
+ select_btn = st.form_submit_button('Select All', use_container_width=True)
258
+ if select_btn:
259
+ submit_actions('Select', prompt_id)
260
+
261
+ with buttons_space[2]:
262
+ deselect_btn = st.form_submit_button('Deselect All', use_container_width=True)
263
+ if deselect_btn:
264
+ submit_actions('Deselect', prompt_id)
265
+
266
+ with buttons_space[3]:
267
+ refresh_btn = st.form_submit_button('Refresh', on_click=gallery_space.empty, use_container_width=True)
268
+
269
+ with gallery_space.container():
270
+ with st.spinner('Loading images...'):
271
+ gallery_standard(items, col_num, info)
272
+
273
+ def submit_actions(status, prompt_id):
274
+ if status == 'Select':
275
+ modelVersions = promptBook[promptBook['prompt_id'] == prompt_id]['modelVersion_id'].unique()
276
+ st.session_state.selected_dict[prompt_id] = modelVersions.tolist()
277
+ print(st.session_state.selected_dict, 'select')
278
+ st.experimental_rerun()
279
+ elif status == 'Deselect':
280
+ st.session_state.selected_dict[prompt_id] = []
281
+ print(st.session_state.selected_dict, 'deselect')
282
+ st.experimental_rerun()
283
+ # self.promptBook.loc[self.promptBook['prompt_id'] == prompt_id, 'checked'] = False
284
+ elif status == 'Continue':
285
+ st.session_state.selected_dict[prompt_id] = []
286
+ for key in st.session_state:
287
+ keys = key.split('_')
288
+ if keys[0] == 'select' and keys[1] == str(prompt_id):
289
+ if st.session_state[key]:
290
+ st.session_state.selected_dict[prompt_id].append(int(keys[2]))
291
+ # switch_page("ranking")
292
+ print(st.session_state.selected_dict, 'continue')
293
+ st.experimental_rerun()
294
+
295
+ def dynamic_weight(prompt_id, items, method='Grid Search'):
296
+ selected = items[
297
+ items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(drop=True)
298
+ optimal_weight = [0, 0, 0]
299
+
300
+ if method == 'Grid Search':
301
+ # grid search method
302
+ top_ranking = len(items) * len(selected)
303
+
304
+ for clip_weight in np.arange(-1, 1, 0.1):
305
+ for mcos_weight in np.arange(-1, 1, 0.1):
306
+ for pop_weight in np.arange(-1, 1, 0.1):
307
+
308
+ weight_all = clip_weight*items[f'norm_clip'] + mcos_weight*items[f'norm_mcos'] + pop_weight*items['norm_pop']
309
+ weight_all_sorted = weight_all.sort_values(ascending=False).reset_index(drop=True)
310
+ # print('weight_all_sorted:', weight_all_sorted)
311
+ weight_selected = clip_weight*selected[f'norm_clip'] + mcos_weight*selected[f'norm_mcos'] + pop_weight*selected['norm_pop']
312
+
313
+ # get the index of values of weight_selected in weight_all_sorted
314
+ rankings = []
315
+ for weight in weight_selected:
316
+ rankings.append(weight_all_sorted.index[weight_all_sorted == weight].tolist()[0])
317
+ if sum(rankings) <= top_ranking:
318
+ top_ranking = sum(rankings)
319
+ print('current top ranking:', top_ranking, rankings)
320
+ optimal_weight = [clip_weight, mcos_weight, pop_weight]
321
+ print('optimal weight:', optimal_weight)
322
+
323
+ elif method == 'SVM':
324
+ # svm method
325
+ print('start svm method')
326
+ # get residual dataframe that contains models not selected
327
+ residual = items[~items['modelVersion_id'].isin(selected['modelVersion_id'])].reset_index(drop=True)
328
+ residual = residual[['norm_clip_crop', 'norm_mcos_crop', 'norm_pop']]
329
+ residual = residual.to_numpy()
330
+ selected = selected[['norm_clip_crop', 'norm_mcos_crop', 'norm_pop']]
331
+ selected = selected.to_numpy()
332
+
333
+ y = np.concatenate((np.full((len(selected), 1), -1), np.full((len(residual), 1), 1)), axis=0).ravel()
334
+ X = np.concatenate((selected, residual), axis=0)
335
+
336
+ # fit svm model, and get parameters for the hyperplane
337
+ clf = LinearSVC(random_state=0, C=1.0, fit_intercept=False, dual='auto')
338
+ clf.fit(X, y)
339
+ optimal_weight = clf.coef_[0].tolist()
340
+ print('optimal weight:', optimal_weight)
341
+ pass
342
+
343
+ elif method == 'Greedy':
344
+ for idx in selected.index:
345
+ # find which score is the highest, clip, mcos, or pop
346
+ clip_score = selected.loc[idx, 'norm_clip_crop']
347
+ mcos_score = selected.loc[idx, 'norm_mcos_crop']
348
+ pop_score = selected.loc[idx, 'norm_pop']
349
+ if clip_score >= mcos_score and clip_score >= pop_score:
350
+ optimal_weight[0] += 1
351
+ elif mcos_score >= clip_score and mcos_score >= pop_score:
352
+ optimal_weight[1] += 1
353
+ elif pop_score >= clip_score and pop_score >= mcos_score:
354
+ optimal_weight[2] += 1
355
+
356
+ # normalize optimal_weight
357
+ optimal_weight = [round(weight/len(selected), 2) for weight in optimal_weight]
358
+ print('optimal weight:', optimal_weight)
359
+
360
+ st.session_state.score_weights[0: 3] = optimal_weight
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
 
362
 
363
  # hist_data = pd.DataFrame(np.random.normal(42, 10, (200, 1)), columns=["x"])
 
416
  return roster, promptBook, images_ds
417
 
418
 
419
+ if __name__ == "__main__":
420
+ st.set_page_config(page_title="Model Coffer Gallery", page_icon="🖼️", layout="wide")
 
 
421
 
422
+ # remove ranking in the session state if it is created in Ranking.py
423
+ st.session_state.pop('ranking', None)
424
 
425
+ if 'user_id' not in st.session_state:
426
+ st.warning('Please log in first.')
427
+ home_btn = st.button('Go to Home Page')
428
+ if home_btn:
429
+ switch_page("home")
430
+ else:
431
+ st.write('You have already logged in as ' + st.session_state.user_id[0])
432
+ roster, promptBook, images_ds = load_hf_dataset()
433
+ # print(promptBook.columns)
434
 
435
+ # initialize selected_dict
436
+ if 'selected_dict' not in st.session_state:
437
+ st.session_state['selected_dict'] = {}
438
 
439
+ app(promptBook, images_ds)