Spaces:
Running
Running
update battle mode
Browse files- pages/Ranking.py +19 -12
- pages/{Results.py β Summary.py} +13 -13
pages/Ranking.py
CHANGED
@@ -6,6 +6,7 @@ import pandas as pd
|
|
6 |
import pymysql.cursors
|
7 |
import streamlit as st
|
8 |
|
|
|
9 |
from streamlit_elements import elements, mui, html, dashboard, nivo
|
10 |
from streamlit_extras.switch_page_button import switch_page
|
11 |
|
@@ -224,6 +225,7 @@ class RankingApp:
|
|
224 |
|
225 |
def next_battle(self, prompt_id, image_ids, winner, curr_position, total_num):
|
226 |
loser = 'left' if winner == 'right' else 'right'
|
|
|
227 |
|
228 |
curser = RANKING_CONN.cursor()
|
229 |
|
@@ -236,8 +238,8 @@ class RankingApp:
|
|
236 |
curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], prompt_id, loser_modelVersion_id, winner_modelVersion_id))
|
237 |
|
238 |
# insert the battle result into the database
|
239 |
-
query = "INSERT INTO battle_results (username, timestamp, tag, prompt_id, winner, loser) VALUES (%s, %s, %s, %s, %s, %s)"
|
240 |
-
curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], self.promptBook[self.promptBook['prompt_id'] == prompt_id]['tag'].values[0], prompt_id, winner_modelVersion_id, loser_modelVersion_id))
|
241 |
|
242 |
curser.close()
|
243 |
RANKING_CONN.commit()
|
@@ -282,22 +284,27 @@ class RankingApp:
|
|
282 |
elif st.session_state.progress[prompt_id] == 'finished':
|
283 |
st.write('## You have ranked all models for this tag!')
|
284 |
st.write('Thank you for your participation! Feel free to do the following things:')
|
285 |
-
st.write('* Rank for other tags and prompts.')
|
286 |
-
st.write('* Back to the gallery page to see more images.')
|
287 |
-
st.write('* Rank again for this tag and prompt.')
|
288 |
-
st.write('*
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
switch_page('gallery')
|
293 |
-
|
294 |
-
restart_btn = st.button('ποΈ Rank Again')
|
295 |
if restart_btn:
|
296 |
st.session_state.progress[prompt_id] = 'ranking'
|
297 |
st.session_state.counter[prompt_id] = 0
|
298 |
st.session_state.pointer[prompt_id] = {'left': 0, 'right': 1}
|
299 |
st.experimental_rerun()
|
300 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
301 |
|
302 |
def connect_to_db():
|
303 |
conn = pymysql.connect(
|
|
|
6 |
import pymysql.cursors
|
7 |
import streamlit as st
|
8 |
|
9 |
+
from datetime import datetime
|
10 |
from streamlit_elements import elements, mui, html, dashboard, nivo
|
11 |
from streamlit_extras.switch_page_button import switch_page
|
12 |
|
|
|
225 |
|
226 |
def next_battle(self, prompt_id, image_ids, winner, curr_position, total_num):
|
227 |
loser = 'left' if winner == 'right' else 'right'
|
228 |
+
battletime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
229 |
|
230 |
curser = RANKING_CONN.cursor()
|
231 |
|
|
|
238 |
curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], prompt_id, loser_modelVersion_id, winner_modelVersion_id))
|
239 |
|
240 |
# insert the battle result into the database
|
241 |
+
query = "INSERT INTO battle_results (username, timestamp, tag, prompt_id, winner, loser, battletime) VALUES (%s, %s, %s, %s, %s, %s, %s)"
|
242 |
+
curser.execute(query, (st.session_state.user_id[0], st.session_state.user_id[1], self.promptBook[self.promptBook['prompt_id'] == prompt_id]['tag'].values[0], prompt_id, winner_modelVersion_id, loser_modelVersion_id, battletime))
|
243 |
|
244 |
curser.close()
|
245 |
RANKING_CONN.commit()
|
|
|
284 |
elif st.session_state.progress[prompt_id] == 'finished':
|
285 |
st.write('## You have ranked all models for this tag!')
|
286 |
st.write('Thank you for your participation! Feel free to do the following things:')
|
287 |
+
# st.write('* Rank for other tags and prompts.')
|
288 |
+
# st.write('* Back to the gallery page to see more images.')
|
289 |
+
# st.write('* Rank again for this tag and prompt.')
|
290 |
+
# st.write('* Check the summary to see what model you like most.')
|
291 |
+
# st.write('*More functions are coming soon... Please stay tuned*')
|
292 |
+
st.button('π Rank for other tags and prompts')
|
293 |
+
restart_btn = st.button('ποΈ Rank this prompt again')
|
|
|
|
|
|
|
294 |
if restart_btn:
|
295 |
st.session_state.progress[prompt_id] = 'ranking'
|
296 |
st.session_state.counter[prompt_id] = 0
|
297 |
st.session_state.pointer[prompt_id] = {'left': 0, 'right': 1}
|
298 |
st.experimental_rerun()
|
299 |
|
300 |
+
gallery_btn = st.button('πΌοΈ Back to Gallery')
|
301 |
+
if gallery_btn:
|
302 |
+
switch_page('gallery')
|
303 |
+
|
304 |
+
summary_btn = st.button('π See Summary')
|
305 |
+
if summary_btn:
|
306 |
+
switch_page('summary')
|
307 |
+
|
308 |
|
309 |
def connect_to_db():
|
310 |
conn = pymysql.connect(
|
pages/{Results.py β Summary.py}
RENAMED
@@ -49,20 +49,20 @@ class DashboardApp:
|
|
49 |
n = 3
|
50 |
metric_cols = st.columns(n)
|
51 |
image_display = st.empty()
|
52 |
-
|
53 |
for i in range(n):
|
54 |
with metric_cols[i]:
|
55 |
modelVersion_id = modelVersion_standings[i][0]
|
56 |
winning_times = modelVersion_standings[i][1]
|
57 |
|
58 |
-
model_name, modelVersion_name, url = self.roster[self.roster['modelVersion_id'] == modelVersion_id][['model_name', 'modelVersion_name', 'modelVersion_url']].values[0]
|
59 |
|
60 |
metric_card = stylable_container(
|
61 |
key="container_with_border",
|
62 |
css_styles="""
|
63 |
{
|
64 |
-
border: 1.5px solid rgba(49, 51, 63, 0.
|
65 |
-
border-left: 0.5rem solid
|
66 |
border-radius: 5px;
|
67 |
padding: calc(1em + 5px);
|
68 |
gap: 0.5em;
|
@@ -74,8 +74,8 @@ class DashboardApp:
|
|
74 |
|
75 |
with metric_card:
|
76 |
icon = 'π₯'if i == 0 else 'π₯' if i == 1 else 'π₯'
|
77 |
-
st.write(
|
78 |
-
st.write(f'### {icon} {model_name}, [{modelVersion_name}](https://civitai.com/models/{modelVersion_id})')
|
79 |
st.write(f'Ranking Score: {winning_times}')
|
80 |
|
81 |
show_image = st.button('Show Image', key=modelVersion_id, use_container_width=True)
|
@@ -105,20 +105,21 @@ class DashboardApp:
|
|
105 |
|
106 |
|
107 |
def score_calculator(self, results, db_table):
|
|
|
|
|
|
|
108 |
modelVersion_standings = {}
|
109 |
if db_table == 'battle_results':
|
110 |
for record in results:
|
111 |
modelVersion_standings[record['winner']] = modelVersion_standings.get(record['winner'], 0) + 1
|
112 |
-
# add the winning time of the loser
|
113 |
-
curser = RANKING_CONN.cursor()
|
114 |
-
curser.execute(f"SELECT COUNT(*) FROM {db_table} WHERE username = '{st.session_state.user_id[0]}' AND timestamp = '{st.session_state.user_id[1]}' AND winner = '{record['loser']}'")
|
115 |
-
modelVersion_standings[record['winner']] += curser.fetchone()['COUNT(*)']
|
116 |
-
curser.close()
|
117 |
|
118 |
# add the loser who never wins
|
119 |
if record['loser'] not in modelVersion_standings:
|
120 |
modelVersion_standings[record['loser']] = 0
|
121 |
|
|
|
|
|
|
|
122 |
elif db_table == 'sort_results':
|
123 |
pts_map = {'position1': 5, 'position2': 3, 'position3': 1, 'position4': 0}
|
124 |
for record in results:
|
@@ -128,11 +129,10 @@ class DashboardApp:
|
|
128 |
return modelVersion_standings
|
129 |
|
130 |
|
131 |
-
|
132 |
def app(self):
|
133 |
st.title('Your Preferred Models', help="Scores are calculated based on your ranking results.")
|
134 |
|
135 |
-
mode = st.sidebar.radio('Ranking mode', ['Sort', 'Battle'], horizontal=True)
|
136 |
# get tags from database of the current user
|
137 |
db_table = 'sort_results' if mode == 'Sort' else 'battle_results'
|
138 |
|
|
|
49 |
n = 3
|
50 |
metric_cols = st.columns(n)
|
51 |
image_display = st.empty()
|
52 |
+
|
53 |
for i in range(n):
|
54 |
with metric_cols[i]:
|
55 |
modelVersion_id = modelVersion_standings[i][0]
|
56 |
winning_times = modelVersion_standings[i][1]
|
57 |
|
58 |
+
model_id, model_name, modelVersion_name, url = self.roster[self.roster['modelVersion_id'] == modelVersion_id][['model_id', 'model_name', 'modelVersion_name', 'modelVersion_url']].values[0]
|
59 |
|
60 |
metric_card = stylable_container(
|
61 |
key="container_with_border",
|
62 |
css_styles="""
|
63 |
{
|
64 |
+
border: 1.5px solid rgba(49, 51, 63, 0.2);
|
65 |
+
border-left: 0.5rem solid gold;
|
66 |
border-radius: 5px;
|
67 |
padding: calc(1em + 5px);
|
68 |
gap: 0.5em;
|
|
|
74 |
|
75 |
with metric_card:
|
76 |
icon = 'π₯'if i == 0 else 'π₯' if i == 1 else 'π₯'
|
77 |
+
# st.write(model_id)
|
78 |
+
st.write(f'### {icon} {model_name}, [{modelVersion_name}](https://civitai.com/models/{model_id}?modelVersionId={modelVersion_id})')
|
79 |
st.write(f'Ranking Score: {winning_times}')
|
80 |
|
81 |
show_image = st.button('Show Image', key=modelVersion_id, use_container_width=True)
|
|
|
105 |
|
106 |
|
107 |
def score_calculator(self, results, db_table):
|
108 |
+
# sort results by battle time
|
109 |
+
results = sorted(results, key=lambda x: x['battletime'])
|
110 |
+
|
111 |
modelVersion_standings = {}
|
112 |
if db_table == 'battle_results':
|
113 |
for record in results:
|
114 |
modelVersion_standings[record['winner']] = modelVersion_standings.get(record['winner'], 0) + 1
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
# add the loser who never wins
|
117 |
if record['loser'] not in modelVersion_standings:
|
118 |
modelVersion_standings[record['loser']] = 0
|
119 |
|
120 |
+
# add the winning time of the loser to the winner
|
121 |
+
modelVersion_standings[record['winner']] += modelVersion_standings[record['loser']]
|
122 |
+
|
123 |
elif db_table == 'sort_results':
|
124 |
pts_map = {'position1': 5, 'position2': 3, 'position3': 1, 'position4': 0}
|
125 |
for record in results:
|
|
|
129 |
return modelVersion_standings
|
130 |
|
131 |
|
|
|
132 |
def app(self):
|
133 |
st.title('Your Preferred Models', help="Scores are calculated based on your ranking results.")
|
134 |
|
135 |
+
mode = st.sidebar.radio('Ranking mode', ['Sort', 'Battle'], horizontal=True, index=1)
|
136 |
# get tags from database of the current user
|
137 |
db_table = 'sort_results' if mode == 'Sort' else 'battle_results'
|
138 |
|