Spaces:
Running
Running
big update for database structure
Browse filesadd two seperate databases: battle_results and sort_results
- pages/Ranking.py +41 -16
pages/Ranking.py
CHANGED
@@ -148,26 +148,34 @@ class RankingApp:
|
|
148 |
kwargs={'prompt_id': prompt_id}, use_container_width=True)
|
149 |
|
150 |
def next_batch(self, prompt_id, progress=None):
|
151 |
-
|
152 |
-
# print(st.session_state.ranking)
|
153 |
-
# ranking_dataset = datasets.load_dataset('MAPS-research/GEMRec-Ranking', split='train')
|
154 |
curser = RANKING_CONN.cursor()
|
155 |
-
|
|
|
|
|
|
|
156 |
modelVersion_id = self.promptBook[self.promptBook['image_id'] == image_id]['modelVersion_id'].values[0]
|
157 |
-
|
158 |
-
|
159 |
-
|
|
|
|
|
|
|
|
|
160 |
|
161 |
-
|
162 |
-
|
163 |
-
|
|
|
|
|
|
|
|
|
164 |
|
165 |
-
|
166 |
-
|
167 |
|
168 |
curser.close()
|
169 |
RANKING_CONN.commit()
|
170 |
-
# ranking_dataset.push_to_hub('MAPS-research/GEMRec-Ranking', split='train')
|
171 |
|
172 |
if progress == 'finished':
|
173 |
st.session_state.progress[prompt_id] = 'finished'
|
@@ -201,7 +209,7 @@ class RankingApp:
|
|
201 |
total_score = items['total_score'][st.session_state.pointer[prompt_id]['left']]
|
202 |
st.write(f'Total Score: {total_score}')
|
203 |
|
204 |
-
btn_left = st.button('Left is better', key='left', on_click=self.next_battle, kwargs={'prompt_id': prompt_id, 'winner': 'left', 'curr_position': curr_position, 'total_num': len(items)}, use_container_width=True)
|
205 |
|
206 |
with right:
|
207 |
image_id = items['image_id'][st.session_state.pointer[prompt_id]['right']]
|
@@ -212,11 +220,28 @@ class RankingApp:
|
|
212 |
total_score = items['total_score'][st.session_state.pointer[prompt_id]['right']]
|
213 |
st.write(f'Total Score: {total_score}')
|
214 |
|
215 |
-
btn_right = st.button('Right is better', key='right', on_click=self.next_battle, kwargs={'prompt_id': prompt_id, 'winner': 'right', 'curr_position': curr_position, 'total_num': len(items)}, use_container_width=True)
|
216 |
|
217 |
-
def next_battle(self, prompt_id, winner, curr_position, total_num):
|
218 |
loser = 'left' if winner == 'right' else 'right'
|
219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
if curr_position == total_num - 1:
|
221 |
st.session_state.progress[prompt_id] = 'finished'
|
222 |
# st.experimental_rerun()
|
|
|
148 |
kwargs={'prompt_id': prompt_id}, use_container_width=True)
|
149 |
|
150 |
def next_batch(self, prompt_id, progress=None):
|
151 |
+
|
|
|
|
|
152 |
curser = RANKING_CONN.cursor()
|
153 |
+
|
154 |
+
# a not so elegant way to get the modelVersion_id of each image, but it works
|
155 |
+
position_version_dict = {}
|
156 |
+
for image_id, position in st.session_state.ranking[prompt_id][st.session_state.counter[prompt_id]].items():
|
157 |
modelVersion_id = self.promptBook[self.promptBook['image_id'] == image_id]['modelVersion_id'].values[0]
|
158 |
+
position_version_dict[position] = modelVersion_id
|
159 |
+
|
160 |
+
# get all records of this user and prompt
|
161 |
+
query = "SELECT * FROM sort_results WHERE username = %s AND timestamp = %s AND prompt_id = %s"
|
162 |
+
curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], prompt_id))
|
163 |
+
results = curser.fetchall()
|
164 |
+
print(results)
|
165 |
|
166 |
+
# remove the old ranking with the same modelVersion_id if exists
|
167 |
+
for result in results:
|
168 |
+
prev_ids = [result['position1'], result['position2'], result['position3'], result['position4']]
|
169 |
+
curr_ids = [position_version_dict[0], position_version_dict[1], position_version_dict[2], position_version_dict[3]]
|
170 |
+
if len(set(prev_ids).intersection(set(curr_ids))) == 4:
|
171 |
+
query = "DELETE FROM sort_results WHERE username = %s AND timestamp = %s AND prompt_id = %s AND position1 = %s AND position2 = %s AND position3 = %s AND position4 = %s"
|
172 |
+
curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], prompt_id, result['position1'], result['position2'], result['position3'], result['position4']))
|
173 |
|
174 |
+
query = "INSERT INTO sort_results (username, timestamp, tag, prompt_id, position1, position2, position3, position4) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)"
|
175 |
+
curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], self.promptBook[self.promptBook['prompt_id'] == prompt_id]['tag'].values[0], prompt_id, position_version_dict[0], position_version_dict[1], position_version_dict[2], position_version_dict[3]))
|
176 |
|
177 |
curser.close()
|
178 |
RANKING_CONN.commit()
|
|
|
179 |
|
180 |
if progress == 'finished':
|
181 |
st.session_state.progress[prompt_id] = 'finished'
|
|
|
209 |
total_score = items['total_score'][st.session_state.pointer[prompt_id]['left']]
|
210 |
st.write(f'Total Score: {total_score}')
|
211 |
|
212 |
+
btn_left = st.button('Left is better', key='left', on_click=self.next_battle, kwargs={'prompt_id': prompt_id, 'image_ids': items['image_id'], 'winner': 'left', 'curr_position': curr_position, 'total_num': len(items)}, use_container_width=True)
|
213 |
|
214 |
with right:
|
215 |
image_id = items['image_id'][st.session_state.pointer[prompt_id]['right']]
|
|
|
220 |
total_score = items['total_score'][st.session_state.pointer[prompt_id]['right']]
|
221 |
st.write(f'Total Score: {total_score}')
|
222 |
|
223 |
+
btn_right = st.button('Right is better', key='right', on_click=self.next_battle, kwargs={'prompt_id': prompt_id, 'image_ids': items['image_id'], 'winner': 'right', 'curr_position': curr_position, 'total_num': len(items)}, use_container_width=True)
|
224 |
|
225 |
+
def next_battle(self, prompt_id, image_ids, winner, curr_position, total_num):
|
226 |
loser = 'left' if winner == 'right' else 'right'
|
227 |
|
228 |
+
curser = RANKING_CONN.cursor()
|
229 |
+
|
230 |
+
winner_modelVersion_id = self.promptBook[self.promptBook['image_id'] == image_ids[st.session_state.pointer[prompt_id][winner]]]['modelVersion_id'].values[0]
|
231 |
+
loser_modelVersion_id = self.promptBook[self.promptBook['image_id'] == image_ids[st.session_state.pointer[prompt_id][loser]]]['modelVersion_id'].values[0]
|
232 |
+
|
233 |
+
# remove the old battle result if exists
|
234 |
+
query = "DELETE FROM battle_results WHERE username = %s AND timestamp = %s AND prompt_id = %s AND winner = %s AND loser = %s"
|
235 |
+
curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], prompt_id, winner_modelVersion_id, loser_modelVersion_id))
|
236 |
+
curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], prompt_id, loser_modelVersion_id, winner_modelVersion_id))
|
237 |
+
|
238 |
+
# insert the battle result into the database
|
239 |
+
query = "INSERT INTO battle_results (username, timestamp, tag, prompt_id, winner, loser) VALUES (%s, %s, %s, %s, %s, %s)"
|
240 |
+
curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], self.promptBook[self.promptBook['prompt_id'] == prompt_id]['tag'].values[0], prompt_id, winner_modelVersion_id, loser_modelVersion_id))
|
241 |
+
|
242 |
+
curser.close()
|
243 |
+
RANKING_CONN.commit()
|
244 |
+
|
245 |
if curr_position == total_num - 1:
|
246 |
st.session_state.progress[prompt_id] = 'finished'
|
247 |
# st.experimental_rerun()
|