Spaces:
Running
Running
beta for battle ranking
Browse files- pages/Gallery.py +35 -4
- pages/Ranking.py +129 -57
pages/Gallery.py
CHANGED
@@ -327,6 +327,7 @@ class GalleryApp:
|
|
327 |
deselect = st.form_submit_button('Deselect', use_container_width=True)
|
328 |
if deselect:
|
329 |
st.session_state.selected_dict[item['prompt_id']].remove(item['modelVersion_id'])
|
|
|
330 |
st.experimental_rerun()
|
331 |
|
332 |
else:
|
@@ -334,6 +335,7 @@ class GalleryApp:
|
|
334 |
select = st.form_submit_button('Select', use_container_width=True, type='primary')
|
335 |
if select:
|
336 |
st.session_state.selected_dict[item['prompt_id']].append(item['modelVersion_id'])
|
|
|
337 |
st.experimental_rerun()
|
338 |
|
339 |
# st.write(item)
|
@@ -408,7 +410,8 @@ class GalleryApp:
|
|
408 |
|
409 |
def submit_actions(self, status, prompt_id):
|
410 |
# remove counter from session state
|
411 |
-
st.session_state.pop('counter', None)
|
|
|
412 |
if status == 'Select':
|
413 |
modelVersions = self.promptBook[self.promptBook['prompt_id'] == prompt_id]['modelVersion_id'].unique()
|
414 |
st.session_state.selected_dict[prompt_id] = modelVersions.tolist()
|
@@ -498,6 +501,30 @@ class GalleryApp:
|
|
498 |
st.session_state.score_weights[0: 3] = optimal_weight
|
499 |
|
500 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
501 |
# hist_data = pd.DataFrame(np.random.normal(42, 10, (200, 1)), columns=["x"])
|
502 |
@st.cache_resource
|
503 |
def altair_histogram(hist_data, sort_by, mini, maxi):
|
@@ -551,6 +578,13 @@ def load_hf_dataset():
|
|
551 |
# add column to record current row index
|
552 |
promptBook.loc[:, 'row_idx'] = promptBook.index
|
553 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
554 |
return roster, promptBook, images_ds
|
555 |
|
556 |
@st.cache_data
|
@@ -569,9 +603,6 @@ def load_tsne_coordinates(items):
|
|
569 |
if __name__ == "__main__":
|
570 |
st.set_page_config(page_title="Model Coffer Gallery", page_icon="🖼️", layout="wide")
|
571 |
|
572 |
-
# remove ranking in the session state if it is created in Ranking.py
|
573 |
-
st.session_state.pop('ranking', None)
|
574 |
-
|
575 |
if 'user_id' not in st.session_state:
|
576 |
st.warning('Please log in first.')
|
577 |
home_btn = st.button('Go to Home Page')
|
|
|
327 |
deselect = st.form_submit_button('Deselect', use_container_width=True)
|
328 |
if deselect:
|
329 |
st.session_state.selected_dict[item['prompt_id']].remove(item['modelVersion_id'])
|
330 |
+
self.remove_ranking_states(item['prompt_id'])
|
331 |
st.experimental_rerun()
|
332 |
|
333 |
else:
|
|
|
335 |
select = st.form_submit_button('Select', use_container_width=True, type='primary')
|
336 |
if select:
|
337 |
st.session_state.selected_dict[item['prompt_id']].append(item['modelVersion_id'])
|
338 |
+
self.remove_ranking_states(item['prompt_id'])
|
339 |
st.experimental_rerun()
|
340 |
|
341 |
# st.write(item)
|
|
|
410 |
|
411 |
def submit_actions(self, status, prompt_id):
|
412 |
# remove counter from session state
|
413 |
+
# st.session_state.pop('counter', None)
|
414 |
+
self.remove_ranking_states('prompt_id')
|
415 |
if status == 'Select':
|
416 |
modelVersions = self.promptBook[self.promptBook['prompt_id'] == prompt_id]['modelVersion_id'].unique()
|
417 |
st.session_state.selected_dict[prompt_id] = modelVersions.tolist()
|
|
|
501 |
st.session_state.score_weights[0: 3] = optimal_weight
|
502 |
|
503 |
|
504 |
+
def remove_ranking_states(self, prompt_id):
|
505 |
+
# for drag sort
|
506 |
+
try:
|
507 |
+
st.session_state.counter[prompt_id] = 0
|
508 |
+
st.session_state.ranking[prompt_id] = {}
|
509 |
+
print('remove ranking states')
|
510 |
+
except:
|
511 |
+
print('no sort ranking states to remove')
|
512 |
+
|
513 |
+
# for battles
|
514 |
+
try:
|
515 |
+
st.session_state.pointer[prompt_id] = {'left': 0, 'right': 1}
|
516 |
+
print('remove battles states')
|
517 |
+
except:
|
518 |
+
print('no battles states to remove')
|
519 |
+
|
520 |
+
# for page progress
|
521 |
+
try:
|
522 |
+
st.session_state.progress[prompt_id] = 'ranking'
|
523 |
+
print('reset page progress states')
|
524 |
+
except:
|
525 |
+
print('no page progress states to be reset')
|
526 |
+
|
527 |
+
|
528 |
# hist_data = pd.DataFrame(np.random.normal(42, 10, (200, 1)), columns=["x"])
|
529 |
@st.cache_resource
|
530 |
def altair_histogram(hist_data, sort_by, mini, maxi):
|
|
|
578 |
# add column to record current row index
|
579 |
promptBook.loc[:, 'row_idx'] = promptBook.index
|
580 |
|
581 |
+
# apply a nsfw filter
|
582 |
+
promptBook = promptBook[promptBook['nsfw_score'] <= 0.84].reset_index(drop=True)
|
583 |
+
|
584 |
+
# add a column that adds up 'norm_clip', 'norm_mcos', and 'norm_pop'
|
585 |
+
score_weights = [1.0, 0.8, 0.2]
|
586 |
+
promptBook.loc[:, 'total_score'] = round(promptBook['norm_clip'] * score_weights[0] + promptBook['norm_mcos'] * score_weights[1] + promptBook['norm_pop'] * score_weights[2], 4)
|
587 |
+
|
588 |
return roster, promptBook, images_ds
|
589 |
|
590 |
@st.cache_data
|
|
|
603 |
if __name__ == "__main__":
|
604 |
st.set_page_config(page_title="Model Coffer Gallery", page_icon="🖼️", layout="wide")
|
605 |
|
|
|
|
|
|
|
606 |
if 'user_id' not in st.session_state:
|
607 |
st.warning('Please log in first.')
|
608 |
home_btn = st.button('Go to Home Page')
|
pages/Ranking.py
CHANGED
@@ -34,6 +34,8 @@ class RankingApp:
|
|
34 |
|
35 |
selected_prompt = st.selectbox('Select a prompt', prompts)
|
36 |
|
|
|
|
|
37 |
items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
|
38 |
prompt_id = items['prompt_id'].unique()[0]
|
39 |
|
@@ -43,7 +45,7 @@ class RankingApp:
|
|
43 |
negative_prompt = st.text_area('Negative Prompt', items['negativePrompt'].unique()[0], height=150, key='negative_prompt', disabled=True)
|
44 |
st.form_submit_button('Generate Images [Coming Soon]', type='primary', use_container_width=True, disabled=True)
|
45 |
|
46 |
-
return prompt_tags, tag, prompt_id, items
|
47 |
|
48 |
def draggable_images(self, items, prompt_id, layout='portrait'):
|
49 |
# init ranking by the order of items
|
@@ -60,7 +62,6 @@ class RankingApp:
|
|
60 |
st.session_state.ranking[prompt_id][st.session_state.counter[prompt_id]][str(items['image_id'][i])] = i
|
61 |
else:
|
62 |
# set the index of items to the corresponding ranking value of the image_id
|
63 |
-
print(items['image_id'])
|
64 |
items.index = items['image_id'].apply(lambda x: st.session_state.ranking[prompt_id][st.session_state.counter[prompt_id]][str(x)])
|
65 |
|
66 |
with elements('dashboard'):
|
@@ -110,6 +111,121 @@ class RankingApp:
|
|
110 |
for k in st.session_state.ranking[prompt_id][batch_idx].keys():
|
111 |
st.session_state.ranking[prompt_id][batch_idx][k] = sorted_list.index(k)
|
112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
def app(self):
|
114 |
st.title('Personal Image Ranking')
|
115 |
st.write('Here you can test out your selected images with any prompt you like.')
|
@@ -120,11 +236,11 @@ class RankingApp:
|
|
120 |
st.session_state.progress = {}
|
121 |
print('current progress: ', st.session_state.progress)
|
122 |
|
123 |
-
prompt_tags, tag, prompt_id, items = self.sidebar()
|
124 |
batch_num = len(items) // self.batch_size
|
125 |
batch_num += 1 if len(items) % self.batch_size != 0 else 0
|
126 |
|
127 |
-
st.session_state.counter[prompt_id] = 0 if prompt_id not in st.session_state.counter else st.session_state.counter[prompt_id]
|
128 |
|
129 |
# save prompt_id in session state
|
130 |
st.session_state.prompt_id_tmp = prompt_id
|
@@ -133,27 +249,10 @@ class RankingApp:
|
|
133 |
st.session_state.progress[prompt_id] = 'ranking'
|
134 |
|
135 |
if st.session_state.progress[prompt_id] == 'ranking':
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
st.progress((st.session_state.counter[prompt_id] + 1) / batch_num, text=f"Batch {st.session_state.counter[prompt_id] + 1} / {batch_num}")
|
141 |
-
# st.write(items.iloc[self.batch_size*st.session_state.counter[prompt_id]: self.batch_size*(st.session_state.counter[prompt_id]+1)])
|
142 |
-
|
143 |
-
width, height = items.loc[0, 'size'].split('x')
|
144 |
-
if int(height) >= int(width):
|
145 |
-
self.draggable_images(items.iloc[self.batch_size*st.session_state.counter[prompt_id]: self.batch_size*(st.session_state.counter[prompt_id]+1)].reset_index(drop=True), prompt_id=prompt_id, layout='portrait')
|
146 |
-
else:
|
147 |
-
self.draggable_images(items.iloc[self.batch_size*st.session_state.counter[prompt_id]: self.batch_size*(st.session_state.counter[prompt_id]+1)].reset_index(drop=True), prompt_id=prompt_id, layout='landscape')
|
148 |
-
# st.write(str(st.session_state.ranking))
|
149 |
-
with control:
|
150 |
-
if st.session_state.counter[prompt_id] < batch_num - 1:
|
151 |
-
st.button(":arrow_right:", key='next', on_click=self.next_batch, help='Next Batch', kwargs={'prompt_id': prompt_id}, use_container_width=True)
|
152 |
-
else:
|
153 |
-
st.button(":heavy_check_mark:", key='finished', on_click=self.next_batch, help='Finished', kwargs={'prompt_id': prompt_id, 'progress': 'finished'}, use_container_width=True)
|
154 |
-
|
155 |
-
if st.session_state.counter[prompt_id] > 0:
|
156 |
-
st.button(":arrow_left:", key='prev', on_click=self.prev_batch, help='Previous Batch', kwargs={'prompt_id': prompt_id}, use_container_width=True)
|
157 |
|
158 |
elif st.session_state.progress[prompt_id] == 'finished':
|
159 |
st.write('## You have ranked all models for this tag!')
|
@@ -171,40 +270,10 @@ class RankingApp:
|
|
171 |
if restart_btn:
|
172 |
st.session_state.progress[prompt_id] = 'ranking'
|
173 |
st.session_state.counter[prompt_id] = 0
|
|
|
174 |
st.experimental_rerun()
|
175 |
|
176 |
|
177 |
-
def next_batch(self, prompt_id, progress=None):
|
178 |
-
# save ranking to dataset
|
179 |
-
# print(st.session_state.ranking)
|
180 |
-
# ranking_dataset = datasets.load_dataset('MAPS-research/GEMRec-Ranking', split='train')
|
181 |
-
curser = RANKING_CONN.cursor()
|
182 |
-
for image_id in st.session_state.ranking[prompt_id][st.session_state.counter[prompt_id]].keys():
|
183 |
-
modelVersion_id = self.promptBook[self.promptBook['image_id'] == image_id]['modelVersion_id'].values[0]
|
184 |
-
ranking = st.session_state.ranking[prompt_id][st.session_state.counter[prompt_id]][image_id]
|
185 |
-
# print({'image_id': image_id, 'modelVersion_id': modelVersion_id, 'ranking': ranking, "user_name": st.session_state.user_id[0], "timestamp": st.session_state.user_id[1]})
|
186 |
-
# ranking_dataset = ranking_dataset.add_item({'image_id': image_id, 'modelVersion_id': modelVersion_id, 'ranking': ranking, "user_name": st.session_state.user_id[0], "timestamp": st.session_state.user_id[1]})
|
187 |
-
|
188 |
-
# remove the old ranking if exists
|
189 |
-
query = "DELETE FROM rankings WHERE image_id = %s AND user_name = %s AND timestamp = %s"
|
190 |
-
curser.execute(query, (image_id, st.session_state.user_id[0], st.session_state.user_id[1]))
|
191 |
-
|
192 |
-
query = "INSERT INTO rankings (image_id, modelVersion_id, ranking, user_name, timestamp) VALUES (%s, %s, %s, %s, %s)"
|
193 |
-
curser.execute(query, (image_id, modelVersion_id, ranking, st.session_state.user_id[0], st.session_state.user_id[1]))
|
194 |
-
|
195 |
-
curser.close()
|
196 |
-
RANKING_CONN.commit()
|
197 |
-
# ranking_dataset.push_to_hub('MAPS-research/GEMRec-Ranking', split='train')
|
198 |
-
|
199 |
-
if progress == 'finished':
|
200 |
-
st.session_state.progress[prompt_id] = 'finished'
|
201 |
-
else:
|
202 |
-
st.session_state.counter[prompt_id] += 1
|
203 |
-
|
204 |
-
def prev_batch(self, prompt_id):
|
205 |
-
st.session_state.counter[prompt_id] -= 1
|
206 |
-
|
207 |
-
|
208 |
def connect_to_db():
|
209 |
conn = pymysql.connect(
|
210 |
host=os.environ.get('RANKING_DB_HOST'),
|
@@ -261,11 +330,14 @@ if __name__ == "__main__":
|
|
261 |
residual = len(user_selections) % 4
|
262 |
if residual != 0:
|
263 |
# select 4-residual items from the promptbook outside the user_selections
|
264 |
-
npc = promptBook[(promptBook['prompt_id'] == key) & (~promptBook['modelVersion_id'].isin(value))].sort_values(by=['
|
265 |
user_selections = pd.concat([user_selections, npc])
|
266 |
|
267 |
promptBook_selected = pd.concat([promptBook_selected, user_selections])
|
268 |
promptBook_selected = promptBook_selected.reset_index(drop=True)
|
|
|
|
|
|
|
269 |
# st.write(promptBook_selected)
|
270 |
images_endpoint = "https://modelcofferbucket.s3-accelerate.amazonaws.com/"
|
271 |
|
|
|
34 |
|
35 |
selected_prompt = st.selectbox('Select a prompt', prompts)
|
36 |
|
37 |
+
mode = st.radio('Select a mode', ['Drag and Sort', 'Battle'], index=1)
|
38 |
+
|
39 |
items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
|
40 |
prompt_id = items['prompt_id'].unique()[0]
|
41 |
|
|
|
45 |
negative_prompt = st.text_area('Negative Prompt', items['negativePrompt'].unique()[0], height=150, key='negative_prompt', disabled=True)
|
46 |
st.form_submit_button('Generate Images [Coming Soon]', type='primary', use_container_width=True, disabled=True)
|
47 |
|
48 |
+
return prompt_tags, tag, prompt_id, items, mode
|
49 |
|
50 |
def draggable_images(self, items, prompt_id, layout='portrait'):
|
51 |
# init ranking by the order of items
|
|
|
62 |
st.session_state.ranking[prompt_id][st.session_state.counter[prompt_id]][str(items['image_id'][i])] = i
|
63 |
else:
|
64 |
# set the index of items to the corresponding ranking value of the image_id
|
|
|
65 |
items.index = items['image_id'].apply(lambda x: st.session_state.ranking[prompt_id][st.session_state.counter[prompt_id]][str(x)])
|
66 |
|
67 |
with elements('dashboard'):
|
|
|
111 |
for k in st.session_state.ranking[prompt_id][batch_idx].keys():
|
112 |
st.session_state.ranking[prompt_id][batch_idx][k] = sorted_list.index(k)
|
113 |
|
114 |
+
def dragsort_mode(self, items, prompt_id, batch_num):
|
115 |
+
st.session_state.counter[prompt_id] = 0 if prompt_id not in st.session_state.counter else \
|
116 |
+
st.session_state.counter[prompt_id]
|
117 |
+
|
118 |
+
sorting, control = st.columns((11, 1), gap='large')
|
119 |
+
with sorting:
|
120 |
+
# st.write('## Sorting')
|
121 |
+
# st.write('Please drag the images to sort them.')
|
122 |
+
st.progress((st.session_state.counter[prompt_id] + 1) / batch_num,
|
123 |
+
text=f"Batch {st.session_state.counter[prompt_id] + 1} / {batch_num}")
|
124 |
+
# st.write(items.iloc[self.batch_size*st.session_state.counter[prompt_id]: self.batch_size*(st.session_state.counter[prompt_id]+1)])
|
125 |
+
|
126 |
+
width, height = items.loc[0, 'size'].split('x')
|
127 |
+
if int(height) >= int(width):
|
128 |
+
self.draggable_images(items.iloc[
|
129 |
+
self.batch_size * st.session_state.counter[prompt_id]: self.batch_size * (
|
130 |
+
st.session_state.counter[prompt_id] + 1)].reset_index(drop=True),
|
131 |
+
prompt_id=prompt_id, layout='portrait')
|
132 |
+
else:
|
133 |
+
self.draggable_images(items.iloc[
|
134 |
+
self.batch_size * st.session_state.counter[prompt_id]: self.batch_size * (
|
135 |
+
st.session_state.counter[prompt_id] + 1)].reset_index(drop=True),
|
136 |
+
prompt_id=prompt_id, layout='landscape')
|
137 |
+
# st.write(str(st.session_state.ranking))
|
138 |
+
with control:
|
139 |
+
if st.session_state.counter[prompt_id] < batch_num - 1:
|
140 |
+
st.button(":arrow_right:", key='next', on_click=self.next_batch, help='Next Batch',
|
141 |
+
kwargs={'prompt_id': prompt_id}, use_container_width=True)
|
142 |
+
else:
|
143 |
+
st.button(":heavy_check_mark:", key='finished', on_click=self.next_batch, help='Finished',
|
144 |
+
kwargs={'prompt_id': prompt_id, 'progress': 'finished'}, use_container_width=True)
|
145 |
+
|
146 |
+
if st.session_state.counter[prompt_id] > 0:
|
147 |
+
st.button(":arrow_left:", key='prev', on_click=self.prev_batch, help='Previous Batch',
|
148 |
+
kwargs={'prompt_id': prompt_id}, use_container_width=True)
|
149 |
+
|
150 |
+
def next_batch(self, prompt_id, progress=None):
|
151 |
+
# save ranking to dataset
|
152 |
+
# print(st.session_state.ranking)
|
153 |
+
# ranking_dataset = datasets.load_dataset('MAPS-research/GEMRec-Ranking', split='train')
|
154 |
+
curser = RANKING_CONN.cursor()
|
155 |
+
for image_id in st.session_state.ranking[prompt_id][st.session_state.counter[prompt_id]].keys():
|
156 |
+
modelVersion_id = self.promptBook[self.promptBook['image_id'] == image_id]['modelVersion_id'].values[0]
|
157 |
+
ranking = st.session_state.ranking[prompt_id][st.session_state.counter[prompt_id]][image_id]
|
158 |
+
# print({'image_id': image_id, 'modelVersion_id': modelVersion_id, 'ranking': ranking, "user_name": st.session_state.user_id[0], "timestamp": st.session_state.user_id[1]})
|
159 |
+
# ranking_dataset = ranking_dataset.add_item({'image_id': image_id, 'modelVersion_id': modelVersion_id, 'ranking': ranking, "user_name": st.session_state.user_id[0], "timestamp": st.session_state.user_id[1]})
|
160 |
+
|
161 |
+
# remove the old ranking if exists
|
162 |
+
query = "DELETE FROM rankings WHERE image_id = %s AND user_name = %s AND timestamp = %s"
|
163 |
+
curser.execute(query, (image_id, st.session_state.user_id[0], st.session_state.user_id[1]))
|
164 |
+
|
165 |
+
query = "INSERT INTO rankings (image_id, modelVersion_id, ranking, user_name, timestamp) VALUES (%s, %s, %s, %s, %s)"
|
166 |
+
curser.execute(query, (image_id, modelVersion_id, ranking, st.session_state.user_id[0], st.session_state.user_id[1]))
|
167 |
+
|
168 |
+
curser.close()
|
169 |
+
RANKING_CONN.commit()
|
170 |
+
# ranking_dataset.push_to_hub('MAPS-research/GEMRec-Ranking', split='train')
|
171 |
+
|
172 |
+
if progress == 'finished':
|
173 |
+
st.session_state.progress[prompt_id] = 'finished'
|
174 |
+
else:
|
175 |
+
st.session_state.counter[prompt_id] += 1
|
176 |
+
|
177 |
+
def prev_batch(self, prompt_id):
|
178 |
+
st.session_state.counter[prompt_id] -= 1
|
179 |
+
|
180 |
+
def battle_images(self, items, prompt_id):
|
181 |
+
if 'pointer' not in st.session_state:
|
182 |
+
st.session_state.pointer = {}
|
183 |
+
|
184 |
+
if prompt_id not in st.session_state.pointer:
|
185 |
+
st.session_state.pointer[prompt_id] = {'left': 0, 'right': 1}
|
186 |
+
|
187 |
+
curr_position = max(st.session_state.pointer[prompt_id]['left'], st.session_state.pointer[prompt_id]['right'])
|
188 |
+
progress = st.progress(curr_position / (len(items)-1), text=f"Progress {curr_position} / {len(items)-1}")
|
189 |
+
|
190 |
+
# if curr_position == len(items) - 1:
|
191 |
+
# st.session_state.progress[prompt_id] = 'finished'
|
192 |
+
#
|
193 |
+
# else:
|
194 |
+
left, right = st.columns(2)
|
195 |
+
with left:
|
196 |
+
image_id = items['image_id'][st.session_state.pointer[prompt_id]['left']]
|
197 |
+
img_url = self.images_endpoint + str(image_id) + '.png'
|
198 |
+
st.image(img_url, use_column_width=True)
|
199 |
+
|
200 |
+
# write the total score of this image
|
201 |
+
total_score = items['total_score'][st.session_state.pointer[prompt_id]['left']]
|
202 |
+
st.write(f'Total Score: {total_score}')
|
203 |
+
|
204 |
+
btn_left = st.button('Left is better', key='left', on_click=self.next_battle, kwargs={'prompt_id': prompt_id, 'winner': 'left', 'curr_position': curr_position, 'total_num': len(items)}, use_container_width=True)
|
205 |
+
|
206 |
+
with right:
|
207 |
+
image_id = items['image_id'][st.session_state.pointer[prompt_id]['right']]
|
208 |
+
img_url = self.images_endpoint + str(image_id) + '.png'
|
209 |
+
st.image(img_url, use_column_width=True)
|
210 |
+
|
211 |
+
# write the total score of this image
|
212 |
+
total_score = items['total_score'][st.session_state.pointer[prompt_id]['right']]
|
213 |
+
st.write(f'Total Score: {total_score}')
|
214 |
+
|
215 |
+
btn_right = st.button('Right is better', key='right', on_click=self.next_battle, kwargs={'prompt_id': prompt_id, 'winner': 'right', 'curr_position': curr_position, 'total_num': len(items)}, use_container_width=True)
|
216 |
+
|
217 |
+
def next_battle(self, prompt_id, winner, curr_position, total_num):
|
218 |
+
loser = 'left' if winner == 'right' else 'right'
|
219 |
+
|
220 |
+
if curr_position == total_num - 1:
|
221 |
+
st.session_state.progress[prompt_id] = 'finished'
|
222 |
+
# st.experimental_rerun()
|
223 |
+
else:
|
224 |
+
st.session_state.pointer[prompt_id][loser] = curr_position + 1
|
225 |
+
|
226 |
+
def battle_mode(self, items, prompt_id):
|
227 |
+
self.battle_images(items, prompt_id)
|
228 |
+
|
229 |
def app(self):
|
230 |
st.title('Personal Image Ranking')
|
231 |
st.write('Here you can test out your selected images with any prompt you like.')
|
|
|
236 |
st.session_state.progress = {}
|
237 |
print('current progress: ', st.session_state.progress)
|
238 |
|
239 |
+
prompt_tags, tag, prompt_id, items, mode = self.sidebar()
|
240 |
batch_num = len(items) // self.batch_size
|
241 |
batch_num += 1 if len(items) % self.batch_size != 0 else 0
|
242 |
|
243 |
+
# st.session_state.counter[prompt_id] = 0 if prompt_id not in st.session_state.counter else st.session_state.counter[prompt_id]
|
244 |
|
245 |
# save prompt_id in session state
|
246 |
st.session_state.prompt_id_tmp = prompt_id
|
|
|
249 |
st.session_state.progress[prompt_id] = 'ranking'
|
250 |
|
251 |
if st.session_state.progress[prompt_id] == 'ranking':
|
252 |
+
if mode == 'Drag and Sort':
|
253 |
+
self.dragsort_mode(items, prompt_id, batch_num)
|
254 |
+
elif mode == 'Battle':
|
255 |
+
self.battle_mode(items, prompt_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
|
257 |
elif st.session_state.progress[prompt_id] == 'finished':
|
258 |
st.write('## You have ranked all models for this tag!')
|
|
|
270 |
if restart_btn:
|
271 |
st.session_state.progress[prompt_id] = 'ranking'
|
272 |
st.session_state.counter[prompt_id] = 0
|
273 |
+
st.session_state.pointer[prompt_id] = {'left': 0, 'right': 1}
|
274 |
st.experimental_rerun()
|
275 |
|
276 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
277 |
def connect_to_db():
|
278 |
conn = pymysql.connect(
|
279 |
host=os.environ.get('RANKING_DB_HOST'),
|
|
|
330 |
residual = len(user_selections) % 4
|
331 |
if residual != 0:
|
332 |
# select 4-residual items from the promptbook outside the user_selections
|
333 |
+
npc = promptBook[(promptBook['prompt_id'] == key) & (~promptBook['modelVersion_id'].isin(value))].sort_values(by=['total_score'], ascending=False).reset_index(drop=True).iloc[:4-residual]
|
334 |
user_selections = pd.concat([user_selections, npc])
|
335 |
|
336 |
promptBook_selected = pd.concat([promptBook_selected, user_selections])
|
337 |
promptBook_selected = promptBook_selected.reset_index(drop=True)
|
338 |
+
# sort promptBook by total_score
|
339 |
+
promptBook_selected = promptBook_selected.sort_values(by=['total_score'], ascending=True).reset_index(drop=True)
|
340 |
+
|
341 |
# st.write(promptBook_selected)
|
342 |
images_endpoint = "https://modelcofferbucket.s3-accelerate.amazonaws.com/"
|
343 |
|