Spaces:
Paused
Paused
import os | |
import re | |
from pathlib import Path | |
from typing import Optional, Tuple, Union | |
from torch import Tensor | |
from torch.utils.data import Dataset | |
from torchaudio.datasets.utils import _load_waveform | |
_SAMPLE_RATE = 16000 | |
def _get_wavs_paths(data_dir): | |
wav_dir = data_dir / "sentences" / "wav" | |
wav_paths = sorted(str(p) for p in wav_dir.glob("*/*.wav")) | |
relative_paths = [] | |
for wav_path in wav_paths: | |
start = wav_path.find("Session") | |
wav_path = wav_path[start:] | |
relative_paths.append(wav_path) | |
return relative_paths | |
class IEMOCAP(Dataset): | |
"""*IEMOCAP* :cite:`iemocap` dataset. | |
Args: | |
root (str or Path): Root directory where the dataset's top level directory is found | |
sessions (Tuple[int]): Tuple of sessions (1-5) to use. (Default: ``(1, 2, 3, 4, 5)``) | |
utterance_type (str or None, optional): Which type(s) of utterances to include in the dataset. | |
Options: ("scripted", "improvised", ``None``). If ``None``, both scripted and improvised | |
data are used. | |
""" | |
def __init__( | |
self, | |
root: Union[str, Path], | |
sessions: Tuple[str] = (1, 2, 3, 4, 5), | |
utterance_type: Optional[str] = None, | |
): | |
root = Path(root) | |
self._path = root / "IEMOCAP" | |
if not os.path.isdir(self._path): | |
raise RuntimeError("Dataset not found.") | |
if utterance_type not in ["scripted", "improvised", None]: | |
raise ValueError("utterance_type must be one of ['scripted', 'improvised', or None]") | |
all_data = [] | |
self.data = [] | |
self.mapping = {} | |
for session in sessions: | |
session_name = f"Session{session}" | |
session_dir = self._path / session_name | |
# get wav paths | |
wav_paths = _get_wavs_paths(session_dir) | |
for wav_path in wav_paths: | |
wav_stem = str(Path(wav_path).stem) | |
all_data.append(wav_stem) | |
# add labels | |
label_dir = session_dir / "dialog" / "EmoEvaluation" | |
query = "*.txt" | |
if utterance_type == "scripted": | |
query = "*script*.txt" | |
elif utterance_type == "improvised": | |
query = "*impro*.txt" | |
label_paths = label_dir.glob(query) | |
for label_path in label_paths: | |
with open(label_path, "r") as f: | |
for line in f: | |
if not line.startswith("["): | |
continue | |
line = re.split("[\t\n]", line) | |
wav_stem = line[1] | |
label = line[2] | |
if wav_stem not in all_data: | |
continue | |
if label not in ["neu", "hap", "ang", "sad", "exc", "fru"]: | |
continue | |
self.mapping[wav_stem] = {} | |
self.mapping[wav_stem]["label"] = label | |
for wav_path in wav_paths: | |
wav_stem = str(Path(wav_path).stem) | |
if wav_stem in self.mapping: | |
self.data.append(wav_stem) | |
self.mapping[wav_stem]["path"] = wav_path | |
def get_metadata(self, n: int) -> Tuple[str, int, str, str, str]: | |
"""Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform, | |
but otherwise returns the same fields as :py:meth:`__getitem__`. | |
Args: | |
n (int): The index of the sample to be loaded | |
Returns: | |
Tuple of the following items; | |
str: | |
Path to audio | |
int: | |
Sample rate | |
str: | |
File name | |
str: | |
Label (one of ``"neu"``, ``"hap"``, ``"ang"``, ``"sad"``, ``"exc"``, ``"fru"``) | |
str: | |
Speaker | |
""" | |
wav_stem = self.data[n] | |
wav_path = self.mapping[wav_stem]["path"] | |
label = self.mapping[wav_stem]["label"] | |
speaker = wav_stem.split("_")[0] | |
return (wav_path, _SAMPLE_RATE, wav_stem, label, speaker) | |
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, str]: | |
"""Load the n-th sample from the dataset. | |
Args: | |
n (int): The index of the sample to be loaded | |
Returns: | |
Tuple of the following items; | |
Tensor: | |
Waveform | |
int: | |
Sample rate | |
str: | |
File name | |
str: | |
Label (one of ``"neu"``, ``"hap"``, ``"ang"``, ``"sad"``, ``"exc"``, ``"fru"``) | |
str: | |
Speaker | |
""" | |
metadata = self.get_metadata(n) | |
waveform = _load_waveform(self._path, metadata[0], metadata[1]) | |
return (waveform,) + metadata[1:] | |
def __len__(self): | |
return len(self.data) | |