|
""" |
|
WebDataset format handling for Video Model Studio |
|
""" |
|
|
|
import os |
|
import tarfile |
|
import tempfile |
|
import logging |
|
from pathlib import Path |
|
from typing import List, Dict, Tuple, Optional |
|
|
|
from ..utils import is_image_file, is_video_file, extract_scene_info |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
def is_webdataset_file(file_path: Path) -> bool: |
|
"""Check if file is a WebDataset tar file |
|
|
|
Args: |
|
file_path: Path to check |
|
|
|
Returns: |
|
bool: True if file has .tar extension |
|
""" |
|
return file_path.suffix.lower() == '.tar' |
|
|
|
def process_webdataset_shard( |
|
tar_path: Path, |
|
videos_output_dir: Path, |
|
staging_output_dir: Path |
|
) -> Tuple[int, int]: |
|
"""Process a WebDataset shard (tar file) extracting video/image and caption pairs |
|
|
|
Args: |
|
tar_path: Path to the WebDataset tar file |
|
videos_output_dir: Directory to store videos for splitting |
|
staging_output_dir: Directory to store images and captions |
|
|
|
Returns: |
|
Tuple of (video_count, image_count) |
|
""" |
|
video_count = 0 |
|
image_count = 0 |
|
|
|
try: |
|
|
|
grouped_files = {} |
|
|
|
|
|
with tarfile.open(tar_path, 'r') as tar: |
|
for member in tar.getmembers(): |
|
if member.isdir(): |
|
continue |
|
|
|
|
|
if os.path.basename(member.name).startswith('.'): |
|
continue |
|
|
|
|
|
file_path = Path(member.name) |
|
file_name = file_path.name |
|
|
|
|
|
|
|
prefix_parts = file_name.split('.', 1) |
|
if len(prefix_parts) < 2: |
|
|
|
continue |
|
|
|
prefix = prefix_parts[0] |
|
extension = '.' + prefix_parts[1] |
|
|
|
|
|
full_prefix = str(file_path.parent / prefix) if file_path.parent != Path('.') else prefix |
|
|
|
if full_prefix not in grouped_files: |
|
grouped_files[full_prefix] = [] |
|
|
|
grouped_files[full_prefix].append((member, extension)) |
|
|
|
|
|
with tarfile.open(tar_path, 'r') as tar: |
|
for prefix, members in grouped_files.items(): |
|
|
|
safe_prefix = Path(prefix).name |
|
|
|
|
|
media_file = None |
|
caption_file = None |
|
media_ext = None |
|
|
|
for member, ext in members: |
|
if ext.lower() in ['.jpg', '.jpeg', '.png', '.webp', '.avif', '.heic']: |
|
media_file = member |
|
media_ext = ext |
|
elif ext.lower() in ['.mp4', '.webm']: |
|
media_file = member |
|
media_ext = ext |
|
elif ext.lower() in ['.txt', '.caption', '.json', '.cls']: |
|
caption_file = member |
|
|
|
|
|
if media_file: |
|
|
|
is_video = media_ext.lower() in ['.mp4', '.webm'] |
|
|
|
|
|
target_dir = videos_output_dir if is_video else staging_output_dir |
|
|
|
|
|
target_filename = f"{safe_prefix}{media_ext}" |
|
target_path = target_dir / target_filename |
|
|
|
|
|
counter = 1 |
|
while target_path.exists(): |
|
target_path = target_dir / f"{safe_prefix}___{counter}{media_ext}" |
|
counter += 1 |
|
|
|
|
|
with open(target_path, 'wb') as f: |
|
f.write(tar.extractfile(media_file).read()) |
|
|
|
|
|
if caption_file: |
|
caption_text = tar.extractfile(caption_file).read().decode('utf-8', errors='ignore') |
|
|
|
|
|
caption_path = target_path.with_suffix('.txt') |
|
with open(caption_path, 'w', encoding='utf-8') as f: |
|
f.write(caption_text) |
|
|
|
|
|
if is_video: |
|
video_count += 1 |
|
else: |
|
image_count += 1 |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing WebDataset file {tar_path}: {e}") |
|
raise |
|
|
|
return video_count, image_count |