Spaces:
Running
Running
ranking is not finished!!!
Browse files- pages/Gallery.py +12 -22
- pages/Ranking.py +101 -9
- pages/__pycache__/Gallery.cpython-39.pyc +0 -0
pages/Gallery.py
CHANGED
@@ -45,19 +45,7 @@ class GalleryApp:
|
|
45 |
st.write("Position: ", idx + j)
|
46 |
|
47 |
# show checkbox
|
48 |
-
|
49 |
-
|
50 |
-
#
|
51 |
-
# if checked:
|
52 |
-
# if prompt_id not in st.session_state.selected_dict:
|
53 |
-
# st.session_state.selected_dict[prompt_id] = []
|
54 |
-
# if modelVersion_id not in st.session_state.selected_dict[prompt_id]:
|
55 |
-
# st.session_state.selected_dict[prompt_id].append(modelVersion_id)
|
56 |
-
# else:
|
57 |
-
# try:
|
58 |
-
# st.session_state.selected_dict[prompt_id].remove(modelVersion_id)
|
59 |
-
# except:
|
60 |
-
# pass
|
61 |
|
62 |
# show selected info
|
63 |
for key in info:
|
@@ -65,7 +53,6 @@ class GalleryApp:
|
|
65 |
|
66 |
def selection_panel(self, items):
|
67 |
# temperal function
|
68 |
-
preprocessor = st.radio('Preprocess Method', ['crop', 'resize'], horizontal=True)
|
69 |
|
70 |
selecters = st.columns([1, 4])
|
71 |
|
@@ -101,7 +88,7 @@ class GalleryApp:
|
|
101 |
with sub_selecters[2]:
|
102 |
pop_weight = st.number_input('Popularity Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[2], step=0.1, help='the weight for normalized popularity score')
|
103 |
|
104 |
-
items.loc[:, 'weighted_score_sum'] = round(items[f'
|
105 |
'norm_pop'] * pop_weight, 4)
|
106 |
|
107 |
continue_idx = 3
|
@@ -168,7 +155,7 @@ class GalleryApp:
|
|
168 |
# select number of columns
|
169 |
col_num = st.slider('Number of columns', min_value=1, max_value=9, value=4, step=1, key='col_num')
|
170 |
|
171 |
-
return items, info, col_num
|
172 |
|
173 |
def sidebar(self):
|
174 |
with st.sidebar:
|
@@ -226,7 +213,7 @@ class GalleryApp:
|
|
226 |
st.title('Model Visualization and Retrieval')
|
227 |
st.write('This is a gallery of images generated by the models')
|
228 |
|
229 |
-
prompt_tags, tag, prompt_id, items= self.sidebar()
|
230 |
|
231 |
# add safety check for some prompts
|
232 |
safety_check = True
|
@@ -245,7 +232,7 @@ class GalleryApp:
|
|
245 |
safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'{prompt_id}')
|
246 |
|
247 |
if safety_check:
|
248 |
-
items, info, col_num
|
249 |
|
250 |
if 'selected_dict' in st.session_state:
|
251 |
st.write('checked: ', str(st.session_state.selected_dict.get(prompt_id, [])))
|
@@ -260,7 +247,7 @@ class GalleryApp:
|
|
260 |
for i in range(len(dynamic_weight_options)):
|
261 |
method = dynamic_weight_options[i]
|
262 |
with dynamic_weight_panel[i]:
|
263 |
-
btn = st.button(method, use_container_width=True, disabled=btn_disable, on_click=self.dynamic_weight, args=(prompt_id, items,
|
264 |
|
265 |
with st.form(key=f'{prompt_id}'):
|
266 |
# buttons = st.columns([1, 1, 1])
|
@@ -311,7 +298,7 @@ class GalleryApp:
|
|
311 |
print(st.session_state.selected_dict, 'continue')
|
312 |
st.experimental_rerun()
|
313 |
|
314 |
-
def dynamic_weight(self, prompt_id, items,
|
315 |
selected = items[
|
316 |
items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(drop=True)
|
317 |
optimal_weight = [0, 0, 0]
|
@@ -324,10 +311,10 @@ class GalleryApp:
|
|
324 |
for mcos_weight in np.arange(-1, 1, 0.1):
|
325 |
for pop_weight in np.arange(-1, 1, 0.1):
|
326 |
|
327 |
-
weight_all = clip_weight*items[f'
|
328 |
weight_all_sorted = weight_all.sort_values(ascending=False).reset_index(drop=True)
|
329 |
# print('weight_all_sorted:', weight_all_sorted)
|
330 |
-
weight_selected = clip_weight*selected[f'
|
331 |
|
332 |
# get the index of values of weight_selected in weight_all_sorted
|
333 |
rankings = []
|
@@ -438,6 +425,9 @@ def load_hf_dataset():
|
|
438 |
if __name__ == "__main__":
|
439 |
st.set_page_config(page_title="Model Coffer Gallery", page_icon="🖼️", layout="wide")
|
440 |
|
|
|
|
|
|
|
441 |
if 'user_id' not in st.session_state:
|
442 |
st.warning('Please log in first.')
|
443 |
home_btn = st.button('Go to Home Page')
|
|
|
45 |
st.write("Position: ", idx + j)
|
46 |
|
47 |
# show checkbox
|
48 |
+
st.checkbox('Select', key=f'select_{prompt_id}_{modelVersion_id}', value=check_init)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
# show selected info
|
51 |
for key in info:
|
|
|
53 |
|
54 |
def selection_panel(self, items):
|
55 |
# temperal function
|
|
|
56 |
|
57 |
selecters = st.columns([1, 4])
|
58 |
|
|
|
88 |
with sub_selecters[2]:
|
89 |
pop_weight = st.number_input('Popularity Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[2], step=0.1, help='the weight for normalized popularity score')
|
90 |
|
91 |
+
items.loc[:, 'weighted_score_sum'] = round(items[f'norm_clip'] * clip_weight + items[f'norm_mcos'] * mcos_weight + items[
|
92 |
'norm_pop'] * pop_weight, 4)
|
93 |
|
94 |
continue_idx = 3
|
|
|
155 |
# select number of columns
|
156 |
col_num = st.slider('Number of columns', min_value=1, max_value=9, value=4, step=1, key='col_num')
|
157 |
|
158 |
+
return items, info, col_num
|
159 |
|
160 |
def sidebar(self):
|
161 |
with st.sidebar:
|
|
|
213 |
st.title('Model Visualization and Retrieval')
|
214 |
st.write('This is a gallery of images generated by the models')
|
215 |
|
216 |
+
prompt_tags, tag, prompt_id, items = self.sidebar()
|
217 |
|
218 |
# add safety check for some prompts
|
219 |
safety_check = True
|
|
|
232 |
safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'{prompt_id}')
|
233 |
|
234 |
if safety_check:
|
235 |
+
items, info, col_num = self.selection_panel(items)
|
236 |
|
237 |
if 'selected_dict' in st.session_state:
|
238 |
st.write('checked: ', str(st.session_state.selected_dict.get(prompt_id, [])))
|
|
|
247 |
for i in range(len(dynamic_weight_options)):
|
248 |
method = dynamic_weight_options[i]
|
249 |
with dynamic_weight_panel[i]:
|
250 |
+
btn = st.button(method, use_container_width=True, disabled=btn_disable, on_click=self.dynamic_weight, args=(prompt_id, items, method))
|
251 |
|
252 |
with st.form(key=f'{prompt_id}'):
|
253 |
# buttons = st.columns([1, 1, 1])
|
|
|
298 |
print(st.session_state.selected_dict, 'continue')
|
299 |
st.experimental_rerun()
|
300 |
|
301 |
+
def dynamic_weight(self, prompt_id, items, method='Grid Search'):
|
302 |
selected = items[
|
303 |
items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(drop=True)
|
304 |
optimal_weight = [0, 0, 0]
|
|
|
311 |
for mcos_weight in np.arange(-1, 1, 0.1):
|
312 |
for pop_weight in np.arange(-1, 1, 0.1):
|
313 |
|
314 |
+
weight_all = clip_weight*items[f'norm_clip'] + mcos_weight*items[f'norm_mcos'] + pop_weight*items['norm_pop']
|
315 |
weight_all_sorted = weight_all.sort_values(ascending=False).reset_index(drop=True)
|
316 |
# print('weight_all_sorted:', weight_all_sorted)
|
317 |
+
weight_selected = clip_weight*selected[f'norm_clip'] + mcos_weight*selected[f'norm_mcos'] + pop_weight*selected['norm_pop']
|
318 |
|
319 |
# get the index of values of weight_selected in weight_all_sorted
|
320 |
rankings = []
|
|
|
425 |
if __name__ == "__main__":
|
426 |
st.set_page_config(page_title="Model Coffer Gallery", page_icon="🖼️", layout="wide")
|
427 |
|
428 |
+
# remove ranking in the session state if it is created in Ranking.py
|
429 |
+
st.session_state.pop('ranking', None)
|
430 |
+
|
431 |
if 'user_id' not in st.session_state:
|
432 |
st.warning('Please log in first.')
|
433 |
home_btn = st.button('Go to Home Page')
|
pages/Ranking.py
CHANGED
@@ -8,24 +8,105 @@ from streamlit_extras.switch_page_button import switch_page
|
|
8 |
from pages.Gallery import load_hf_dataset
|
9 |
|
10 |
|
11 |
-
class RankingApp
|
12 |
-
def __init__(self, promptBook,
|
13 |
self.promptBook = promptBook
|
14 |
-
self.
|
|
|
|
|
|
|
15 |
|
16 |
-
|
17 |
-
|
18 |
|
19 |
def sidebar(self):
|
20 |
with st.sidebar:
|
21 |
prompt_tags = self.promptBook['tag'].unique()
|
|
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
def app(self):
|
25 |
st.title('Personal Image Ranking')
|
26 |
st.write('Here you can test out your selected images with any prompt you like.')
|
|
|
27 |
|
28 |
-
prompt_tags, tag, prompt_id, items= self.sidebar()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
|
31 |
if __name__ == "__main__":
|
@@ -50,11 +131,22 @@ if __name__ == "__main__":
|
|
50 |
if gallery_btn:
|
51 |
switch_page('gallery')
|
52 |
else:
|
53 |
-
st.write('You have checked ' + str(len(selected_modelVersions)) + ' images.')
|
54 |
roster, promptBook, images_ds = load_hf_dataset()
|
55 |
-
st.
|
56 |
-
st.write(roster
|
|
|
57 |
# st.write(roster)
|
58 |
# st.write("## promptBook")
|
59 |
# st.write(promptBook)
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from pages.Gallery import load_hf_dataset
|
9 |
|
10 |
|
11 |
+
class RankingApp:
|
12 |
+
def __init__(self, promptBook, images_endpoint, batch_size=4):
|
13 |
self.promptBook = promptBook
|
14 |
+
self.images_endpoint = images_endpoint
|
15 |
+
self.batch_size = batch_size
|
16 |
+
# self.batch_num = len(self.promptBook) // self.batch_size
|
17 |
+
# self.batch_num += 1 if len(self.promptBook) % self.batch_size != 0 else 0
|
18 |
|
19 |
+
if 'counter' not in st.session_state:
|
20 |
+
st.session_state.counter = 0
|
21 |
|
22 |
def sidebar(self):
|
23 |
with st.sidebar:
|
24 |
prompt_tags = self.promptBook['tag'].unique()
|
25 |
+
prompt_tags = np.sort(prompt_tags)
|
26 |
|
27 |
+
tag = st.selectbox('Select a prompt tag', prompt_tags)
|
28 |
+
items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
|
29 |
+
prompts = np.sort(items['prompt'].unique())[::-1]
|
30 |
+
|
31 |
+
selected_prompt = st.selectbox('Select a prompt', prompts)
|
32 |
+
|
33 |
+
items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
|
34 |
+
prompt_id = items['prompt_id'].unique()[0]
|
35 |
+
|
36 |
+
with st.form(key='prompt_form'):
|
37 |
+
# input image metadata
|
38 |
+
prompt = st.text_area('Prompt', selected_prompt, height=150, key='prompt', disabled=True)
|
39 |
+
negative_prompt = st.text_area('Negative Prompt', items['negativePrompt'].unique()[0], height=150, key='negative_prompt', disabled=True)
|
40 |
+
st.form_submit_button('Generate Images', type='primary', use_container_width=True)
|
41 |
+
|
42 |
+
return prompt_tags, tag, prompt_id, items
|
43 |
+
|
44 |
+
def draggable_images(self, items, layout='portrait'):
|
45 |
+
# init ranking by the order of items
|
46 |
+
if 'ranking' not in st.session_state:
|
47 |
+
st.session_state.ranking = {}
|
48 |
+
for i in range(len(items)):
|
49 |
+
st.session_state.ranking[str(items['image_id'][i])] = i
|
50 |
+
|
51 |
+
print(items)
|
52 |
+
with elements('dashboard'):
|
53 |
+
if layout == 'portrait':
|
54 |
+
col_num = 4
|
55 |
+
layout = [dashboard.Item(str(items['image_id'][i]), i % col_num, i//col_num, 1, 2, isResizable=False) for i in range(len(items))]
|
56 |
+
|
57 |
+
elif layout == 'landscape':
|
58 |
+
col_num = 2
|
59 |
+
layout = [
|
60 |
+
dashboard.Item(str(items['image_id'][i]), i % col_num * 2, i // col_num, 2, 1.4, isResizable=False) for
|
61 |
+
i in range(len(items))
|
62 |
+
]
|
63 |
+
|
64 |
+
with dashboard.Grid(layout, cols={'lg': 4, 'md': 4, 'sm': 4, 'xs': 4, 'xxs': 2}, onLayoutChange=self.handle_layout_change, margin=[18, 18], containerPadding=[0, 0]):
|
65 |
+
for i in range(len(layout)):
|
66 |
+
with mui.Card(key=str(items['image_id'][i]), variant="outlined"):
|
67 |
+
rank = st.session_state.ranking[str(items['image_id'][i])] + 1
|
68 |
+
|
69 |
+
mui.Chip(label=rank,
|
70 |
+
# variant="outlined" if rank!=1 else "default",
|
71 |
+
color="primary" if rank == 1 else "warning" if rank == 2 else "info",
|
72 |
+
size="small",
|
73 |
+
sx={"position": "absolute", "left": "-0.3rem", "top": "-0.3rem"})
|
74 |
+
|
75 |
+
img_url = self.images_endpoint + str(items['image_id'][i]) + '.png'
|
76 |
+
|
77 |
+
mui.CardMedia(
|
78 |
+
component="img",
|
79 |
+
# image={"data:image/png;base64", img_str},
|
80 |
+
image=img_url,
|
81 |
+
alt="There should be an image",
|
82 |
+
sx={"height": "100%", "object-fit": "fit", 'bgcolor': 'black'},
|
83 |
+
)
|
84 |
+
|
85 |
+
def handle_layout_change(self, updated_layout):
|
86 |
+
# print(updated_layout)
|
87 |
+
sorted_list = sorted(updated_layout, key=lambda x: (x['y'], x['x']))
|
88 |
+
sorted_list = [str(item['i']) for item in sorted_list]
|
89 |
+
|
90 |
+
for k in st.session_state.ranking.keys():
|
91 |
+
st.session_state.ranking[k] = sorted_list.index(k)
|
92 |
|
93 |
def app(self):
|
94 |
st.title('Personal Image Ranking')
|
95 |
st.write('Here you can test out your selected images with any prompt you like.')
|
96 |
+
# st.write(self.promptBook)
|
97 |
|
98 |
+
prompt_tags, tag, prompt_id, items = self.sidebar()
|
99 |
+
|
100 |
+
sorting, control = st.columns((11, 1), gap='large')
|
101 |
+
with sorting:
|
102 |
+
# st.write('## Sorting')
|
103 |
+
# st.write('Please drag the images to sort them.')
|
104 |
+
st.progress((st.session_state.counter + 1) / self.batch_num, text=f"Batch {st.session_state.counter + 1} / {self.batch_num}")
|
105 |
+
self.draggable_images(items.iloc[self.batch_size*st.session_state.counter: self.batch_size*(st.session_state.counter+1)], layout='portrait')
|
106 |
+
|
107 |
+
with control:
|
108 |
+
st.button(":arrow_right:")
|
109 |
+
st.button(":slightly_frowning_face:")
|
110 |
|
111 |
|
112 |
if __name__ == "__main__":
|
|
|
131 |
if gallery_btn:
|
132 |
switch_page('gallery')
|
133 |
else:
|
134 |
+
# st.write('You have checked ' + str(len(selected_modelVersions)) + ' images.')
|
135 |
roster, promptBook, images_ds = load_hf_dataset()
|
136 |
+
print(st.session_state.selected_dict)
|
137 |
+
# st.write("## roster")
|
138 |
+
# st.write(roster[roster['modelVersion_id'].isin(selected_modelVersions)])
|
139 |
# st.write(roster)
|
140 |
# st.write("## promptBook")
|
141 |
# st.write(promptBook)
|
142 |
|
143 |
+
# only select the part of the promptbook where tag is the same as st.session_state.selected_dict.keys(), while model version ids are the same as corresponding values to each key
|
144 |
+
promptBook_selected = pd.DataFrame()
|
145 |
+
for key, value in st.session_state.selected_dict.items():
|
146 |
+
promptBook_selected = promptBook_selected.append(promptBook[(promptBook['prompt_id'] == key) & (promptBook['modelVersion_id'].isin(value))])
|
147 |
+
promptBook_selected = promptBook_selected.reset_index(drop=True)
|
148 |
+
images_endpoint = "https://modelcofferbucket.s3-accelerate.amazonaws.com/"
|
149 |
+
|
150 |
+
app = RankingApp(promptBook_selected, images_endpoint, batch_size=4)
|
151 |
+
app.app()
|
152 |
+
|
pages/__pycache__/Gallery.cpython-39.pyc
CHANGED
Binary files a/pages/__pycache__/Gallery.cpython-39.pyc and b/pages/__pycache__/Gallery.cpython-39.pyc differ
|
|