tts_labeling / scripts /export_approved_datasets.py
vargha's picture
auxiliray scripts for dataset managements
8dcb829
raw
history blame
26.4 kB
#!/usr/bin/env python3
"""
Optimized TTS Data Export to Hugging Face
This script exports approved TTS annotations directly from the database to Hugging Face.
Features:
- Local caching for audio files to avoid re-downloading
- Batch processing to handle large datasets without memory issues
- Resume capability for interrupted uploads
- Better error handling and retry mechanisms
- HuggingFace best practices for large dataset uploads
"""
import os
import sys
import json
import hashlib
import time
import shutil
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Dict, Optional, Tuple
import pymysql
import requests
import pandas as pd
from huggingface_hub import HfApi, login
from datasets import Dataset, Audio, Features, Value
import librosa
import numpy as np
from tqdm import tqdm
# Configuration
TARGET_REPO = "navidved/approved-tts-dataset"
SPEAKER_NAME = "ali_bandari"
BATCH_SIZE = 100 # Process annotations in batches
CACHE_DIR = "./audio_cache" # Local cache directory
TEMP_DIR = "./temp_dataset" # Temporary directory for dataset preparation
MAX_WORKERS = 4 # Concurrent downloads
MAX_RETRIES = 3 # Max retries for failed downloads
# Memory optimization settings
OPTIMIZE_MEMORY = True # Enable memory optimizations
TARGET_SAMPLE_RATE = 22050 # Reduce sample rate to save memory (None to keep original)
AUDIO_DTYPE = 'int16' # Use int16 instead of float32 to halve memory usage
USE_GENERATOR = True # Use generator-based dataset creation (recommended for large datasets)
# Database configuration (edit these if needed)
DB_CONFIG = {
'host': 'annotation-db.apps.teh2.abrhapaas.com',
'port': 32107,
'user': os.getenv('DB_USER', 'navid'),
'password': os.getenv('DB_PASSWORD', 'ZUJSK!1V!PF4ZEnIaylX'),
'database': os.getenv('DB_NAME', 'tts'),
'charset': 'utf8mb4'
}
# Audio server base URL
AUDIO_BASE_URL = "http://hubbit.ir/hf_dataset/tts"
class CacheManager:
"""Handles local caching of audio files"""
def __init__(self, cache_dir: str):
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(exist_ok=True)
self.index_file = self.cache_dir / "cache_index.json"
self.index = self._load_index()
def _load_index(self) -> Dict:
"""Load cache index from disk"""
if self.index_file.exists():
try:
with open(self.index_file, 'r') as f:
return json.load(f)
except:
return {}
return {}
def _save_index(self):
"""Save cache index to disk"""
with open(self.index_file, 'w') as f:
json.dump(self.index, f)
def _get_cache_key(self, filename: str) -> str:
"""Generate cache key for filename"""
return hashlib.md5(filename.encode()).hexdigest()
def get_cached_file(self, filename: str) -> Optional[Path]:
"""Get cached file path if exists and valid"""
cache_key = self._get_cache_key(filename)
if cache_key in self.index:
cached_path = Path(self.index[cache_key])
if cached_path.exists():
return cached_path
else:
# Remove invalid entry
del self.index[cache_key]
self._save_index()
return None
def cache_file(self, filename: str, file_data: bytes) -> Path:
"""Cache file data and return path"""
cache_key = self._get_cache_key(filename)
# Use original extension if available
ext = Path(filename).suffix or '.mp3'
cached_path = self.cache_dir / f"{cache_key}{ext}"
with open(cached_path, 'wb') as f:
f.write(file_data)
self.index[cache_key] = str(cached_path)
self._save_index()
return cached_path
class AudioDownloader:
"""Handles audio downloading with retry logic"""
def __init__(self, base_url: str, cache_manager: CacheManager, max_retries: int = 3):
self.base_url = base_url
self.cache_manager = cache_manager
self.max_retries = max_retries
def download_audio(self, filename: str) -> Optional[Tuple[Path, Dict]]:
"""Download and process audio file, return (path, audio_info)"""
# Check cache first
cached_path = self.cache_manager.get_cached_file(filename)
if cached_path:
return self._load_audio_info(cached_path, filename)
# Download file
url = f"{self.base_url}/{filename}"
for attempt in range(self.max_retries):
try:
response = requests.get(url, timeout=30)
response.raise_for_status()
# Cache the file
cached_path = self.cache_manager.cache_file(filename, response.content)
return self._load_audio_info(cached_path, filename)
except Exception as e:
if attempt < self.max_retries - 1:
time.sleep(2 ** attempt) # Exponential backoff
continue
else:
print(f" ❌ Failed to download {filename} after {self.max_retries} attempts: {e}")
return None
def _load_audio_info(self, file_path: Path, filename: str) -> Tuple[Path, Dict]:
"""Load audio information and audio data with memory optimization"""
try:
# Load audio data with librosa
sr = TARGET_SAMPLE_RATE if OPTIMIZE_MEMORY else None
audio_data, sample_rate = librosa.load(str(file_path), sr=sr, mono=True)
# Optimize audio data type for memory efficiency
if OPTIMIZE_MEMORY and AUDIO_DTYPE == 'int16':
# Convert float32 to int16 to halve memory usage
audio_data = (audio_data * 32767).astype(np.int16)
return file_path, {
'filename': filename,
'path': str(file_path),
'audio_array': audio_data, # Optimized audio array
'duration': len(audio_data) / sample_rate,
'sample_rate': sample_rate,
'channels': 1,
'dtype': str(audio_data.dtype)
}
except Exception as e:
# Try with soundfile as fallback
try:
import soundfile as sf
audio_data, sample_rate = sf.read(str(file_path))
if len(audio_data.shape) > 1:
audio_data = np.mean(audio_data, axis=1) # Convert to mono
# Apply sample rate optimization
if OPTIMIZE_MEMORY and TARGET_SAMPLE_RATE and sample_rate != TARGET_SAMPLE_RATE:
import scipy.signal
num_samples = int(len(audio_data) * TARGET_SAMPLE_RATE / sample_rate)
audio_data = scipy.signal.resample(audio_data, num_samples)
sample_rate = TARGET_SAMPLE_RATE
# Optimize data type
if OPTIMIZE_MEMORY and AUDIO_DTYPE == 'int16':
audio_data = (audio_data * 32767).astype(np.int16)
return file_path, {
'filename': filename,
'path': str(file_path),
'audio_array': audio_data,
'duration': len(audio_data) / sample_rate,
'sample_rate': sample_rate,
'channels': 1,
'dtype': str(audio_data.dtype)
}
except ImportError:
print(f" ❌ Error loading audio {filename}: {e}")
return None
class BatchProcessor:
"""Processes annotations in batches to avoid memory issues"""
def __init__(self, downloader: AudioDownloader, temp_dir: str, batch_size: int = 100):
self.downloader = downloader
self.temp_dir = Path(temp_dir)
self.temp_dir.mkdir(exist_ok=True)
self.batch_size = batch_size
def process_batch(self, annotations: List[Dict], batch_id: int) -> Optional[Path]:
"""Process a batch of annotations and save to parquet"""
print(f"\nπŸ“¦ Processing batch {batch_id} with {len(annotations)} annotations...")
batch_data = []
# Use ThreadPoolExecutor for concurrent downloads
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
# Submit all download tasks
future_to_annotation = {
executor.submit(self.downloader.download_audio, ann['audio_file_name']): ann
for ann in annotations
}
# Process completed downloads
for future in tqdm(as_completed(future_to_annotation),
total=len(annotations),
desc=f"Batch {batch_id}"):
annotation = future_to_annotation[future]
try:
result = future.result()
if result:
file_path, audio_info = result
# Structure audio data for HuggingFace compatibility
audio_array = audio_info['audio_array']
# Convert to list for serialization, handling different dtypes
if audio_info.get('dtype') == 'int16':
# For int16, convert to float32 for better compatibility with HF Audio
array_list = (audio_array.astype(np.float32) / 32767.0).tolist()
else:
array_list = audio_array.astype(np.float32).tolist()
audio_data = {
'array': array_list,
'sampling_rate': int(audio_info['sample_rate']),
'path': f"audio/{annotation['audio_file_name']}"
}
batch_data.append({
'audio': audio_data, # HuggingFace standard audio column
'file_name': f"audio/{annotation['audio_file_name']}", # Keep for compatibility
'sentence': annotation['sentence'],
'speaker': SPEAKER_NAME,
'duration': audio_info['duration'],
'sample_rate': audio_info['sample_rate']
})
except Exception as e:
print(f" ⚠️ Error processing {annotation['audio_file_name']}: {e}")
if not batch_data:
print(f" ❌ No valid audio files in batch {batch_id}")
return None
# Save batch to parquet
batch_file = self.temp_dir / f"batch_{batch_id:04d}.parquet"
df = pd.DataFrame(batch_data)
df.to_parquet(batch_file, index=False)
print(f" βœ… Saved {len(batch_data)} files to {batch_file}")
return batch_file
class DatasetUploader:
"""Handles HuggingFace dataset upload using best practices"""
def __init__(self, temp_dir: str, target_repo: str):
self.temp_dir = Path(temp_dir)
self.target_repo = target_repo
self.api = HfApi()
def prepare_dataset_structure(self) -> Path:
"""Prepare dataset structure for upload"""
dataset_dir = self.temp_dir / "dataset"
dataset_dir.mkdir(exist_ok=True)
# Create audio directory
audio_dir = dataset_dir / "audio"
audio_dir.mkdir(exist_ok=True)
batch_files = list(self.temp_dir.glob("batch_*.parquet"))
print(f"\nπŸ“ Preparing dataset structure from {len(batch_files)} batch files...")
if USE_GENERATOR:
# Memory-efficient generator-based approach
print("🧠 Using memory-efficient generator approach...")
def audio_sample_generator():
"""Generator that yields one sample at a time to minimize memory usage"""
sample_count = 0
for batch_file in tqdm(batch_files, desc="Processing batch files"):
try:
df = pd.read_parquet(batch_file)
for _, row in df.iterrows():
sample_count += 1
yield {
'audio': row['audio'],
'file_name': row['file_name'],
'sentence': row['sentence'],
'speaker': row['speaker'],
'duration': row['duration'],
'sample_rate': row['sample_rate']
}
# Clean up processed batch file to save disk space
batch_file.unlink()
print(f" 🧹 Cleaned up {batch_file.name}")
except Exception as e:
print(f" ⚠️ Error processing {batch_file}: {e}")
continue
print(f" βœ… Generated {sample_count} samples")
# Create dataset using generator (memory efficient)
print(f"\nπŸ”„ Creating HuggingFace dataset using generator...")
features = Features({
'audio': Audio(sampling_rate=None),
'file_name': Value('string'),
'sentence': Value('string'),
'speaker': Value('string'),
'duration': Value('float32'),
'sample_rate': Value('int32')
})
dataset = Dataset.from_generator(
audio_sample_generator,
features=features,
cache_dir=str(self.temp_dir / "hf_cache") # Use local cache
)
num_samples = len(dataset)
else:
# Original approach (memory intensive)
print("⚠️ Using original approach - may consume significant memory...")
all_data = []
for batch_file in tqdm(batch_files, desc="Processing batches"):
df = pd.read_parquet(batch_file)
for _, row in df.iterrows():
all_data.append({
'audio': row['audio'],
'file_name': row['file_name'],
'sentence': row['sentence'],
'speaker': row['speaker'],
'duration': row['duration'],
'sample_rate': row['sample_rate']
})
print(f"\nπŸ”„ Creating HuggingFace dataset with {len(all_data)} samples...")
df = pd.DataFrame(all_data)
features = Features({
'audio': Audio(sampling_rate=None),
'file_name': Value('string'),
'sentence': Value('string'),
'speaker': Value('string'),
'duration': Value('float32'),
'sample_rate': Value('int32')
})
dataset = Dataset.from_pandas(df, features=features)
num_samples = len(all_data)
# Save the dataset in HuggingFace format
print(f"πŸ’Ύ Saving dataset to disk...")
dataset.save_to_disk(str(dataset_dir / "dataset"))
# Save metadata for compatibility (using a small sample to avoid memory issues)
print(f"πŸ“‹ Creating metadata files...")
sample_data = []
for i, sample in enumerate(dataset.select(range(min(1000, len(dataset))))):
sample_data.append({
'file_name': sample['file_name'],
'sentence': sample['sentence'],
'speaker': sample['speaker'],
'duration': sample['duration'],
'sample_rate': sample['sample_rate']
})
metadata_df = pd.DataFrame(sample_data)
metadata_df.to_parquet(dataset_dir / "train.parquet", index=False)
metadata_df.to_parquet(dataset_dir / "metadata.parquet", index=False)
# Create dataset card
self._create_dataset_card(dataset_dir, num_samples)
print(f" βœ… Dataset prepared with {num_samples} samples in {dataset_dir}")
return dataset_dir
def _create_dataset_card(self, dataset_dir: Path, num_samples: int):
"""Create a basic dataset card"""
card_content = f"""---
license: mit
task_categories:
- text-to-speech
language:
- fa
tags:
- tts
- persian
- farsi
- speech-synthesis
size_categories:
- {self._get_size_category(num_samples)}
---
# {TARGET_REPO.split('/')[-1]}
This dataset contains {num_samples} Persian TTS samples with the speaker "{SPEAKER_NAME}".
## Dataset Structure
- `dataset/`: HuggingFace dataset format with audio arrays
- `train.parquet`: Training split metadata
- `metadata.parquet`: General metadata file (same content as train.parquet)
**Metadata columns:**
- `audio`: Audio data with array, sampling_rate, and path
- `array`: Audio data as float array
- `sampling_rate`: Sample rate in Hz
- `path`: Relative path to audio file
- `file_name`: Relative path to audio files (e.g., "audio/filename.mp3")
- `sentence`: Transcription text in Persian
- `speaker`: Speaker identifier ("{SPEAKER_NAME}")
- `duration`: Audio duration in seconds
- `sample_rate`: Audio sample rate in Hz
## Usage
```python
from datasets import load_dataset
# Load the dataset
dataset = load_dataset("{self.target_repo}")
# Access audio and transcription
for item in dataset['train']:
audio_data = item['audio'] # Dict with 'array', 'sampling_rate', 'path'
audio_array = audio_data['array'] # Actual audio as numpy array
sample_rate = audio_data['sampling_rate'] # Sample rate
text = item['sentence'] # Transcription
speaker = item['speaker'] # Speaker ID
# You can also load with streaming for large datasets
dataset = load_dataset("{self.target_repo}", streaming=True)
for item in dataset['train']:
audio = item['audio']['array'] # Audio array directly
text = item['sentence'] # Transcription
```
## Speaker
- **Speaker ID**: {SPEAKER_NAME}
- **Language**: Persian (Farsi)
- **Total Samples**: {num_samples}
Generated using the TTS annotation system.
"""
with open(dataset_dir / "README.md", 'w', encoding='utf-8') as f:
f.write(card_content)
def _get_size_category(self, num_samples: int) -> str:
"""Get size category for dataset card"""
if num_samples < 1000:
return "n<1K"
elif num_samples < 10000:
return "1K<n<10K"
elif num_samples < 100000:
return "10K<n<100K"
else:
return "100K<n<1M"
def upload_dataset(self, dataset_dir: Path):
"""Upload dataset using HuggingFace best practices"""
print(f"\nπŸš€ Uploading dataset to {self.target_repo}...")
try:
# Check if dataset directory exists in HF format
hf_dataset_dir = dataset_dir / "dataset"
if hf_dataset_dir.exists():
print("πŸ“¦ Uploading HuggingFace dataset format...")
# Load and push the dataset
dataset = Dataset.load_from_disk(str(hf_dataset_dir))
dataset.push_to_hub(
self.target_repo,
commit_message="Add TTS dataset with audio arrays"
)
print(f"βœ… Dataset upload completed successfully!")
else:
# Fallback to folder upload
print("πŸ“ Uploading as folder...")
self.api.upload_large_folder(
repo_id=self.target_repo,
repo_type="dataset",
folder_path=str(dataset_dir)
)
print(f"βœ… Folder upload completed successfully!")
print(f"Dataset available at: https://huggingface.co/datasets/{self.target_repo}")
except Exception as e:
print(f"❌ Upload failed: {e}")
print("You can retry the upload or use the prepared dataset directory manually.")
print(f"Dataset directory: {dataset_dir}")
# Fallback to regular upload_folder with commit message
print("\nπŸ”„ Trying fallback upload method...")
try:
self.api.upload_folder(
repo_id=self.target_repo,
repo_type="dataset",
folder_path=str(dataset_dir),
commit_message="Add TTS dataset with audio arrays"
)
print(f"βœ… Fallback upload completed successfully!")
print(f"Dataset available at: https://huggingface.co/datasets/{self.target_repo}")
except Exception as fallback_error:
print(f"❌ Fallback upload also failed: {fallback_error}")
print(f"Manual upload required. Dataset directory: {dataset_dir}")
raise
def get_approved_annotations():
"""Get all approved annotations from the database"""
connection = pymysql.connect(**DB_CONFIG)
try:
with connection.cursor(pymysql.cursors.DictCursor) as cursor:
# Query for approved annotations
query = """
SELECT
a.annotated_sentence as sentence,
td.filename as audio_file_name
FROM annotations a
JOIN validations v ON a.id = v.annotation_id
JOIN tts_data td ON a.tts_data_id = td.id
WHERE v.validated = 1
"""
cursor.execute(query)
results = cursor.fetchall()
print(f"Found {len(results)} approved annotations")
return results
finally:
connection.close()
def cleanup_temp_files(temp_dir: Path, keep_dataset: bool = True):
"""Clean up temporary files"""
if not keep_dataset and temp_dir.exists():
shutil.rmtree(temp_dir)
print(f"🧹 Cleaned up temporary directory: {temp_dir}")
else:
# Only clean up batch files, keep the dataset
batch_files = list(temp_dir.glob("batch_*.parquet"))
for batch_file in batch_files:
batch_file.unlink()
print(f"🧹 Cleaned up {len(batch_files)} batch files")
def main():
"""Main export function with improved error handling and performance"""
print("πŸš€ Starting optimized TTS data export to Hugging Face...")
print(f"πŸ“Š Configuration:")
print(f" - Target repository: {TARGET_REPO}")
print(f" - Speaker: {SPEAKER_NAME}")
print(f" - Batch size: {BATCH_SIZE}")
print(f" - Cache directory: {CACHE_DIR}")
print(f" - Max concurrent downloads: {MAX_WORKERS}")
if OPTIMIZE_MEMORY:
print(f"🧠 Memory Optimizations Enabled:")
print(f" - Target sample rate: {TARGET_SAMPLE_RATE or 'Original'}")
print(f" - Audio data type: {AUDIO_DTYPE}")
print(f" - Generator-based processing: {USE_GENERATOR}")
else:
print("⚠️ Memory optimizations disabled - may consume significant RAM")
try:
# Initialize components
cache_manager = CacheManager(CACHE_DIR)
downloader = AudioDownloader(AUDIO_BASE_URL, cache_manager, MAX_RETRIES)
processor = BatchProcessor(downloader, TEMP_DIR, BATCH_SIZE)
uploader = DatasetUploader(TEMP_DIR, TARGET_REPO)
# Get approved annotations
print("\nπŸ“‹ Fetching approved annotations from database...")
annotations = get_approved_annotations()
if not annotations:
print("❌ No approved annotations found!")
return
total_batches = (len(annotations) + BATCH_SIZE - 1) // BATCH_SIZE
print(f"πŸ“¦ Will process {len(annotations)} annotations in {total_batches} batches")
# Process annotations in batches
batch_files = []
for i in range(0, len(annotations), BATCH_SIZE):
batch_id = i // BATCH_SIZE + 1
batch_annotations = annotations[i:i + BATCH_SIZE]
batch_file = processor.process_batch(batch_annotations, batch_id)
if batch_file:
batch_files.append(batch_file)
if not batch_files:
print("❌ No batches were processed successfully!")
return
print(f"\nβœ… Successfully processed {len(batch_files)} batches")
# Prepare dataset structure
dataset_dir = uploader.prepare_dataset_structure()
# Login to HF
print("\nπŸ”‘ Logging in to Hugging Face...")
try:
login() # Will use HF_TOKEN env var or prompt for token
except Exception as e:
print(f"❌ HF login failed: {e}")
print("Make sure you have HF_TOKEN environment variable set or login manually")
return
# Upload dataset
uploader.upload_dataset(dataset_dir)
# Cleanup
cleanup_temp_files(Path(TEMP_DIR), keep_dataset=True)
print("\nπŸŽ‰ Export completed successfully!")
print(f"πŸ“Š Final stats:")
print(f" - Total annotations processed: {len(annotations)}")
print(f" - Successful batches: {len(batch_files)}")
print(f" - Dataset URL: https://huggingface.co/datasets/{TARGET_REPO}")
print(f" - Local dataset copy: {dataset_dir}")
except KeyboardInterrupt:
print("\n⚠️ Process interrupted by user")
print("πŸ’‘ You can resume by running the script again - cached files will be reused")
except Exception as e:
print(f"\n❌ Error during export: {e}")
print("πŸ’‘ Check the error above and try again - cached files will be reused")
raise
if __name__ == "__main__":
main()