import json import os import random import warnings import gradio as gr import librosa import numpy as np from datasets import IterableDatasetDict, load_dataset from gradio_client import Client from loguru import logger warnings.filterwarnings("ignore") NUM_TAR_FILES = 115 NUM_SAMPLES = 3746131 HF_PATH_TO_DATASET = "litagin/Galgame_Speech_SER_16kHz" hf_token = os.getenv("HF_TOKEN") client = Client("litagin/ser_record", hf_token=hf_token) id2label = { 0: "Angry", 1: "Disgusted", 2: "Embarrassed", 3: "Fearful", 4: "Happy", 5: "Sad", 6: "Surprised", 7: "Neutral", 8: "Sexual1", 9: "Sexual2", } id2rich_label = { 0: "😠 怒り (0)", 1: "😒 嫌悪 (1)", 2: "😳 恥ずかしさ・戸惑い (2)", 3: "😨 恐怖 (3)", 4: "😊 幸せ (4)", 5: "😢 悲しみ (5)", 6: "😲 驚き (6)", 7: "😐 中立 (7)", 8: "🥰 NSFW1 (8)", 9: "🍭 NSFW2 (9)", } current_item: dict | None = None def _load_dataset( *, streaming: bool = True, use_local_dataset: bool = False, local_dataset_path: str | None = None, data_dir: str = "data", ) -> IterableDatasetDict: data_files = { "train": [ f"galgame-speech-ser-16kHz-train-000{index:03d}.tar" for index in range(0, NUM_TAR_FILES) ], } if use_local_dataset: assert local_dataset_path is not None path = local_dataset_path else: path = HF_PATH_TO_DATASET dataset: IterableDatasetDict = load_dataset( path=path, data_dir=data_dir, data_files=data_files, streaming=streaming ) # type: ignore dataset = dataset.remove_columns(["__url__"]) dataset = dataset.rename_column("ogg", "audio") return dataset logger.info("Start loading dataset") ds = _load_dataset(streaming=True, use_local_dataset=False) logger.info("Dataset loaded") seed = random.randint(0, 2**32 - 1) logger.info(f"Seed: {seed}") ds_iter = iter(ds["train"].shuffle(seed=seed)) # ds_iter = iter(ds["train"]) counter = 0 shortcut_js = """ <script> function shortcuts(e) { if (e.key === "a") { document.getElementById("btn_skip").click(); } else if (e.key === "0") { document.getElementById("btn_0").click(); } else if (e.key === "1") { document.getElementById("btn_1").click(); } else if (e.key === "2") { document.getElementById("btn_2").click(); } else if (e.key === "3") { document.getElementById("btn_3").click(); } else if (e.key === "4") { document.getElementById("btn_4").click(); } else if (e.key === "5") { document.getElementById("btn_5").click(); } else if (e.key === "6") { document.getElementById("btn_6").click(); } else if (e.key === "7") { document.getElementById("btn_7").click(); } else if (e.key === "8") { document.getElementById("btn_8").click(); } else if (e.key === "9") { document.getElementById("btn_9").click(); } } document.addEventListener('keypress', shortcuts, false); </script> """ def modify_speed( data: tuple[int, np.ndarray], speed: float = 1.0 ) -> tuple[int, np.ndarray]: if speed == 1.0: return data sr, array = data return sr, librosa.effects.time_stretch(array, rate=speed) def parse_item(item) -> dict: global counter label_id = item["cls"] sampling_rate = item["audio"]["sampling_rate"] array = item["audio"]["array"] return { "key": item["__key__"], "audio": (sampling_rate, array), "text": item["txt"], "label": id2rich_label[label_id], "label_id": label_id, "counter": counter, } def get_next_parsed_item() -> dict: global counter, ds_iter logger.info("Getting next item") try: next_item = next(ds_iter) counter += 1 except StopIteration: logger.info("StopIteration, re-initializing using new seed") seed = random.randint(0, 2**32 - 1) logger.info(f"New Seed: {seed}") ds_iter = iter(ds["train"].shuffle(seed=seed)) next_item = next(ds_iter) counter = 1 parsed = parse_item(next_item) logger.info( f"Next item:\nkey={parsed['key']}\ntext={parsed['text']}\nlabel={parsed['label']}" ) return parsed md = """ # 説明 - **性的な音声が含まれるため、18歳未満の方はご利用をお控えください** - このアプリは [このゲームのセリフ音声データセット](https://huggingface.co/datasets/litagin/Galgame_Speech_SER_16kHz) の感情ラベルを修正して、大規模で高品質な感情音声データセットを作成するためのものです - 「**何を言っているか**」ではなく「**どのように言っているか**」に注目して、感情ラベルを付与してください(例: 悲しそうに「とっても楽しいです…」と言っていたら、 `😊 幸せ` ではなく `😢 悲しみ` とする) - 既存のラベルが適切であれば、そのまま「現在の感情ラベルで適切」ボタンを押してください(ショートカットキー: `A`) - ラベルを修正する場合は、適切なボタンを押してください(ショートカットキー: `0` 〜 `9`) # ラベル補足 - `🥰 NSFW1` は女性の性的行為中の音声(喘ぎ声等) - `🍭 NSFW2` はキスシーンでのリップ音やフェラシーンでのしゃぶる音(チュパ音)が多く含まれている音声(セリフ+チュパ音の場合も含む)(フェラシーン中のセリフだと思われる場合はこれ) - 感情が音声からは特に読み取れない場合(普通のテンションの声で「今日はラーメンを食べます」等)は `😐 中立` を選択してください - 複数の感情が含まれている場合は、最も多く含まれている感情を選択してください """ with gr.Blocks(head=shortcut_js) as app: gr.Markdown(md) with gr.Row(): with gr.Column(): btn_init = gr.Button("初期化・再読み込み") speed_slider = gr.Slider( minimum=0.5, maximum=5.0, step=0.1, value=1.0, label="再生速度" ) counter_info = gr.Textbox(label="進捗状況") with gr.Column(variant="panel"): key = gr.Textbox(label="Key") audio = gr.Audio( show_download_button=False, show_share_button=False, interactive=False, ) text = gr.Textbox(label="Text") label = gr.Textbox(label="感情ラベル") label_id = gr.Textbox(visible=False) btn_skip = gr.Button("現在の感情ラベルで適切 (A)", elem_id="btn_skip") with gr.Column(): gr.Markdown("# 感情ラベルを修正する場合") btn_list = [ gr.Button(id2rich_label[_id], elem_id=f"btn_{_id}") for _id in range(10) ] def update_current_item(data: dict) -> dict: global current_item if current_item is None: current_item = get_next_parsed_item() modified_audio = modify_speed(current_item["audio"], speed=data[speed_slider]) counter_str = f"{current_item['counter']}/{NUM_SAMPLES}: {current_item['counter'] / NUM_SAMPLES * 100:.2f}%" return { key: current_item["key"], audio: gr.Audio(modified_audio, autoplay=True), text: current_item["text"], label: current_item["label"], label_id: current_item["label_id"], counter_info: counter_str, } def set_next_item(data: dict) -> dict: global current_item current_item = get_next_parsed_item() return update_current_item(data) def put_unmodified(data: dict) -> dict: logger.info("Putting unmodified") current_key = data[key] current_label_id = data[label_id] _ = client.predict( new_data=json.dumps( { "key": current_key, "cls": int(current_label_id), } ), api_name="/put_data", ) logger.info("Unmodified sent") return set_next_item(data) btn_init.click( update_current_item, inputs={speed_slider}, outputs=[key, audio, text, label, label_id, counter_info], ) btn_skip.click( put_unmodified, inputs={key, label_id, speed_slider}, outputs=[key, audio, text, label, label_id, counter_info], ) functions_list = [] for _id in range(10): def put_label(data: dict, _id=_id) -> dict: logger.info(f"Putting label: {id2rich_label[_id]}") current_key = data[key] _ = client.predict( new_data=json.dumps( { "key": current_key, "cls": _id, } ), api_name="/put_data", ) logger.info("Modified sent") return set_next_item(data) functions_list.append(put_label) for _id in range(10): btn_list[_id].click( functions_list[_id], inputs={key, speed_slider}, outputs=[key, audio, text, label, label_id, counter_info], ) app.launch()