Set up arena
Browse files- README.md +3 -3
- app.py +362 -0
- config.py +61 -0
- requirements.txt +7 -0
- static/modelViewer.js +58 -0
- static/plots.js +6 -0
- static/popup.js +24 -0
- static/style.css +135 -0
- utils/__pycache__/data_utils.cpython-310.pyc +0 -0
- utils/__pycache__/emoji_utils.cpython-310.pyc +0 -0
- utils/__pycache__/plot_utils.cpython-310.pyc +0 -0
- utils/__pycache__/s3_utils.cpython-310.pyc +0 -0
- utils/__pycache__/utils.cpython-310.pyc +0 -0
- utils/data_utils.py +49 -0
- utils/emoji_utils.py +14 -0
- utils/graphs.ipynb +0 -0
- utils/plot_utils.py +28 -0
- utils/s3_utils.py +123 -0
- utils/utils.py +176 -0
README.md
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
---
|
2 |
title: 3D Animation Arena
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
colorTo: gray
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.27.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
short_description: Arena to rank
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: 3D Animation Arena
|
3 |
+
emoji: 🐠
|
4 |
+
colorFrom: green
|
5 |
colorTo: gray
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.27.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
short_description: Arena to rank 3D animation models
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from pathlib import Path
|
3 |
+
import uuid
|
4 |
+
import random
|
5 |
+
|
6 |
+
from utils.data_utils import generate_leaderboard
|
7 |
+
from utils.plot_utils import plot_ratings
|
8 |
+
from utils.utils import simulate, submit_rating, generate_matchup
|
9 |
+
from config import MODE, VIDEOS, MODELS, CRITERIA, default_beta
|
10 |
+
|
11 |
+
|
12 |
+
head = f"""
|
13 |
+
<script src="https://ajax.googleapis.com/ajax/libs/jquery/3.2.1/jquery.min.js"></script>
|
14 |
+
<script src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.7/js/bootstrap.min.js"></script>
|
15 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/plotly.js/1.33.1/plotly.min.js"></script>
|
16 |
+
<script>{Path('static/modelViewer.js').read_text()}</script>
|
17 |
+
<script>{Path('static/popup.js').read_text()}</script>
|
18 |
+
<script>{Path('static/plots.js').read_text()}</script>
|
19 |
+
"""
|
20 |
+
|
21 |
+
with gr.Blocks(title='3D Animation Arena', head=head, css_paths='static/style.css') as arena:
|
22 |
+
|
23 |
+
sessionState = gr.State({
|
24 |
+
'video': None,
|
25 |
+
'modelLeft': None,
|
26 |
+
'modelRight': None,
|
27 |
+
'darkMode': False,
|
28 |
+
'videos': VIDEOS,
|
29 |
+
'currentTab': CRITERIA[0],
|
30 |
+
'uuid': None
|
31 |
+
})
|
32 |
+
|
33 |
+
frontState = gr.JSON(sessionState, visible=False)
|
34 |
+
|
35 |
+
with gr.Row():
|
36 |
+
with gr.Column(scale=1):
|
37 |
+
gr.HTML('')
|
38 |
+
with gr.Column(scale=12):
|
39 |
+
gr.HTML("<h1 style='text-align:center; font-size:50px'>3D Animation Arena</h1>")
|
40 |
+
with gr.Column(scale=1):
|
41 |
+
toggle_dark = gr.Button(value="Dark Mode")
|
42 |
+
|
43 |
+
def update_toggle_dark(state):
|
44 |
+
state['darkMode'] = not state['darkMode']
|
45 |
+
if state['darkMode']:
|
46 |
+
return gr.update(value="Light Mode"), state
|
47 |
+
else:
|
48 |
+
return gr.update(value="Dark Mode"), state
|
49 |
+
|
50 |
+
toggle_dark.click(
|
51 |
+
inputs=[sessionState],
|
52 |
+
js="""
|
53 |
+
() => {
|
54 |
+
document.body.classList.toggle('dark');
|
55 |
+
}
|
56 |
+
""",
|
57 |
+
fn=update_toggle_dark,
|
58 |
+
outputs=[toggle_dark, sessionState]
|
59 |
+
)
|
60 |
+
|
61 |
+
with gr.Tab(label='Arena'):
|
62 |
+
models = gr.HTML('''
|
63 |
+
<div class="viewer-container">
|
64 |
+
<iframe
|
65 |
+
id="modelViewerLeft"
|
66 |
+
src="https://d39vhmln1nnc4z.cloudfront.net/index.html"
|
67 |
+
width="100%"
|
68 |
+
height="100%"
|
69 |
+
allow="storage-access"
|
70 |
+
></iframe>
|
71 |
+
|
72 |
+
<iframe
|
73 |
+
id="modelViewerRight"
|
74 |
+
src="https://d39vhmln1nnc4z.cloudfront.net/index.html"
|
75 |
+
width="100%"
|
76 |
+
height="100%"
|
77 |
+
allow="storage-access"
|
78 |
+
></iframe>
|
79 |
+
</div>''',
|
80 |
+
render=False)
|
81 |
+
|
82 |
+
with gr.Row():
|
83 |
+
with gr.Column(scale=1):
|
84 |
+
gr.HTML(f"<h1>1. Choose a video below:</h1>")
|
85 |
+
video = gr.Video(
|
86 |
+
label='Input Video',
|
87 |
+
interactive=False,
|
88 |
+
autoplay=True,
|
89 |
+
show_download_button=False,
|
90 |
+
loop=True,
|
91 |
+
elem_id='gradioVideo',
|
92 |
+
)
|
93 |
+
|
94 |
+
triggerButtons = {}
|
95 |
+
for vid in sessionState.value['videos']:
|
96 |
+
triggerButtons[vid] = gr.Button(elem_id=f'triggerBtn_{vid}', visible=False)
|
97 |
+
triggerButtons[vid].click(
|
98 |
+
fn=lambda vid=vid: gr.update(value=f'https://gradio-model-viewer.s3.eu-west-1.amazonaws.com/sample+videos/{vid}.mp4'),
|
99 |
+
outputs=[video]
|
100 |
+
)
|
101 |
+
examples = gr.HTML(visible=False)
|
102 |
+
|
103 |
+
with gr.Column(scale=4):
|
104 |
+
gr.HTML("""
|
105 |
+
<h1>2. Play around with the models:
|
106 |
+
<span class="glyphicon glyphicon-question-sign popup-btn btn btn-info btn-lg" data-popup-id="instructionsPopup">
|
107 |
+
<span class="popup-text" id="instructionsPopup">You can control the playback in both viewers at the same time by using the video, or control both viewers independently by using mouse and GUI!</span>
|
108 |
+
</span>
|
109 |
+
</h1>
|
110 |
+
""")
|
111 |
+
with gr.Row():
|
112 |
+
models.render()
|
113 |
+
|
114 |
+
with gr.Row():
|
115 |
+
gr.HTML(f"<h1>3. Choose your favorite model for each criteria:</h1>")
|
116 |
+
ratingButtons = {}
|
117 |
+
for criteria in CRITERIA:
|
118 |
+
with gr.Row():
|
119 |
+
with gr.Column():
|
120 |
+
with gr.Row():
|
121 |
+
match criteria:
|
122 |
+
case 'Global_Appreciation':
|
123 |
+
instructions = "Your overall appreciation of the models, including general aesthetics and self-contacts if applicable."
|
124 |
+
case 'Ground_Contacts':
|
125 |
+
instructions = "The quality of the models' contacts with the ground, including ground penetration and foot sliding."
|
126 |
+
case 'Fidelity':
|
127 |
+
instructions = "The fidelity of the models compared to the motion of the original video."
|
128 |
+
case 'Fluidity':
|
129 |
+
instructions = "The smoothness and temporal coherence of the models."
|
130 |
+
gr.HTML(f"""
|
131 |
+
<h2 style='text-align:center;'>{criteria.replace('_', ' ')}
|
132 |
+
<span class="glyphicon glyphicon-question-sign popup-btn btn btn-info btn-lg" data-popup-id="{criteria}Popup">
|
133 |
+
<span class="popup-text" id="{criteria}Popup">{instructions}</span>
|
134 |
+
</span></h2>
|
135 |
+
""")
|
136 |
+
with gr.Row():
|
137 |
+
ratingButtons[criteria] = []
|
138 |
+
with gr.Column(scale=2):
|
139 |
+
ratingButtons[criteria].append(gr.Button('Left Model', variant='primary', interactive=False))
|
140 |
+
with gr.Column(scale=1, min_width=2):
|
141 |
+
ratingButtons[criteria].append(gr.Button('Skip', min_width=2, interactive=False))
|
142 |
+
with gr.Column(scale=2):
|
143 |
+
ratingButtons[criteria].append(gr.Button('Right Model', variant='primary', interactive=False))
|
144 |
+
|
145 |
+
|
146 |
+
# Leaderboard per criteria
|
147 |
+
with gr.Tab(label='Leaderboards') as leaderboard_tab:
|
148 |
+
|
149 |
+
if MODE == 'testing':
|
150 |
+
# Simulation controls
|
151 |
+
with gr.Row():
|
152 |
+
simulate_btn = gr.Button('Simulate Matches', variant='primary')
|
153 |
+
add_model_btn = gr.Button('Add Model', variant='secondary')
|
154 |
+
with gr.Row():
|
155 |
+
gr.Markdown('''
|
156 |
+
## Probability of each model to be chosen is updated after each vote following: \
|
157 |
+
$$ p_i = \\frac{e^{-\\frac{Matches_i}{\\beta}}}{\\sum_{j=1}^{N} e^{-\\frac{Matches_j}{\\beta}}} $$
|
158 |
+
''')
|
159 |
+
iterate = gr.Number(label='Number of iterations', value=100, minimum=1, maximum=2000, precision=0, interactive=True)
|
160 |
+
beta = gr.Number(label='Beta', value=default_beta, minimum=1, maximum=1000, precision=0, step=10, interactive=True)
|
161 |
+
else:
|
162 |
+
beta = gr.Number(label='Beta', value=default_beta, render=False)
|
163 |
+
|
164 |
+
leaderboards = {}
|
165 |
+
tabs = {}
|
166 |
+
for criteria in CRITERIA:
|
167 |
+
with gr.Tab(label=criteria.replace('_', ' ')) as tabs[criteria]:
|
168 |
+
with gr.Row():
|
169 |
+
gr.HTML(f"<h2 style='text-align:center;'>{criteria.replace('_', ' ')}</h2>")
|
170 |
+
with gr.Row():
|
171 |
+
leaderboards[criteria] = gr.Dataframe(value=None, row_count=(len(MODELS), 'fixed'), headers=['Model', 'Elo', 'Wins', 'Matches', 'Win Rate'], interactive=False)
|
172 |
+
|
173 |
+
# Plots
|
174 |
+
if MODE == 'testing':
|
175 |
+
with gr.Row():
|
176 |
+
elo_plot = gr.Plot(value=None, label='Elo Ratings', format='plotly', elem_id='plot')
|
177 |
+
with gr.Row():
|
178 |
+
wr_plot = gr.Plot(value=None, label='Win Rates', format='plotly', elem_id='plot')
|
179 |
+
with gr.Row():
|
180 |
+
matches_plot = gr.Plot(value=None, label='Matches played', format='plotly', elem_id='plot')
|
181 |
+
elif MODE == 'production':
|
182 |
+
elo_plot = gr.Plot(value=None, label='Elo Ratings', format='plotly', elem_id='plot', visible=False)
|
183 |
+
wr_plot = gr.Plot(value=None, label='Win Rates', format='plotly', elem_id='plot', visible=False)
|
184 |
+
matches_plot = gr.Plot(value=None, label='Matches played', format='plotly', elem_id='plot', visible=False)
|
185 |
+
|
186 |
+
with gr.Tab(label='About'):
|
187 |
+
gr.Markdown('''
|
188 |
+
## Thank you for using the 3D Animation Arena!
|
189 |
+
|
190 |
+
This app is designed to compare different models based on human preferences, inspired by dylanebert's [3D Arena](https://huggingface.co/spaces/dylanebert/3d-arena) on Hugging Face.
|
191 |
+
Current rankings often use metrics to assess the quality of a model, but these metrics may not always reflect the complexity behind human preferences.
|
192 |
+
|
193 |
+
The current models competing in the arena are:
|
194 |
+
- 4DHumans (https://github.com/shubham-goel/4D-Humans)
|
195 |
+
- CLIFF (https://github.com/haofanwang/CLIFF)
|
196 |
+
- GVHMR (https://github.com/zju3dv/GVHMR)
|
197 |
+
- HybrIK (https://github.com/jeffffffli/HybrIK)
|
198 |
+
- WHAM (https://github.com/yohanshin/WHAM)
|
199 |
+
|
200 |
+
All inferences are precomputed following the code in the associated GitHub repository.
|
201 |
+
Some post-inference modifications have been made to some models in order to make the comparison possible.
|
202 |
+
These modifications include:
|
203 |
+
* Adjusting height to a common ground
|
204 |
+
* Fixing the root depth of certain models, when depth was extremely jittery
|
205 |
+
* Fixing the root position of certain models, when no root position was available
|
206 |
+
|
207 |
+
All models use the SMPL body model to discard the influence of the body model on the comparison.
|
208 |
+
These choices were made without any intention to favor or harm any model.
|
209 |
+
All matchups are generated randomly, don't hesitate to rate the same videos multiple times as the matchups will probably be different!
|
210 |
+
|
211 |
+
---
|
212 |
+
|
213 |
+
If you have comments, complaints or suggestions, please contact me at [email protected].
|
214 |
+
New models and videos will be added over time, feel free to share your ideas! Keep in mind that I will not add raw inferences from other people to keep it fair.
|
215 |
+
''')
|
216 |
+
|
217 |
+
|
218 |
+
# Event handlers
|
219 |
+
def randomize_videos(state):
|
220 |
+
state['uuid'] = str(uuid.uuid4())
|
221 |
+
random.shuffle(state['videos'])
|
222 |
+
gallery = "<div class='gallery'>"
|
223 |
+
for vid in state['videos']:
|
224 |
+
gallery += f"""
|
225 |
+
<button class="btn btn-info thumbnail-btn" onclick="(function() {{
|
226 |
+
let gradioVideo = document.getElementById('gradioVideo');
|
227 |
+
let videoComponent = gradioVideo ? gradioVideo.querySelector('video') : null;
|
228 |
+
if (videoComponent && !videoComponent.src.includes('{vid}')) {{
|
229 |
+
Array.from(document.getElementsByClassName('thumbnail-btn')).forEach(btn => btn.disabled = true);
|
230 |
+
}}
|
231 |
+
document.getElementById('triggerBtn_{vid}').click();
|
232 |
+
}})()">
|
233 |
+
<video class="thumbnail" preload="" loop muted onmouseenter="this.play()" onmouseleave="this.pause()">
|
234 |
+
<source src="https://gradio-model-viewer.s3.eu-west-1.amazonaws.com/sample+videos/{vid}.mp4">
|
235 |
+
</video>
|
236 |
+
</button>
|
237 |
+
"""
|
238 |
+
gallery += "</div>"
|
239 |
+
return state, gallery
|
240 |
+
|
241 |
+
async def display_leaderboards():
|
242 |
+
return [await generate_leaderboard(criteria) for criteria in CRITERIA]
|
243 |
+
|
244 |
+
arena.load(
|
245 |
+
inputs=[sessionState],
|
246 |
+
fn=lambda state: randomize_videos(state),
|
247 |
+
outputs=[sessionState, examples],
|
248 |
+
).then(
|
249 |
+
inputs=[],
|
250 |
+
fn=lambda: gr.update(visible=True),
|
251 |
+
outputs=[examples]
|
252 |
+
).then(
|
253 |
+
inputs=[gr.State(CRITERIA[0])],
|
254 |
+
fn=plot_ratings,
|
255 |
+
outputs=[elo_plot, wr_plot, matches_plot]
|
256 |
+
).then(
|
257 |
+
inputs=[],
|
258 |
+
fn=display_leaderboards,
|
259 |
+
outputs=[leaderboards[criteria] for criteria in CRITERIA]
|
260 |
+
)
|
261 |
+
|
262 |
+
async def update_models(video, state):
|
263 |
+
leaderboard = await generate_leaderboard(CRITERIA[0])
|
264 |
+
video_name = video.split('/')[-1].split('.')[0]
|
265 |
+
modelLeft, modelRight = generate_matchup(leaderboard=leaderboard, beta=beta.value)
|
266 |
+
|
267 |
+
state['video'] = video_name
|
268 |
+
state['modelLeft'] = MODELS[modelLeft]
|
269 |
+
state['modelRight'] = MODELS[modelRight]
|
270 |
+
|
271 |
+
return state, state
|
272 |
+
|
273 |
+
video.change(
|
274 |
+
inputs=[video, sessionState],
|
275 |
+
fn=update_models,
|
276 |
+
outputs=[sessionState, frontState]
|
277 |
+
)
|
278 |
+
|
279 |
+
# Weird workaround to run JS function on state change, from https://github.com/gradio-app/gradio/issues/3525#issuecomment-2348596861
|
280 |
+
frontState.change(
|
281 |
+
inputs=[frontState],
|
282 |
+
js='(state) => updateViewers(state)',
|
283 |
+
fn=lambda state: None,
|
284 |
+
).then(
|
285 |
+
inputs=None,
|
286 |
+
fn=lambda: tuple(gr.update(interactive=True) for _ in sum(ratingButtons.values(), [])),
|
287 |
+
outputs= sum(ratingButtons.values(), [])
|
288 |
+
)
|
289 |
+
|
290 |
+
leaderboard_tab.select(
|
291 |
+
inputs=None,
|
292 |
+
js='() => resetPlots()',
|
293 |
+
fn=None,
|
294 |
+
).then(
|
295 |
+
fn=lambda: [gr.update(value=None) for _ in range(3)],
|
296 |
+
outputs=[elo_plot, wr_plot, matches_plot]
|
297 |
+
).then(
|
298 |
+
inputs=[sessionState],
|
299 |
+
fn=lambda state: plot_ratings(state['currentTab']),
|
300 |
+
outputs=[elo_plot, wr_plot, matches_plot]
|
301 |
+
)
|
302 |
+
|
303 |
+
async def process_rating(state, i, criteria):
|
304 |
+
return gr.update(value=await submit_rating(
|
305 |
+
criteria=criteria,
|
306 |
+
winner=state['modelLeft'] if i == 0 else state['modelRight'] if i == 2 else None,
|
307 |
+
loser=state['modelRight'] if i == 0 else state['modelLeft'] if i == 2 else None,
|
308 |
+
uuid=state['uuid']
|
309 |
+
))
|
310 |
+
|
311 |
+
def update_tab(state, criteria):
|
312 |
+
state['currentTab'] = criteria
|
313 |
+
return state
|
314 |
+
|
315 |
+
for criteria in CRITERIA:
|
316 |
+
for i, button in enumerate(ratingButtons[criteria]):
|
317 |
+
button.click(
|
318 |
+
# fn=lambda i=i, criteria=criteria: gr.Info(f'{"You chose Left Model for " if i == 0 else "You chose Right Model for " if i == 2 else "You skipped "} {criteria.replace("_", " ")}!'),
|
319 |
+
# ).then(
|
320 |
+
fn=lambda: tuple(gr.update(interactive=False) for _ in range(len(ratingButtons[criteria]))),
|
321 |
+
outputs=ratingButtons[criteria]
|
322 |
+
).then(
|
323 |
+
inputs=[sessionState, gr.State(i), gr.State(criteria)],
|
324 |
+
fn=process_rating,
|
325 |
+
outputs=[leaderboards[criteria]],
|
326 |
+
)
|
327 |
+
|
328 |
+
tabs[criteria].select(
|
329 |
+
fn=lambda: [gr.update(value=None) for _ in range(3)],
|
330 |
+
outputs=[elo_plot, wr_plot, matches_plot]
|
331 |
+
).then(
|
332 |
+
inputs=[gr.State(criteria)],
|
333 |
+
fn=plot_ratings,
|
334 |
+
outputs=[elo_plot, wr_plot, matches_plot]
|
335 |
+
).then(
|
336 |
+
inputs=[sessionState, gr.State(criteria)],
|
337 |
+
fn=update_tab,
|
338 |
+
outputs=[sessionState]
|
339 |
+
)
|
340 |
+
|
341 |
+
|
342 |
+
if MODE == 'testing':
|
343 |
+
for criteria in CRITERIA:
|
344 |
+
simulate_btn.click(
|
345 |
+
inputs=[iterate, beta, gr.State(criteria)],
|
346 |
+
fn=simulate,
|
347 |
+
outputs=[leaderboards[criteria]],
|
348 |
+
).then(fn=lambda: [gr.update(value=None) for _ in range(3)],
|
349 |
+
outputs=[elo_plot, wr_plot, matches_plot]
|
350 |
+
).then(
|
351 |
+
inputs=[gr.State(criteria)],
|
352 |
+
fn=plot_ratings,
|
353 |
+
outputs=[elo_plot, wr_plot, matches_plot]
|
354 |
+
)
|
355 |
+
|
356 |
+
add_model_btn.click(
|
357 |
+
fn=lambda: MODELS.append(f'model_{len(MODELS)}'),
|
358 |
+
)
|
359 |
+
|
360 |
+
if __name__ == '__main__':
|
361 |
+
gr.set_static_paths(['static'])
|
362 |
+
arena.queue(default_concurrency_limit=50).launch(inbrowser=True, allowed_paths=['static/'])
|
config.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
MODE = os.getenv('ARENA_MODE') # 'testing' or 'production'
|
4 |
+
|
5 |
+
default_beta = 500
|
6 |
+
|
7 |
+
MODELS = [
|
8 |
+
'4DHumans',
|
9 |
+
'CLIFF',
|
10 |
+
'GVHMR',
|
11 |
+
'HybrIK',
|
12 |
+
'WHAM',
|
13 |
+
'TokenHMR',
|
14 |
+
'STAF',
|
15 |
+
'CameraHMR'
|
16 |
+
]
|
17 |
+
|
18 |
+
VIDEOS = [
|
19 |
+
'backflip',
|
20 |
+
'ballet_dance',
|
21 |
+
'ballet_jump',
|
22 |
+
'basketball_dunk',
|
23 |
+
'boxing',
|
24 |
+
'breakdance',
|
25 |
+
'bridge',
|
26 |
+
'capoeira',
|
27 |
+
'contorsionist',
|
28 |
+
'dance_feathers',
|
29 |
+
'dance_modern1',
|
30 |
+
'dance_modern2',
|
31 |
+
'dance_modern3',
|
32 |
+
'dance_road',
|
33 |
+
'dance_tiktok',
|
34 |
+
'highkick',
|
35 |
+
'hiphop',
|
36 |
+
'parkour',
|
37 |
+
'pillars',
|
38 |
+
'skateboard',
|
39 |
+
'spinkick',
|
40 |
+
'trampoline',
|
41 |
+
'wall_jump',
|
42 |
+
'yoga',
|
43 |
+
'cartoon_fall',
|
44 |
+
'cutting_tree',
|
45 |
+
'fencing',
|
46 |
+
'ferocity',
|
47 |
+
'ice_skating',
|
48 |
+
'moonwalk',
|
49 |
+
'npc',
|
50 |
+
'running',
|
51 |
+
'sitting',
|
52 |
+
'stretching',
|
53 |
+
'tennis',
|
54 |
+
]
|
55 |
+
|
56 |
+
CRITERIA = [
|
57 |
+
'Global_Appreciation',
|
58 |
+
'Ground_Contacts',
|
59 |
+
'Fidelity',
|
60 |
+
'Fluidity',
|
61 |
+
]
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pandas
|
2 |
+
plotly
|
3 |
+
numpy
|
4 |
+
gradio==5.16.0
|
5 |
+
huggingface_hub==0.28.1
|
6 |
+
aioboto3==14.1.0
|
7 |
+
pydantic==2.10.3
|
static/modelViewer.js
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
function updateViewers(state) {
|
2 |
+
let video = state.video;
|
3 |
+
let modelLeft = state.modelLeft;
|
4 |
+
let modelRight = state.modelRight;
|
5 |
+
|
6 |
+
let gradioVideo = document.getElementById("gradioVideo");
|
7 |
+
let videoComponent = gradioVideo ? gradioVideo.querySelector("video") : null;
|
8 |
+
|
9 |
+
if (videoComponent && document.getElementById("modelViewerLeft") && document.getElementById("modelViewerRight")) {
|
10 |
+
|
11 |
+
videoComponent.setAttribute("muted", true);
|
12 |
+
document.getElementById("modelViewerLeft").contentWindow.postMessage({ action: "loadModel", modelUrl: `https://gradio-model-viewer.s3.eu-west-1.amazonaws.com/models/${modelLeft}/${video}.glb` }, "*");
|
13 |
+
document.getElementById("modelViewerRight").contentWindow.postMessage({ action: "loadModel", modelUrl: `https://gradio-model-viewer.s3.eu-west-1.amazonaws.com/models/${modelRight}/${video}.glb` }, "*");
|
14 |
+
|
15 |
+
let loadedCount = 0;
|
16 |
+
|
17 |
+
window.addEventListener("message", (event) => {
|
18 |
+
if (event.data.status === "modelLoaded") {
|
19 |
+
loadedCount++;
|
20 |
+
if (loadedCount === 2) {
|
21 |
+
videoComponent.addEventListener("play", syncModelViewers);
|
22 |
+
videoComponent.addEventListener("pause", syncModelViewers);
|
23 |
+
videoComponent.addEventListener("timeupdate", syncModelViewers);
|
24 |
+
|
25 |
+
Array.from(document.getElementsByClassName('thumbnail-btn')).forEach(btn => btn.disabled = false);
|
26 |
+
}
|
27 |
+
}
|
28 |
+
else if (event.data.status === "modelLoadError") {
|
29 |
+
Array.from(document.getElementsByClassName('thumbnail-btn')).forEach(btn => btn.disabled = false);
|
30 |
+
}
|
31 |
+
});
|
32 |
+
|
33 |
+
}
|
34 |
+
}
|
35 |
+
|
36 |
+
function syncModelViewers(event) {
|
37 |
+
let videoComponent = event.target;
|
38 |
+
let modelViewerLeft = document.getElementById("modelViewerLeft");
|
39 |
+
let modelViewerRight = document.getElementById("modelViewerRight");
|
40 |
+
|
41 |
+
if (!modelViewerLeft || !modelViewerRight) return;
|
42 |
+
|
43 |
+
switch (event.type) {
|
44 |
+
case "play":
|
45 |
+
modelViewerLeft.contentWindow.postMessage({ action: "playAnimation" }, "*");
|
46 |
+
modelViewerRight.contentWindow.postMessage({ action: "playAnimation" }, "*");
|
47 |
+
break;
|
48 |
+
case "pause":
|
49 |
+
modelViewerLeft.contentWindow.postMessage({ action: "pauseAnimation" }, "*");
|
50 |
+
modelViewerRight.contentWindow.postMessage({ action: "pauseAnimation" }, "*");
|
51 |
+
break;
|
52 |
+
case "timeupdate":
|
53 |
+
let currentTime = videoComponent.currentTime;
|
54 |
+
modelViewerLeft.contentWindow.postMessage({ action: "setAnimationTime", currentTime }, "*");
|
55 |
+
modelViewerRight.contentWindow.postMessage({ action: "setAnimationTime", currentTime }, "*");
|
56 |
+
break;
|
57 |
+
}
|
58 |
+
}
|
static/plots.js
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Fix plot display issues with gradio Tabs
|
2 |
+
function resetPlots() {
|
3 |
+
document.querySelectorAll(".js-plotly-plot").forEach(plot => {
|
4 |
+
Plotly.relayout(plot, { autosize: true });
|
5 |
+
});
|
6 |
+
};
|
static/popup.js
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
function togglePopup(event, show) {
|
2 |
+
var popupBtn = event.target;
|
3 |
+
var popupId = popupBtn.getAttribute('data-popup-id');
|
4 |
+
var popup = document.getElementById(popupId);
|
5 |
+
|
6 |
+
popup.parentElement.parentElement.parentElement.parentElement.parentElement.style.overflow = 'visible';
|
7 |
+
|
8 |
+
popup.classList.toggle("show");
|
9 |
+
popupBtn.parentElement.style.overflow = 'visible';
|
10 |
+
|
11 |
+
event.stopPropagation();
|
12 |
+
}
|
13 |
+
|
14 |
+
document.addEventListener("mouseover", function(event) {
|
15 |
+
if (event.target.classList.contains('popup-btn')) {
|
16 |
+
togglePopup(event, true);
|
17 |
+
}
|
18 |
+
});
|
19 |
+
|
20 |
+
document.addEventListener("mouseout", function(event) {
|
21 |
+
if (event.target.classList.contains('popup-btn')) {
|
22 |
+
togglePopup(event, false);
|
23 |
+
}
|
24 |
+
});
|
static/style.css
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.viewer-container {
|
2 |
+
display: flex;
|
3 |
+
gap: 20px;
|
4 |
+
justify-content: center;
|
5 |
+
}
|
6 |
+
model-viewer {
|
7 |
+
width: 50%;
|
8 |
+
height: 500px;
|
9 |
+
background-color: #eeeeee;
|
10 |
+
}
|
11 |
+
iframe {
|
12 |
+
width: 50%;
|
13 |
+
height: 500px;
|
14 |
+
background-color: #eeeeee;
|
15 |
+
border: none;
|
16 |
+
border-radius: 5px;
|
17 |
+
}
|
18 |
+
|
19 |
+
.js-plotly-plot {
|
20 |
+
width: 100%;
|
21 |
+
display: flex;
|
22 |
+
justify-content: center;
|
23 |
+
}
|
24 |
+
|
25 |
+
@font-face {
|
26 |
+
font-family: 'Glyphicons Halflings';
|
27 |
+
|
28 |
+
src: url('https://maxcdn.bootstrapcdn.com/bootstrap/3.3.7/fonts/glyphicons-halflings-regular.eot');
|
29 |
+
src: url('https://maxcdn.bootstrapcdn.com/bootstrap/3.3.7/fonts/glyphicons-halflings-regular.eot?#iefix') format('embedded-opentype'), url('https://maxcdn.bootstrapcdn.com/bootstrap/3.3.7/fonts/glyphicons-halflings-regular.woff2') format('woff2'), url('https://maxcdn.bootstrapcdn.com/bootstrap/3.3.7/fonts/glyphicons-halflings-regular.woff') format('woff'), url('https://maxcdn.bootstrapcdn.com/bootstrap/3.3.7/fonts/glyphicons-halflings-regular.ttf') format('truetype'), url('https://maxcdn.bootstrapcdn.com/bootstrap/3.3.7/fonts/glyphicons-halflings-regular.svg#glyphicons_halflingsregular') format('svg');
|
30 |
+
}
|
31 |
+
.glyphicon {
|
32 |
+
position: relative;
|
33 |
+
top: 2px;
|
34 |
+
display: inline-block;
|
35 |
+
font-family: 'Glyphicons Halflings';
|
36 |
+
font-style: normal;
|
37 |
+
font-weight: normal;
|
38 |
+
line-height: 1;
|
39 |
+
|
40 |
+
-webkit-font-smoothing: antialiased;
|
41 |
+
-moz-osx-font-smoothing: grayscale;
|
42 |
+
}
|
43 |
+
.glyphicon-question-sign:before {
|
44 |
+
content: "\e085";
|
45 |
+
}
|
46 |
+
|
47 |
+
.popup-btn {
|
48 |
+
position: relative;
|
49 |
+
display: inline-block;
|
50 |
+
cursor: pointer;
|
51 |
+
-webkit-user-select: none;
|
52 |
+
-moz-user-select: none;
|
53 |
+
-ms-user-select: none;
|
54 |
+
user-select: none;
|
55 |
+
overflow: visible !important;
|
56 |
+
font-size: 20px;
|
57 |
+
border: none;
|
58 |
+
text-align: center;
|
59 |
+
text-decoration: none;
|
60 |
+
padding-left: 10px;
|
61 |
+
}
|
62 |
+
|
63 |
+
.popup-text {
|
64 |
+
visibility: hidden;
|
65 |
+
width: 300px;
|
66 |
+
background-color: #e4e4e7;
|
67 |
+
color: #000;
|
68 |
+
text-align: center;
|
69 |
+
border-radius: 10px;
|
70 |
+
padding: 20px 10px;
|
71 |
+
position: absolute;
|
72 |
+
z-index: 1;
|
73 |
+
bottom: 150%;
|
74 |
+
left: 50%;
|
75 |
+
margin-left: -145px;
|
76 |
+
overflow: visible !important;
|
77 |
+
font-size: 17px;
|
78 |
+
font-family: Arial, sans-serif;
|
79 |
+
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
|
80 |
+
transition: visibility 0.3s, opacity 0.3s ease-in-out;
|
81 |
+
opacity: 0;
|
82 |
+
}
|
83 |
+
|
84 |
+
.popup-text.dark {
|
85 |
+
background-color: #52525b;
|
86 |
+
color: #fff;
|
87 |
+
}
|
88 |
+
|
89 |
+
/* Popup arrow */
|
90 |
+
.popup-text::after {
|
91 |
+
content: "";
|
92 |
+
position: absolute;
|
93 |
+
top: 100%;
|
94 |
+
left: 50%;
|
95 |
+
margin-left: -5px;
|
96 |
+
border-width: 5px;
|
97 |
+
border-style: solid;
|
98 |
+
border-color: #e4e4e7 transparent transparent transparent;
|
99 |
+
}
|
100 |
+
|
101 |
+
.popup-text.dark::after {
|
102 |
+
border-color: #52525b transparent transparent transparent;
|
103 |
+
}
|
104 |
+
|
105 |
+
.popup-text.show {
|
106 |
+
visibility: visible;
|
107 |
+
opacity: 1;
|
108 |
+
}
|
109 |
+
|
110 |
+
.thumbnail-btn {
|
111 |
+
height:75px;
|
112 |
+
width: auto;
|
113 |
+
overflow: hidden;
|
114 |
+
display: flex;
|
115 |
+
}
|
116 |
+
.thumbnail-btn:hover {
|
117 |
+
cursor: pointer;
|
118 |
+
transform: scale(1.05);
|
119 |
+
transition: transform 0.5s;
|
120 |
+
}
|
121 |
+
|
122 |
+
.thumbnail {
|
123 |
+
width: auto;
|
124 |
+
height: 100%;
|
125 |
+
object-fit: cover;
|
126 |
+
border-radius: 5px;
|
127 |
+
}
|
128 |
+
|
129 |
+
.gallery {
|
130 |
+
display: flex;
|
131 |
+
flex-wrap: wrap;
|
132 |
+
gap: 10px;
|
133 |
+
justify-content: flex-start;
|
134 |
+
align-items: flex-start;
|
135 |
+
}
|
utils/__pycache__/data_utils.cpython-310.pyc
ADDED
Binary file (1.68 kB). View file
|
|
utils/__pycache__/emoji_utils.cpython-310.pyc
ADDED
Binary file (495 Bytes). View file
|
|
utils/__pycache__/plot_utils.cpython-310.pyc
ADDED
Binary file (1.49 kB). View file
|
|
utils/__pycache__/s3_utils.cpython-310.pyc
ADDED
Binary file (4.37 kB). View file
|
|
utils/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (4.69 kB). View file
|
|
utils/data_utils.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .s3_utils import read_from_s3
|
2 |
+
from config import MODELS
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
async def generate_leaderboard(criteria : str) -> pd.DataFrame:
|
6 |
+
"""
|
7 |
+
Generate the leaderboard from saved data.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
criteria (str): The criteria corresponding to the leaderboard.
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
pd.DataFrame: The leaderboard.
|
14 |
+
"""
|
15 |
+
try:
|
16 |
+
leaderboard = await read_from_s3(f'leaderboard_{criteria}.csv')
|
17 |
+
if leaderboard is None:
|
18 |
+
raise Exception
|
19 |
+
except:
|
20 |
+
leaderboard = pd.DataFrame({
|
21 |
+
'Model': pd.Series(dtype='str'),
|
22 |
+
'Elo': pd.Series(dtype='int'),
|
23 |
+
'Wins': pd.Series(dtype='int'),
|
24 |
+
'Matches': pd.Series(dtype='int'),
|
25 |
+
'Win Rate': pd.Series(dtype='float')
|
26 |
+
})
|
27 |
+
|
28 |
+
for model in MODELS:
|
29 |
+
if model not in leaderboard['Model'].values:
|
30 |
+
leaderboard = pd.concat([leaderboard, pd.DataFrame([{'Model': model, 'Elo': 1500, 'Wins': 0, 'Matches': 0, 'Win Rate': 0.0}])], ignore_index=True)
|
31 |
+
leaderboard = leaderboard.sort_values('Elo', ascending=False).reset_index(drop=True)
|
32 |
+
leaderboard['Win Rate'] = leaderboard['Win Rate'].apply(lambda x: round(x, 2))
|
33 |
+
|
34 |
+
return leaderboard
|
35 |
+
|
36 |
+
async def generate_data() -> pd.DataFrame:
|
37 |
+
"""
|
38 |
+
Generate the data for the matches.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
pd.DataFrame: The data for the matches.
|
42 |
+
"""
|
43 |
+
try :
|
44 |
+
data = await read_from_s3('data.csv')
|
45 |
+
if data is None:
|
46 |
+
raise Exception
|
47 |
+
except:
|
48 |
+
data = pd.DataFrame(columns=['Criteria', 'Model', 'Opponent', 'Won', 'Elo', 'Win Rate', 'Matches', 'Timestamp', 'UUID'])
|
49 |
+
return data
|
utils/emoji_utils.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def add_emoji(model : str) -> str:
|
2 |
+
match model:
|
3 |
+
case '4DHumans':
|
4 |
+
return '\U0001F3A5 4DHumans'
|
5 |
+
case 'CLIFF':
|
6 |
+
return '\U0001F3D4 CLIFF'
|
7 |
+
case 'GVHMR':
|
8 |
+
return '\U0001F4CD GVHMR'
|
9 |
+
case 'HybrIK':
|
10 |
+
return '\U0001F9BE HybrIK'
|
11 |
+
case 'WHAM':
|
12 |
+
return '\U0001F3AF WHAM'
|
13 |
+
case _:
|
14 |
+
return model
|
utils/graphs.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
utils/plot_utils.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import plotly.express as px
|
2 |
+
from .data_utils import generate_data
|
3 |
+
|
4 |
+
async def plot_ratings(criteria : str):
|
5 |
+
"""
|
6 |
+
Plot different ratings of the models for a given criteria on a line graph.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
criteria (str): The criteria to plot the ratings for.
|
10 |
+
"""
|
11 |
+
data = await generate_data()
|
12 |
+
|
13 |
+
elo_data = data[data['Criteria'] == criteria].reset_index(drop=True)
|
14 |
+
elo_data.reset_index(inplace=True)
|
15 |
+
elo_data['index'] = elo_data['index'].apply(lambda x: int(x/2))
|
16 |
+
elo_fig = px.line(elo_data, x='index', y='Elo', color='Model', title='Elo Ratings Over Total Votes', labels={'index': 'Total Votes', 'Elo': 'Elo Rating'})
|
17 |
+
|
18 |
+
wr_data = data[data['Criteria'] == criteria].reset_index(drop=True)
|
19 |
+
wr_data.reset_index(inplace=True)
|
20 |
+
wr_data['index'] = wr_data['index'].apply(lambda x: int(x/2))
|
21 |
+
wr_fig = px.line(wr_data, x='index', y='Win Rate', color='Model', title='Win Rates Over Total Votes', labels={'index': 'Total Votes', 'Win Rate': 'Win Rate'})
|
22 |
+
|
23 |
+
matches_data = data[data['Criteria'] == criteria].reset_index(drop=True)
|
24 |
+
matches_data.reset_index(inplace=True)
|
25 |
+
matches_data['index'] = matches_data['index'].apply(lambda x: int(x/2))
|
26 |
+
matches_fig = px.line(matches_data, x='index', y='Matches', color='Model', title='Matches Played Over Total Votes', labels={'index': 'Total Votes', 'Matches': 'Matches Played'})
|
27 |
+
|
28 |
+
return elo_fig, wr_fig, matches_fig
|
utils/s3_utils.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import aioboto3
|
2 |
+
import pandas as pd
|
3 |
+
from io import StringIO
|
4 |
+
from typing import Optional, Union
|
5 |
+
import os
|
6 |
+
|
7 |
+
AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID')
|
8 |
+
AWS_SECRET_ACCESS_KEY = os.getenv('AWS_SECRET_ACCESS_KEY')
|
9 |
+
S3_BUCKET = os.getenv('S3_BUCKET')
|
10 |
+
S3_VIDEO_PATH = 'sample videos'
|
11 |
+
S3_MODEL_PATH = 'models'
|
12 |
+
S3_DATA_PATH = '3d_animation_arena/results'
|
13 |
+
|
14 |
+
async def download_from_s3(file_key : str, target_dir: str, bucket : str = S3_BUCKET) -> Optional[str]:
|
15 |
+
"""
|
16 |
+
Downloads a file from an S3 bucket.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
file_key (str): The key of the file in the S3 bucket, including extension.
|
20 |
+
target_dir (str): The path to the directory to save the downloaded file.
|
21 |
+
bucket (str, optional): The name of the S3 bucket.
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
Optional[str]: The path to the file or None if the download fails.
|
25 |
+
"""
|
26 |
+
session = aioboto3.Session()
|
27 |
+
target_path = os.path.join(target_dir, file_key)
|
28 |
+
|
29 |
+
async with session.client(
|
30 |
+
's3',
|
31 |
+
aws_access_key_id=AWS_ACCESS_KEY_ID,
|
32 |
+
aws_secret_access_key=AWS_SECRET_ACCESS_KEY
|
33 |
+
) as s3_client:
|
34 |
+
try:
|
35 |
+
os.makedirs(target_dir, exist_ok=True)
|
36 |
+
if os.path.exists(target_path):
|
37 |
+
print(f'{file_key} already exists in {target_dir}')
|
38 |
+
return target_path
|
39 |
+
with open(target_path, 'wb') as f:
|
40 |
+
match file_key.split('.')[-1]:
|
41 |
+
case 'mp4':
|
42 |
+
await s3_client.download_fileobj(bucket, os.path.join(S3_VIDEO_PATH, file_key), f)
|
43 |
+
case 'glb'|'obj'|'stl'|'gltf'|'splat'|'ply':
|
44 |
+
await s3_client.download_fileobj(bucket, os.path.join(S3_MODEL_PATH, file_key), f)
|
45 |
+
case _:
|
46 |
+
print(f"Unsupported file type: {file_key}")
|
47 |
+
raise ValueError(f"Unsupported file type: {file_key}")
|
48 |
+
return target_path
|
49 |
+
except Exception as e:
|
50 |
+
print(f'Error downloading {file_key} from bucket {bucket}: {e}')
|
51 |
+
raise e
|
52 |
+
|
53 |
+
|
54 |
+
async def read_from_s3(file_key : str, bucket : str = S3_BUCKET) -> Optional[Union[pd.DataFrame, str]]:
|
55 |
+
"""
|
56 |
+
Reads a file from an S3 bucket based on its file extension and returns the appropriate data type.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
file_key (str): The key of the file in the S3 bucket.
|
60 |
+
bucket (str, optional): The name of the S3 bucket.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
Optional[Union[pd.DataFrame, str]]:
|
64 |
+
- A pandas DataFrame if the file is a CSV.
|
65 |
+
- A temporary file path (str) if the file is a GLB.
|
66 |
+
- A presigned URL (str) if the file is an MP4.
|
67 |
+
- None if the file type is unsupported.
|
68 |
+
"""
|
69 |
+
session = aioboto3.Session()
|
70 |
+
async with session.client(
|
71 |
+
's3',
|
72 |
+
aws_access_key_id=AWS_ACCESS_KEY_ID,
|
73 |
+
aws_secret_access_key=AWS_SECRET_ACCESS_KEY
|
74 |
+
) as s3_client:
|
75 |
+
try:
|
76 |
+
match file_key.split('.')[-1]:
|
77 |
+
case 'csv':
|
78 |
+
response = await s3_client.get_object(Bucket=bucket, Key=os.path.join(S3_DATA_PATH, file_key))
|
79 |
+
content = await response['Body'].read()
|
80 |
+
result = pd.read_csv(StringIO(content.decode("utf-8")))
|
81 |
+
return result
|
82 |
+
case _:
|
83 |
+
print(f"Unsupported file type for reading: {file_key}")
|
84 |
+
raise ValueError(f"Unsupported file type for reading: {file_key}")
|
85 |
+
except Exception as e:
|
86 |
+
print(f'Error reading {file_key} from bucket {bucket}: {e}')
|
87 |
+
raise e
|
88 |
+
|
89 |
+
|
90 |
+
async def write_to_s3(file_key : str, dataframe: pd.DataFrame, bucket : str = S3_BUCKET) -> None:
|
91 |
+
"""
|
92 |
+
Writes a pandas DataFrame to an S3 bucket as a CSV file.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
file_key (str): The key (file name) under which the file will be stored in the S3 bucket.
|
96 |
+
dataframe (pd.DataFrame): The pandas DataFrame to write to the S3 bucket.
|
97 |
+
bucket (str, optional): The name of the S3 bucket.
|
98 |
+
|
99 |
+
Raises:
|
100 |
+
Exception: Reraises any exception encountered during the write process.
|
101 |
+
"""
|
102 |
+
session = aioboto3.Session()
|
103 |
+
async with session.client(
|
104 |
+
's3',
|
105 |
+
aws_access_key_id=AWS_ACCESS_KEY_ID,
|
106 |
+
aws_secret_access_key=AWS_SECRET_ACCESS_KEY
|
107 |
+
) as s3_client:
|
108 |
+
try:
|
109 |
+
match file_key.split('.')[-1]:
|
110 |
+
case 'csv':
|
111 |
+
csv_buffer = StringIO()
|
112 |
+
dataframe.to_csv(csv_buffer, index=False)
|
113 |
+
await s3_client.put_object(
|
114 |
+
Bucket=bucket,
|
115 |
+
Key=os.path.join(S3_DATA_PATH, file_key),
|
116 |
+
Body=csv_buffer.getvalue()
|
117 |
+
)
|
118 |
+
case _:
|
119 |
+
print(f"Unsupported file type for writing: {file_key}")
|
120 |
+
raise ValueError(f"Unsupported file type for writing: {file_key}")
|
121 |
+
except Exception as e:
|
122 |
+
print(f'Error writing {file_key} to bucket {bucket}: {e}')
|
123 |
+
raise e
|
utils/utils.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
import time
|
5 |
+
import asyncio
|
6 |
+
from utils.s3_utils import write_to_s3
|
7 |
+
from utils.data_utils import generate_leaderboard, generate_data
|
8 |
+
|
9 |
+
submit_lock = asyncio.Lock()
|
10 |
+
|
11 |
+
def update_ratings(R_win : int, R_lose : int, k : int = 32) -> Tuple[int, int]:
|
12 |
+
"""
|
13 |
+
Update the ratings of two players after a match.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
R_win (int): The rating of the winning player.
|
17 |
+
R_lose (int): The rating of the losing player.
|
18 |
+
k (int, optional): The k-factor. Defaults to 32.
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
Tuple[int, int]: The updated ratings of the winning and losing players.
|
22 |
+
"""
|
23 |
+
E_win = 1 / (1 + 10 ** ((R_lose - R_win) / 480))
|
24 |
+
E_lose = 1 / (1 + 10 ** ((R_win - R_lose) / 480))
|
25 |
+
return int(R_win + k * (1 - E_win)), int(R_lose + k * (0 - E_lose))
|
26 |
+
|
27 |
+
def generate_matchup(leaderboard : pd.DataFrame, beta : int) -> tuple[str, str]:
|
28 |
+
"""
|
29 |
+
Generate a pseudo-random matchup between two models.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
leaderboard (pd.DataFrame): The leaderboard of models
|
33 |
+
beta (int): The damping factor for the Elo update.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
model1 (str): The first model.
|
37 |
+
model2 (str): The second model.
|
38 |
+
"""
|
39 |
+
if leaderboard['Matches'].sum() == 0:
|
40 |
+
return np.random.choice(leaderboard.index, 2, replace=False)
|
41 |
+
weights = [np.exp(-leaderboard.at[model, 'Matches'] / beta) for model in leaderboard.index]
|
42 |
+
weights = weights / np.sum(weights) # Normalize weights
|
43 |
+
selected = np.random.choice(leaderboard.index, 2, replace=False, p=weights)
|
44 |
+
np.random.shuffle(selected)
|
45 |
+
model1, model2 = selected
|
46 |
+
return model1, model2
|
47 |
+
|
48 |
+
async def simulate(iter : int, beta : int, criteria : str) -> pd.DataFrame:
|
49 |
+
"""
|
50 |
+
Simulate matches between random models.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
iter (int): The number of matches to simulate.
|
54 |
+
beta (int): The damping factor for the Elo update.
|
55 |
+
criteria (str): The criteria for the rating.
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
leaderboard (pd.DataFrame): Updated leaderboard after simulation
|
59 |
+
"""
|
60 |
+
data = await generate_data()
|
61 |
+
|
62 |
+
leaderboard = await generate_leaderboard(criteria)
|
63 |
+
leaderboard.set_index('Model', inplace=True)
|
64 |
+
|
65 |
+
for _ in range(iter):
|
66 |
+
# Generate random matchups
|
67 |
+
timestamp = time.time()
|
68 |
+
model1, model2 = generate_matchup(leaderboard, beta)
|
69 |
+
R1, R2 = leaderboard.at[model1, 'Elo'], leaderboard.at[model2, 'Elo']
|
70 |
+
R1_new, R2_new = update_ratings(R1, R2)
|
71 |
+
|
72 |
+
# Update leaderboard
|
73 |
+
leaderboard.at[model1, 'Elo'], leaderboard.at[model2, 'Elo'] = R1_new, R2_new
|
74 |
+
leaderboard.at[model1, 'Wins'] += 1
|
75 |
+
leaderboard.at[model1, 'Matches'] += 1
|
76 |
+
leaderboard.at[model2, 'Matches'] += 1
|
77 |
+
leaderboard.at[model1, 'Win Rate'] = np.round(leaderboard.at[model1, 'Wins'] / leaderboard.at[model1, 'Matches'], 2)
|
78 |
+
leaderboard.at[model2, 'Win Rate'] = np.round(leaderboard.at[model2, 'Wins'] / leaderboard.at[model2, 'Matches'], 2)
|
79 |
+
|
80 |
+
# Save match data
|
81 |
+
data.loc[len(data)] = {
|
82 |
+
'Criteria': criteria,
|
83 |
+
'Model': model1,
|
84 |
+
'Opponent': model2,
|
85 |
+
'Won': True,
|
86 |
+
'Elo': leaderboard.at[model1, 'Elo'],
|
87 |
+
'Win Rate': leaderboard.at[model1, 'Win Rate'],
|
88 |
+
'Matches': leaderboard.at[model1, 'Matches'],
|
89 |
+
'Timestamp': timestamp,
|
90 |
+
'UUID': None
|
91 |
+
}
|
92 |
+
|
93 |
+
data.loc[len(data)] = {
|
94 |
+
'Criteria': criteria,
|
95 |
+
'Model': model2,
|
96 |
+
'Opponent': model1,
|
97 |
+
'Won': False,
|
98 |
+
'Elo': leaderboard.at[model2, 'Elo'],
|
99 |
+
'Win Rate': leaderboard.at[model2, 'Win Rate'],
|
100 |
+
'Matches': leaderboard.at[model2, 'Matches'],
|
101 |
+
'Timestamp': timestamp,
|
102 |
+
'UUID': None
|
103 |
+
}
|
104 |
+
|
105 |
+
leaderboard = leaderboard.sort_values('Elo', ascending=False).reset_index(drop=False)
|
106 |
+
|
107 |
+
await asyncio.gather(
|
108 |
+
write_to_s3(f'leaderboard_{criteria}.csv', leaderboard),
|
109 |
+
write_to_s3('data.csv', data)
|
110 |
+
)
|
111 |
+
|
112 |
+
return leaderboard
|
113 |
+
|
114 |
+
|
115 |
+
async def submit_rating(criteria : str, winner : str, loser : str, uuid : str) -> None:
|
116 |
+
"""
|
117 |
+
Submit a rating for a match.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
criteria (str): The criteria for the rating.
|
121 |
+
winner (str): The winning model.
|
122 |
+
loser (str): The losing model.
|
123 |
+
uuid (str): The UUID of the session.
|
124 |
+
"""
|
125 |
+
async with submit_lock:
|
126 |
+
data = await generate_data()
|
127 |
+
|
128 |
+
leaderboard = await generate_leaderboard(criteria)
|
129 |
+
leaderboard.set_index('Model', inplace=True)
|
130 |
+
|
131 |
+
if winner is None or loser is None:
|
132 |
+
return leaderboard
|
133 |
+
|
134 |
+
timestamp = time.time()
|
135 |
+
R_win, R_lose = leaderboard.at[winner, 'Elo'], leaderboard.at[loser, 'Elo']
|
136 |
+
R_win_new, R_lose_new = update_ratings(R_win, R_lose)
|
137 |
+
|
138 |
+
# Update leaderboard
|
139 |
+
leaderboard.loc[[winner, loser], 'Elo'] = [R_win_new, R_lose_new]
|
140 |
+
leaderboard.at[winner, 'Wins'] += 1
|
141 |
+
leaderboard.loc[[winner, loser], 'Matches'] += [1, 1]
|
142 |
+
leaderboard.loc[[winner, loser], 'Win Rate'] = (
|
143 |
+
leaderboard.loc[[winner, loser], 'Wins'] / leaderboard.loc[[winner, loser], 'Matches']
|
144 |
+
).apply(lambda x: round(x, 2))
|
145 |
+
|
146 |
+
# Save match data
|
147 |
+
data.loc[len(data)] = {
|
148 |
+
'Criteria': criteria,
|
149 |
+
'Model': winner,
|
150 |
+
'Opponent': loser,
|
151 |
+
'Won': True,
|
152 |
+
'Elo': leaderboard.at[winner, 'Elo'],
|
153 |
+
'Win Rate': leaderboard.at[winner, 'Win Rate'],
|
154 |
+
'Matches': leaderboard.at[winner, 'Matches'],
|
155 |
+
'Timestamp': timestamp,
|
156 |
+
'UUID': uuid
|
157 |
+
}
|
158 |
+
|
159 |
+
data.loc[len(data)] = {
|
160 |
+
'Criteria': criteria,
|
161 |
+
'Model': loser,
|
162 |
+
'Opponent': winner,
|
163 |
+
'Won': False,
|
164 |
+
'Elo': leaderboard.at[loser, 'Elo'],
|
165 |
+
'Win Rate': leaderboard.at[loser, 'Win Rate'],
|
166 |
+
'Matches': leaderboard.at[loser, 'Matches'],
|
167 |
+
'Timestamp': timestamp,
|
168 |
+
'UUID': uuid
|
169 |
+
}
|
170 |
+
|
171 |
+
leaderboard = leaderboard.sort_values('Elo', ascending=False).reset_index(drop=False)
|
172 |
+
await asyncio.gather(
|
173 |
+
write_to_s3(f'leaderboard_{criteria}.csv', leaderboard),
|
174 |
+
write_to_s3('data.csv', data)
|
175 |
+
)
|
176 |
+
return leaderboard
|