Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
Optimized TTS Data Export to Hugging Face | |
This script exports approved TTS annotations directly from the database to Hugging Face. | |
Features: | |
- Local caching for audio files to avoid re-downloading | |
- Batch processing to handle large datasets without memory issues | |
- Resume capability for interrupted uploads | |
- Better error handling and retry mechanisms | |
- HuggingFace best practices for large dataset uploads | |
""" | |
import os | |
import sys | |
import json | |
import hashlib | |
import time | |
import shutil | |
from pathlib import Path | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
from typing import List, Dict, Optional, Tuple | |
import pymysql | |
import requests | |
import pandas as pd | |
from huggingface_hub import HfApi, login | |
from datasets import Dataset, Audio, Features, Value | |
import librosa | |
import numpy as np | |
from tqdm import tqdm | |
# Configuration | |
TARGET_REPO = "navidved/approved-tts-dataset" | |
SPEAKER_NAME = "ali_bandari" | |
BATCH_SIZE = 100 # Process annotations in batches | |
CACHE_DIR = "./audio_cache" # Local cache directory | |
TEMP_DIR = "./temp_dataset" # Temporary directory for dataset preparation | |
MAX_WORKERS = 4 # Concurrent downloads | |
MAX_RETRIES = 3 # Max retries for failed downloads | |
# Memory optimization settings | |
OPTIMIZE_MEMORY = True # Enable memory optimizations | |
TARGET_SAMPLE_RATE = 22050 # Reduce sample rate to save memory (None to keep original) | |
AUDIO_DTYPE = 'int16' # Use int16 instead of float32 to halve memory usage | |
USE_GENERATOR = True # Use generator-based dataset creation (recommended for large datasets) | |
# Database configuration (edit these if needed) | |
DB_CONFIG = { | |
'host': 'annotation-db.apps.teh2.abrhapaas.com', | |
'port': 32107, | |
'user': os.getenv('DB_USER', 'navid'), | |
'password': os.getenv('DB_PASSWORD', 'ZUJSK!1V!PF4ZEnIaylX'), | |
'database': os.getenv('DB_NAME', 'tts'), | |
'charset': 'utf8mb4' | |
} | |
# Audio server base URL | |
AUDIO_BASE_URL = "http://hubbit.ir/hf_dataset/tts" | |
class CacheManager: | |
"""Handles local caching of audio files""" | |
def __init__(self, cache_dir: str): | |
self.cache_dir = Path(cache_dir) | |
self.cache_dir.mkdir(exist_ok=True) | |
self.index_file = self.cache_dir / "cache_index.json" | |
self.index = self._load_index() | |
def _load_index(self) -> Dict: | |
"""Load cache index from disk""" | |
if self.index_file.exists(): | |
try: | |
with open(self.index_file, 'r') as f: | |
return json.load(f) | |
except: | |
return {} | |
return {} | |
def _save_index(self): | |
"""Save cache index to disk""" | |
with open(self.index_file, 'w') as f: | |
json.dump(self.index, f) | |
def _get_cache_key(self, filename: str) -> str: | |
"""Generate cache key for filename""" | |
return hashlib.md5(filename.encode()).hexdigest() | |
def get_cached_file(self, filename: str) -> Optional[Path]: | |
"""Get cached file path if exists and valid""" | |
cache_key = self._get_cache_key(filename) | |
if cache_key in self.index: | |
cached_path = Path(self.index[cache_key]) | |
if cached_path.exists(): | |
return cached_path | |
else: | |
# Remove invalid entry | |
del self.index[cache_key] | |
self._save_index() | |
return None | |
def cache_file(self, filename: str, file_data: bytes) -> Path: | |
"""Cache file data and return path""" | |
cache_key = self._get_cache_key(filename) | |
# Use original extension if available | |
ext = Path(filename).suffix or '.mp3' | |
cached_path = self.cache_dir / f"{cache_key}{ext}" | |
with open(cached_path, 'wb') as f: | |
f.write(file_data) | |
self.index[cache_key] = str(cached_path) | |
self._save_index() | |
return cached_path | |
class AudioDownloader: | |
"""Handles audio downloading with retry logic""" | |
def __init__(self, base_url: str, cache_manager: CacheManager, max_retries: int = 3): | |
self.base_url = base_url | |
self.cache_manager = cache_manager | |
self.max_retries = max_retries | |
def download_audio(self, filename: str) -> Optional[Tuple[Path, Dict]]: | |
"""Download and process audio file, return (path, audio_info)""" | |
# Check cache first | |
cached_path = self.cache_manager.get_cached_file(filename) | |
if cached_path: | |
return self._load_audio_info(cached_path, filename) | |
# Download file | |
url = f"{self.base_url}/{filename}" | |
for attempt in range(self.max_retries): | |
try: | |
response = requests.get(url, timeout=30) | |
response.raise_for_status() | |
# Cache the file | |
cached_path = self.cache_manager.cache_file(filename, response.content) | |
return self._load_audio_info(cached_path, filename) | |
except Exception as e: | |
if attempt < self.max_retries - 1: | |
time.sleep(2 ** attempt) # Exponential backoff | |
continue | |
else: | |
print(f" β Failed to download {filename} after {self.max_retries} attempts: {e}") | |
return None | |
def _load_audio_info(self, file_path: Path, filename: str) -> Tuple[Path, Dict]: | |
"""Load audio information and audio data with memory optimization""" | |
try: | |
# Load audio data with librosa | |
sr = TARGET_SAMPLE_RATE if OPTIMIZE_MEMORY else None | |
audio_data, sample_rate = librosa.load(str(file_path), sr=sr, mono=True) | |
# Optimize audio data type for memory efficiency | |
if OPTIMIZE_MEMORY and AUDIO_DTYPE == 'int16': | |
# Convert float32 to int16 to halve memory usage | |
audio_data = (audio_data * 32767).astype(np.int16) | |
return file_path, { | |
'filename': filename, | |
'path': str(file_path), | |
'audio_array': audio_data, # Optimized audio array | |
'duration': len(audio_data) / sample_rate, | |
'sample_rate': sample_rate, | |
'channels': 1, | |
'dtype': str(audio_data.dtype) | |
} | |
except Exception as e: | |
# Try with soundfile as fallback | |
try: | |
import soundfile as sf | |
audio_data, sample_rate = sf.read(str(file_path)) | |
if len(audio_data.shape) > 1: | |
audio_data = np.mean(audio_data, axis=1) # Convert to mono | |
# Apply sample rate optimization | |
if OPTIMIZE_MEMORY and TARGET_SAMPLE_RATE and sample_rate != TARGET_SAMPLE_RATE: | |
import scipy.signal | |
num_samples = int(len(audio_data) * TARGET_SAMPLE_RATE / sample_rate) | |
audio_data = scipy.signal.resample(audio_data, num_samples) | |
sample_rate = TARGET_SAMPLE_RATE | |
# Optimize data type | |
if OPTIMIZE_MEMORY and AUDIO_DTYPE == 'int16': | |
audio_data = (audio_data * 32767).astype(np.int16) | |
return file_path, { | |
'filename': filename, | |
'path': str(file_path), | |
'audio_array': audio_data, | |
'duration': len(audio_data) / sample_rate, | |
'sample_rate': sample_rate, | |
'channels': 1, | |
'dtype': str(audio_data.dtype) | |
} | |
except ImportError: | |
print(f" β Error loading audio {filename}: {e}") | |
return None | |
class BatchProcessor: | |
"""Processes annotations in batches to avoid memory issues""" | |
def __init__(self, downloader: AudioDownloader, temp_dir: str, batch_size: int = 100): | |
self.downloader = downloader | |
self.temp_dir = Path(temp_dir) | |
self.temp_dir.mkdir(exist_ok=True) | |
self.batch_size = batch_size | |
def process_batch(self, annotations: List[Dict], batch_id: int) -> Optional[Path]: | |
"""Process a batch of annotations and save to parquet""" | |
print(f"\nπ¦ Processing batch {batch_id} with {len(annotations)} annotations...") | |
batch_data = [] | |
# Use ThreadPoolExecutor for concurrent downloads | |
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: | |
# Submit all download tasks | |
future_to_annotation = { | |
executor.submit(self.downloader.download_audio, ann['audio_file_name']): ann | |
for ann in annotations | |
} | |
# Process completed downloads | |
for future in tqdm(as_completed(future_to_annotation), | |
total=len(annotations), | |
desc=f"Batch {batch_id}"): | |
annotation = future_to_annotation[future] | |
try: | |
result = future.result() | |
if result: | |
file_path, audio_info = result | |
# Structure audio data for HuggingFace compatibility | |
audio_array = audio_info['audio_array'] | |
# Convert to list for serialization, handling different dtypes | |
if audio_info.get('dtype') == 'int16': | |
# For int16, convert to float32 for better compatibility with HF Audio | |
array_list = (audio_array.astype(np.float32) / 32767.0).tolist() | |
else: | |
array_list = audio_array.astype(np.float32).tolist() | |
audio_data = { | |
'array': array_list, | |
'sampling_rate': int(audio_info['sample_rate']), | |
'path': f"audio/{annotation['audio_file_name']}" | |
} | |
batch_data.append({ | |
'audio': audio_data, # HuggingFace standard audio column | |
'file_name': f"audio/{annotation['audio_file_name']}", # Keep for compatibility | |
'sentence': annotation['sentence'], | |
'speaker': SPEAKER_NAME, | |
'duration': audio_info['duration'], | |
'sample_rate': audio_info['sample_rate'] | |
}) | |
except Exception as e: | |
print(f" β οΈ Error processing {annotation['audio_file_name']}: {e}") | |
if not batch_data: | |
print(f" β No valid audio files in batch {batch_id}") | |
return None | |
# Save batch to parquet | |
batch_file = self.temp_dir / f"batch_{batch_id:04d}.parquet" | |
df = pd.DataFrame(batch_data) | |
df.to_parquet(batch_file, index=False) | |
print(f" β Saved {len(batch_data)} files to {batch_file}") | |
return batch_file | |
class DatasetUploader: | |
"""Handles HuggingFace dataset upload using best practices""" | |
def __init__(self, temp_dir: str, target_repo: str): | |
self.temp_dir = Path(temp_dir) | |
self.target_repo = target_repo | |
self.api = HfApi() | |
def prepare_dataset_structure(self) -> Path: | |
"""Prepare dataset structure for upload""" | |
dataset_dir = self.temp_dir / "dataset" | |
dataset_dir.mkdir(exist_ok=True) | |
# Create audio directory | |
audio_dir = dataset_dir / "audio" | |
audio_dir.mkdir(exist_ok=True) | |
batch_files = list(self.temp_dir.glob("batch_*.parquet")) | |
print(f"\nπ Preparing dataset structure from {len(batch_files)} batch files...") | |
if USE_GENERATOR: | |
# Memory-efficient generator-based approach | |
print("π§ Using memory-efficient generator approach...") | |
def audio_sample_generator(): | |
"""Generator that yields one sample at a time to minimize memory usage""" | |
sample_count = 0 | |
for batch_file in tqdm(batch_files, desc="Processing batch files"): | |
try: | |
df = pd.read_parquet(batch_file) | |
for _, row in df.iterrows(): | |
sample_count += 1 | |
yield { | |
'audio': row['audio'], | |
'file_name': row['file_name'], | |
'sentence': row['sentence'], | |
'speaker': row['speaker'], | |
'duration': row['duration'], | |
'sample_rate': row['sample_rate'] | |
} | |
# Clean up processed batch file to save disk space | |
batch_file.unlink() | |
print(f" π§Ή Cleaned up {batch_file.name}") | |
except Exception as e: | |
print(f" β οΈ Error processing {batch_file}: {e}") | |
continue | |
print(f" β Generated {sample_count} samples") | |
# Create dataset using generator (memory efficient) | |
print(f"\nπ Creating HuggingFace dataset using generator...") | |
features = Features({ | |
'audio': Audio(sampling_rate=None), | |
'file_name': Value('string'), | |
'sentence': Value('string'), | |
'speaker': Value('string'), | |
'duration': Value('float32'), | |
'sample_rate': Value('int32') | |
}) | |
dataset = Dataset.from_generator( | |
audio_sample_generator, | |
features=features, | |
cache_dir=str(self.temp_dir / "hf_cache") # Use local cache | |
) | |
num_samples = len(dataset) | |
else: | |
# Original approach (memory intensive) | |
print("β οΈ Using original approach - may consume significant memory...") | |
all_data = [] | |
for batch_file in tqdm(batch_files, desc="Processing batches"): | |
df = pd.read_parquet(batch_file) | |
for _, row in df.iterrows(): | |
all_data.append({ | |
'audio': row['audio'], | |
'file_name': row['file_name'], | |
'sentence': row['sentence'], | |
'speaker': row['speaker'], | |
'duration': row['duration'], | |
'sample_rate': row['sample_rate'] | |
}) | |
print(f"\nπ Creating HuggingFace dataset with {len(all_data)} samples...") | |
df = pd.DataFrame(all_data) | |
features = Features({ | |
'audio': Audio(sampling_rate=None), | |
'file_name': Value('string'), | |
'sentence': Value('string'), | |
'speaker': Value('string'), | |
'duration': Value('float32'), | |
'sample_rate': Value('int32') | |
}) | |
dataset = Dataset.from_pandas(df, features=features) | |
num_samples = len(all_data) | |
# Save the dataset in HuggingFace format | |
print(f"πΎ Saving dataset to disk...") | |
dataset.save_to_disk(str(dataset_dir / "dataset")) | |
# Save metadata for compatibility (using a small sample to avoid memory issues) | |
print(f"π Creating metadata files...") | |
sample_data = [] | |
for i, sample in enumerate(dataset.select(range(min(1000, len(dataset))))): | |
sample_data.append({ | |
'file_name': sample['file_name'], | |
'sentence': sample['sentence'], | |
'speaker': sample['speaker'], | |
'duration': sample['duration'], | |
'sample_rate': sample['sample_rate'] | |
}) | |
metadata_df = pd.DataFrame(sample_data) | |
metadata_df.to_parquet(dataset_dir / "train.parquet", index=False) | |
metadata_df.to_parquet(dataset_dir / "metadata.parquet", index=False) | |
# Create dataset card | |
self._create_dataset_card(dataset_dir, num_samples) | |
print(f" β Dataset prepared with {num_samples} samples in {dataset_dir}") | |
return dataset_dir | |
def _create_dataset_card(self, dataset_dir: Path, num_samples: int): | |
"""Create a basic dataset card""" | |
card_content = f"""--- | |
license: mit | |
task_categories: | |
- text-to-speech | |
language: | |
- fa | |
tags: | |
- tts | |
- persian | |
- farsi | |
- speech-synthesis | |
size_categories: | |
- {self._get_size_category(num_samples)} | |
--- | |
# {TARGET_REPO.split('/')[-1]} | |
This dataset contains {num_samples} Persian TTS samples with the speaker "{SPEAKER_NAME}". | |
## Dataset Structure | |
- `dataset/`: HuggingFace dataset format with audio arrays | |
- `train.parquet`: Training split metadata | |
- `metadata.parquet`: General metadata file (same content as train.parquet) | |
**Metadata columns:** | |
- `audio`: Audio data with array, sampling_rate, and path | |
- `array`: Audio data as float array | |
- `sampling_rate`: Sample rate in Hz | |
- `path`: Relative path to audio file | |
- `file_name`: Relative path to audio files (e.g., "audio/filename.mp3") | |
- `sentence`: Transcription text in Persian | |
- `speaker`: Speaker identifier ("{SPEAKER_NAME}") | |
- `duration`: Audio duration in seconds | |
- `sample_rate`: Audio sample rate in Hz | |
## Usage | |
```python | |
from datasets import load_dataset | |
# Load the dataset | |
dataset = load_dataset("{self.target_repo}") | |
# Access audio and transcription | |
for item in dataset['train']: | |
audio_data = item['audio'] # Dict with 'array', 'sampling_rate', 'path' | |
audio_array = audio_data['array'] # Actual audio as numpy array | |
sample_rate = audio_data['sampling_rate'] # Sample rate | |
text = item['sentence'] # Transcription | |
speaker = item['speaker'] # Speaker ID | |
# You can also load with streaming for large datasets | |
dataset = load_dataset("{self.target_repo}", streaming=True) | |
for item in dataset['train']: | |
audio = item['audio']['array'] # Audio array directly | |
text = item['sentence'] # Transcription | |
``` | |
## Speaker | |
- **Speaker ID**: {SPEAKER_NAME} | |
- **Language**: Persian (Farsi) | |
- **Total Samples**: {num_samples} | |
Generated using the TTS annotation system. | |
""" | |
with open(dataset_dir / "README.md", 'w', encoding='utf-8') as f: | |
f.write(card_content) | |
def _get_size_category(self, num_samples: int) -> str: | |
"""Get size category for dataset card""" | |
if num_samples < 1000: | |
return "n<1K" | |
elif num_samples < 10000: | |
return "1K<n<10K" | |
elif num_samples < 100000: | |
return "10K<n<100K" | |
else: | |
return "100K<n<1M" | |
def upload_dataset(self, dataset_dir: Path): | |
"""Upload dataset using HuggingFace best practices""" | |
print(f"\nπ Uploading dataset to {self.target_repo}...") | |
try: | |
# Check if dataset directory exists in HF format | |
hf_dataset_dir = dataset_dir / "dataset" | |
if hf_dataset_dir.exists(): | |
print("π¦ Uploading HuggingFace dataset format...") | |
# Load and push the dataset | |
dataset = Dataset.load_from_disk(str(hf_dataset_dir)) | |
dataset.push_to_hub( | |
self.target_repo, | |
commit_message="Add TTS dataset with audio arrays" | |
) | |
print(f"β Dataset upload completed successfully!") | |
else: | |
# Fallback to folder upload | |
print("π Uploading as folder...") | |
self.api.upload_large_folder( | |
repo_id=self.target_repo, | |
repo_type="dataset", | |
folder_path=str(dataset_dir) | |
) | |
print(f"β Folder upload completed successfully!") | |
print(f"Dataset available at: https://huggingface.co/datasets/{self.target_repo}") | |
except Exception as e: | |
print(f"β Upload failed: {e}") | |
print("You can retry the upload or use the prepared dataset directory manually.") | |
print(f"Dataset directory: {dataset_dir}") | |
# Fallback to regular upload_folder with commit message | |
print("\nπ Trying fallback upload method...") | |
try: | |
self.api.upload_folder( | |
repo_id=self.target_repo, | |
repo_type="dataset", | |
folder_path=str(dataset_dir), | |
commit_message="Add TTS dataset with audio arrays" | |
) | |
print(f"β Fallback upload completed successfully!") | |
print(f"Dataset available at: https://huggingface.co/datasets/{self.target_repo}") | |
except Exception as fallback_error: | |
print(f"β Fallback upload also failed: {fallback_error}") | |
print(f"Manual upload required. Dataset directory: {dataset_dir}") | |
raise | |
def get_approved_annotations(): | |
"""Get all approved annotations from the database""" | |
connection = pymysql.connect(**DB_CONFIG) | |
try: | |
with connection.cursor(pymysql.cursors.DictCursor) as cursor: | |
# Query for approved annotations | |
query = """ | |
SELECT | |
a.annotated_sentence as sentence, | |
td.filename as audio_file_name | |
FROM annotations a | |
JOIN validations v ON a.id = v.annotation_id | |
JOIN tts_data td ON a.tts_data_id = td.id | |
WHERE v.validated = 1 | |
""" | |
cursor.execute(query) | |
results = cursor.fetchall() | |
print(f"Found {len(results)} approved annotations") | |
return results | |
finally: | |
connection.close() | |
def cleanup_temp_files(temp_dir: Path, keep_dataset: bool = True): | |
"""Clean up temporary files""" | |
if not keep_dataset and temp_dir.exists(): | |
shutil.rmtree(temp_dir) | |
print(f"π§Ή Cleaned up temporary directory: {temp_dir}") | |
else: | |
# Only clean up batch files, keep the dataset | |
batch_files = list(temp_dir.glob("batch_*.parquet")) | |
for batch_file in batch_files: | |
batch_file.unlink() | |
print(f"π§Ή Cleaned up {len(batch_files)} batch files") | |
def main(): | |
"""Main export function with improved error handling and performance""" | |
print("π Starting optimized TTS data export to Hugging Face...") | |
print(f"π Configuration:") | |
print(f" - Target repository: {TARGET_REPO}") | |
print(f" - Speaker: {SPEAKER_NAME}") | |
print(f" - Batch size: {BATCH_SIZE}") | |
print(f" - Cache directory: {CACHE_DIR}") | |
print(f" - Max concurrent downloads: {MAX_WORKERS}") | |
if OPTIMIZE_MEMORY: | |
print(f"π§ Memory Optimizations Enabled:") | |
print(f" - Target sample rate: {TARGET_SAMPLE_RATE or 'Original'}") | |
print(f" - Audio data type: {AUDIO_DTYPE}") | |
print(f" - Generator-based processing: {USE_GENERATOR}") | |
else: | |
print("β οΈ Memory optimizations disabled - may consume significant RAM") | |
try: | |
# Initialize components | |
cache_manager = CacheManager(CACHE_DIR) | |
downloader = AudioDownloader(AUDIO_BASE_URL, cache_manager, MAX_RETRIES) | |
processor = BatchProcessor(downloader, TEMP_DIR, BATCH_SIZE) | |
uploader = DatasetUploader(TEMP_DIR, TARGET_REPO) | |
# Get approved annotations | |
print("\nπ Fetching approved annotations from database...") | |
annotations = get_approved_annotations() | |
if not annotations: | |
print("β No approved annotations found!") | |
return | |
total_batches = (len(annotations) + BATCH_SIZE - 1) // BATCH_SIZE | |
print(f"π¦ Will process {len(annotations)} annotations in {total_batches} batches") | |
# Process annotations in batches | |
batch_files = [] | |
for i in range(0, len(annotations), BATCH_SIZE): | |
batch_id = i // BATCH_SIZE + 1 | |
batch_annotations = annotations[i:i + BATCH_SIZE] | |
batch_file = processor.process_batch(batch_annotations, batch_id) | |
if batch_file: | |
batch_files.append(batch_file) | |
if not batch_files: | |
print("β No batches were processed successfully!") | |
return | |
print(f"\nβ Successfully processed {len(batch_files)} batches") | |
# Prepare dataset structure | |
dataset_dir = uploader.prepare_dataset_structure() | |
# Login to HF | |
print("\nπ Logging in to Hugging Face...") | |
try: | |
login() # Will use HF_TOKEN env var or prompt for token | |
except Exception as e: | |
print(f"β HF login failed: {e}") | |
print("Make sure you have HF_TOKEN environment variable set or login manually") | |
return | |
# Upload dataset | |
uploader.upload_dataset(dataset_dir) | |
# Cleanup | |
cleanup_temp_files(Path(TEMP_DIR), keep_dataset=True) | |
print("\nπ Export completed successfully!") | |
print(f"π Final stats:") | |
print(f" - Total annotations processed: {len(annotations)}") | |
print(f" - Successful batches: {len(batch_files)}") | |
print(f" - Dataset URL: https://huggingface.co/datasets/{TARGET_REPO}") | |
print(f" - Local dataset copy: {dataset_dir}") | |
except KeyboardInterrupt: | |
print("\nβ οΈ Process interrupted by user") | |
print("π‘ You can resume by running the script again - cached files will be reused") | |
except Exception as e: | |
print(f"\nβ Error during export: {e}") | |
print("π‘ Check the error above and try again - cached files will be reused") | |
raise | |
if __name__ == "__main__": | |
main() | |