Spaces:
Running
Running
| """ | |
| WebDataset format handling for Video Model Studio | |
| """ | |
| import os | |
| import tarfile | |
| import tempfile | |
| import logging | |
| from pathlib import Path | |
| from typing import List, Dict, Tuple, Optional | |
| from ..utils import is_image_file, is_video_file, extract_scene_info | |
| logger = logging.getLogger(__name__) | |
| def is_webdataset_file(file_path: Path) -> bool: | |
| """Check if file is a WebDataset tar file | |
| Args: | |
| file_path: Path to check | |
| Returns: | |
| bool: True if file has .tar extension | |
| """ | |
| return file_path.suffix.lower() == '.tar' | |
| def process_webdataset_shard( | |
| tar_path: Path, | |
| videos_output_dir: Path, | |
| staging_output_dir: Path | |
| ) -> Tuple[int, int]: | |
| """Process a WebDataset shard (tar file) extracting video/image and caption pairs | |
| Args: | |
| tar_path: Path to the WebDataset tar file | |
| videos_output_dir: Directory to store videos for splitting | |
| staging_output_dir: Directory to store images and captions | |
| Returns: | |
| Tuple of (video_count, image_count) | |
| """ | |
| video_count = 0 | |
| image_count = 0 | |
| print(f"videos_output_dir = {videos_output_dir}") | |
| print(f"staging_output_dir = {staging_output_dir}") | |
| try: | |
| # Dictionary to store grouped files by prefix | |
| grouped_files = {} | |
| # First pass: collect and group files by prefix | |
| with tarfile.open(tar_path, 'r') as tar: | |
| for member in tar.getmembers(): | |
| if member.isdir(): | |
| continue | |
| # Skip hidden files | |
| if os.path.basename(member.name).startswith('.'): | |
| continue | |
| # Extract file prefix (everything up to the first dot after the last slash) | |
| file_path = Path(member.name) | |
| file_name = file_path.name | |
| # Get prefix (filename without extensions) | |
| # For WebDataset, the prefix is everything up to the first dot | |
| prefix_parts = file_name.split('.', 1) | |
| if len(prefix_parts) < 2: | |
| # No extension, skip | |
| continue | |
| prefix = prefix_parts[0] | |
| extension = '.' + prefix_parts[1] | |
| # Include directory in the prefix to keep samples grouped correctly | |
| full_prefix = str(file_path.parent / prefix) if file_path.parent != Path('.') else prefix | |
| if full_prefix not in grouped_files: | |
| grouped_files[full_prefix] = [] | |
| grouped_files[full_prefix].append((member, extension)) | |
| # Second pass: extract and process grouped files | |
| with tarfile.open(tar_path, 'r') as tar: | |
| for prefix, members in grouped_files.items(): | |
| # Create safe filename from prefix | |
| safe_prefix = Path(prefix).name | |
| # Find media and caption files | |
| media_file = None | |
| caption_file = None | |
| media_ext = None | |
| for member, ext in members: | |
| if ext.lower() in ['.jpg', '.jpeg', '.png', '.webp', '.avif', '.heic']: | |
| media_file = member | |
| media_ext = ext | |
| elif ext.lower() in ['.mp4', '.webm']: | |
| media_file = member | |
| media_ext = ext | |
| elif ext.lower() in ['.txt', '.caption', '.json', '.cls']: | |
| caption_file = member | |
| # If we have a media file, process it | |
| if media_file: | |
| # Determine if it's video or image | |
| is_video = media_ext.lower() in ['.mp4', '.webm'] | |
| # Choose target directory based on media type | |
| target_dir = videos_output_dir if is_video else staging_output_dir | |
| # Create target filename | |
| target_filename = f"{safe_prefix}{media_ext}" | |
| target_path = target_dir / target_filename | |
| # If file already exists, add number suffix | |
| counter = 1 | |
| while target_path.exists(): | |
| target_path = target_dir / f"{safe_prefix}___{counter}{media_ext}" | |
| counter += 1 | |
| # Extract media file | |
| with open(target_path, 'wb') as f: | |
| f.write(tar.extractfile(media_file).read()) | |
| # If we have a caption file, extract it too | |
| if caption_file: | |
| caption_text = tar.extractfile(caption_file).read().decode('utf-8', errors='ignore') | |
| # Save caption with media file extension | |
| caption_path = target_path.with_suffix('.txt') | |
| with open(caption_path, 'w', encoding='utf-8') as f: | |
| f.write(caption_text) | |
| # Update counters | |
| if is_video: | |
| video_count += 1 | |
| else: | |
| image_count += 1 | |
| except Exception as e: | |
| logger.error(f"Error processing WebDataset file {tar_path}: {e}") | |
| raise | |
| return video_count, image_count |