Spaces:
Build error
Build error
| import gradio as gr | |
| import json | |
| from datetime import datetime | |
| from pathlib import Path | |
| from uuid import uuid4 | |
| import json | |
| import time | |
| import os | |
| from huggingface_hub import CommitScheduler | |
| from functools import partial | |
| import pandas as pd | |
| import numpy as np | |
| from huggingface_hub import snapshot_download | |
| import librosa | |
| import random | |
| def enable_buttons_side_by_side(): | |
| return tuple(gr.update(visible=True, interactive=True) for i in range(6)) | |
| def disable_buttons_side_by_side(): | |
| return tuple(gr.update(visible=i>=4, interactive=False) for i in range(6)) | |
| os.makedirs('data', exist_ok = True) | |
| LOG_FILENAME = os.path.join('data', f'log_{datetime.now().isoformat()}.json') | |
| FLAG_FILENAME = os.path.join('data', f'flagged_{datetime.now().isoformat()}.json') | |
| enable_btn = gr.update(interactive=True, visible=True) | |
| disable_btn = gr.update(interactive=False) | |
| invisible_btn = gr.update(interactive=False, visible=False) | |
| no_change_btn = gr.update(value="No Change", interactive=True, visible=True) | |
| DS_ID = os.getenv('DS_ID') | |
| TOKEN = os.getenv('TOKEN') | |
| SONG_SOURCE = os.getenv("SONG_SOURCE") | |
| LOCAL_DIR = './' | |
| snapshot_download(repo_id=SONG_SOURCE, repo_type="dataset", token = TOKEN, local_dir = LOCAL_DIR) | |
| scheduler = CommitScheduler( | |
| repo_id= DS_ID, | |
| repo_type="dataset", | |
| folder_path= os.path.dirname(LOG_FILENAME), | |
| path_in_repo="data", | |
| token = TOKEN, | |
| every = 10, | |
| ) | |
| df = pd.read_csv(os.path.join(LOCAL_DIR,'data.csv')) | |
| filenames = list(os.path.join(LOCAL_DIR, 'songs') + '/' + df.filename + '.mp3') | |
| indices = list(df.index) | |
| main_indices = indices.copy() | |
| def init_indices(): | |
| global indices, main_indices | |
| indices = main_indices | |
| def pick_and_remove_one(): | |
| global indices | |
| if len(indices) < 1: | |
| init_indices() | |
| np.random.shuffle(indices) | |
| sel_indices = indices[0] | |
| indices = indices[1:] | |
| print("Indices : ",sel_indices) | |
| return sel_indices | |
| def vote_last_response(state, vote_type, request: gr.Request): | |
| with scheduler.lock: | |
| with open(LOG_FILENAME, "a") as fout: | |
| data = { | |
| "tstamp": round(time.time(), 4), | |
| "type": vote_type, | |
| "state": state.dict(), | |
| "ip": get_ip(request), | |
| } | |
| fout.write(json.dumps(data) + "\n") | |
| def flag_last_response(state, vote_type, request: gr.Request): | |
| with scheduler.lock: | |
| with open(FLAG_FILENAME, "a") as fout: | |
| data = { | |
| "tstamp": round(time.time(), 4), | |
| "type": vote_type, | |
| "state": state.dict(), | |
| "ip": get_ip(request), | |
| } | |
| fout.write(json.dumps(data) + "\n") | |
| class AudioStateIG: | |
| def __init__(self, row): | |
| self.conv_id = uuid4().hex | |
| self.row = row | |
| self.new_duration = None | |
| def dict(self): | |
| base = { | |
| "conv_id": self.conv_id, | |
| "label": self.row.label, | |
| "filename": self.row.filename, | |
| "duration": self.row.duration if self.new_duration is None else self.new_duration, | |
| "song_id": str(self.row.id), | |
| "source": self.row.source, | |
| "algorithm": self.row.algorithm, | |
| } | |
| return base | |
| def update_duration(self, duration): | |
| self.new_duration = duration | |
| def get_ip(request: gr.Request): | |
| if request: | |
| if "cf-connecting-ip" in request.headers: | |
| ip = request.headers["cf-connecting-ip"] or request.client.host | |
| else: | |
| ip = request.client.host | |
| else: | |
| ip = None | |
| return ip | |
| def get_song(idx, df = df, filenames = filenames): | |
| global indices | |
| row = df.loc[idx] | |
| audio_path = filenames[idx] | |
| state = AudioStateIG(row) | |
| print(df.loc[indices].label.value_counts()) | |
| return state, audio_path | |
| def random_cut_length(audio_data, max_length, sample_rate): | |
| if max_length > 125: | |
| options = [125, 55, 25] | |
| elif max_length > 55: | |
| options = [55, 25] | |
| elif max_length > 25: | |
| options = [25] | |
| else: | |
| return audio_data, max_length | |
| length_picked = random.choice(options) | |
| start_point = np.random.randint(0, max_length - length_picked) | |
| end_point = start_point + length_picked | |
| audio_data_cut = audio_data[start_point*sample_rate : end_point*sample_rate] | |
| return audio_data_cut, length_picked | |
| def constant_cut_length(audio_data, max_length, sample_rate, length_picked = 25): | |
| if max_length <= length_picked: | |
| return audio_data, max_length | |
| start_point = np.random.randint(0, max_length - length_picked) | |
| end_point = start_point + length_picked | |
| audio_data_cut = audio_data[start_point*sample_rate : end_point*sample_rate] | |
| return audio_data_cut, length_picked | |
| def generate_songs(state, song_cut_function = constant_cut_length): | |
| idx= pick_and_remove_one() | |
| state, audio = get_song(idx) | |
| if song_cut_function is not None: | |
| audio_data, sample_rate = librosa.load(audio, sr=None) | |
| audio_cut, new_length = song_cut_function(audio_data, state.row.duration, sample_rate) | |
| state.update_duration(new_length) | |
| return state, (sample_rate, audio_cut), "Vote to Reveal Label", | |
| return state, audio, "Vote to Reveal Label", | |
| def fake_last_response( | |
| state, request: gr.Request | |
| ): | |
| vote_last_response( | |
| state, "fake", request | |
| ) | |
| return (disable_btn,) * 3 + (gr.Markdown(f"### {state.row.label} \nModel : {state.row.algorithm}", visible=True),) | |
| def real_last_response( | |
| state, request: gr.Request | |
| ): | |
| vote_last_response( | |
| state, "real", request | |
| ) | |
| return (disable_btn,) * 3 + (gr.Markdown(f"### {state.row.label}", visible=True),) | |