Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -253,7 +253,7 @@ def find_best_songs_for_mood(all_tracks_audio_features, genre_selected_indexes,
|
|
| 253 |
return min_dist_indexes, n_candidates
|
| 254 |
|
| 255 |
@st.cache
|
| 256 |
-
def run_exploration(selected_tracks_uris, selected_tracks_genres, playlist_length, exploration, all_tracks_uris, target_mood):
|
| 257 |
# sample exploration songs
|
| 258 |
if exploration > 0:
|
| 259 |
n_known = int(playlist_length * (1 - exploration))
|
|
@@ -276,19 +276,23 @@ def run_exploration(selected_tracks_uris, selected_tracks_genres, playlist_lengt
|
|
| 276 |
dict_args_loose[f'max_{m}'] = min(1, target_mood[i_m] + 0.3)
|
| 277 |
new_songs = []
|
| 278 |
counter_seed = 0
|
|
|
|
| 279 |
while len(new_songs) < n_new:
|
| 280 |
try:
|
| 281 |
print(seed_songs[counter_seed])
|
| 282 |
print(dict_args)
|
| 283 |
-
|
|
|
|
| 284 |
market="from_token", country='from_token', **dict_args)['tracks']
|
| 285 |
if len(reco) == 0:
|
| 286 |
print('Using loose bounds')
|
| 287 |
-
|
|
|
|
| 288 |
market="from_token", country='from_token', **dict_args_loose)['tracks']
|
| 289 |
if len(reco) == 0:
|
| 290 |
print('Using looser bounds')
|
| 291 |
-
|
|
|
|
| 292 |
market="from_token", country='from_token', **dict_args_looser)['tracks']
|
| 293 |
if len(reco) == 0:
|
| 294 |
print('Removing bounds')
|
|
@@ -298,6 +302,7 @@ def run_exploration(selected_tracks_uris, selected_tracks_genres, playlist_lengt
|
|
| 298 |
if r['uri'] not in all_tracks_uris and r['uri'] not in new_songs:
|
| 299 |
new_songs.append(r['uri'])
|
| 300 |
break
|
|
|
|
| 301 |
except:
|
| 302 |
pass
|
| 303 |
print(counter_seed, len(new_songs))
|
|
@@ -348,6 +353,7 @@ def run_app():
|
|
| 348 |
if custom_button or 'run_custom' in st.session_state.keys() or debug:
|
| 349 |
st.session_state['run_custom'] = True
|
| 350 |
checkboxes = st.session_state['checkboxes'].copy()
|
|
|
|
| 351 |
init_time = time.time()
|
| 352 |
genre_selected_indexes = filter_songs_by_genre(checkboxes, genres_labels, indexes_by_genre)
|
| 353 |
if len(genre_selected_indexes) < 10:
|
|
@@ -380,7 +386,7 @@ def run_app():
|
|
| 380 |
generation_button = centered_button(st.button, 'Generate playlist', n_columns=5)
|
| 381 |
if generation_button:
|
| 382 |
selected_tracks_uris = run_exploration(selected_tracks_uris, selected_tracks_genres, playlist_length, exploration, all_tracks_uris,
|
| 383 |
-
target_mood.flatten())
|
| 384 |
print(f'9. run exploration: {time.time() - init_time:.2f}')
|
| 385 |
init_time = time.time()
|
| 386 |
|
|
|
|
| 253 |
return min_dist_indexes, n_candidates
|
| 254 |
|
| 255 |
@st.cache
|
| 256 |
+
def run_exploration(selected_tracks_uris, selected_tracks_genres, playlist_length, exploration, all_tracks_uris, target_mood, selected_genres):
|
| 257 |
# sample exploration songs
|
| 258 |
if exploration > 0:
|
| 259 |
n_known = int(playlist_length * (1 - exploration))
|
|
|
|
| 276 |
dict_args_loose[f'max_{m}'] = min(1, target_mood[i_m] + 0.3)
|
| 277 |
new_songs = []
|
| 278 |
counter_seed = 0
|
| 279 |
+
print(selected_genres)
|
| 280 |
while len(new_songs) < n_new:
|
| 281 |
try:
|
| 282 |
print(seed_songs[counter_seed])
|
| 283 |
print(dict_args)
|
| 284 |
+
np.random.shuffle(selected_genres)
|
| 285 |
+
reco = sp.recommendations(seed_tracks=[seed_songs[counter_seed]], seed_genres=selected_genres,
|
| 286 |
market="from_token", country='from_token', **dict_args)['tracks']
|
| 287 |
if len(reco) == 0:
|
| 288 |
print('Using loose bounds')
|
| 289 |
+
np.random.shuffle(selected_genres)
|
| 290 |
+
reco = sp.recommendations(seed_tracks=[seed_songs[counter_seed]], seed_genres=selected_genres,
|
| 291 |
market="from_token", country='from_token', **dict_args_loose)['tracks']
|
| 292 |
if len(reco) == 0:
|
| 293 |
print('Using looser bounds')
|
| 294 |
+
np.random.shuffle(selected_genres)
|
| 295 |
+
reco = sp.recommendations(seed_tracks=[seed_songs[counter_seed]], seed_genres=selected_genres,
|
| 296 |
market="from_token", country='from_token', **dict_args_looser)['tracks']
|
| 297 |
if len(reco) == 0:
|
| 298 |
print('Removing bounds')
|
|
|
|
| 302 |
if r['uri'] not in all_tracks_uris and r['uri'] not in new_songs:
|
| 303 |
new_songs.append(r['uri'])
|
| 304 |
break
|
| 305 |
+
|
| 306 |
except:
|
| 307 |
pass
|
| 308 |
print(counter_seed, len(new_songs))
|
|
|
|
| 353 |
if custom_button or 'run_custom' in st.session_state.keys() or debug:
|
| 354 |
st.session_state['run_custom'] = True
|
| 355 |
checkboxes = st.session_state['checkboxes'].copy()
|
| 356 |
+
selected_genres = [genres_labels[i] for i in range(len(genres_labels)) if checkboxes[i] and genres_labels[i] != 'unknown']
|
| 357 |
init_time = time.time()
|
| 358 |
genre_selected_indexes = filter_songs_by_genre(checkboxes, genres_labels, indexes_by_genre)
|
| 359 |
if len(genre_selected_indexes) < 10:
|
|
|
|
| 386 |
generation_button = centered_button(st.button, 'Generate playlist', n_columns=5)
|
| 387 |
if generation_button:
|
| 388 |
selected_tracks_uris = run_exploration(selected_tracks_uris, selected_tracks_genres, playlist_length, exploration, all_tracks_uris,
|
| 389 |
+
target_mood.flatten(), selected_genres)
|
| 390 |
print(f'9. run exploration: {time.time() - init_time:.2f}')
|
| 391 |
init_time = time.time()
|
| 392 |
|