|
import os |
|
import shutil |
|
import subprocess |
|
from huggingface_hub import HfApi, create_repo |
|
from pathlib import Path |
|
import json |
|
import re |
|
import logging |
|
from typing import Any, Optional, Dict, List, Union, Tuple |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
def make_archive(source: str | Path, destination: str | Path): |
|
source = str(source) |
|
destination = str(destination) |
|
|
|
base = os.path.basename(destination) |
|
name = base.split('.')[0] |
|
format = base.split('.')[1] |
|
archive_from = os.path.dirname(source) |
|
archive_to = os.path.basename(source.strip(os.sep)) |
|
shutil.make_archive(name, format, archive_from, archive_to) |
|
shutil.move('%s.%s'%(name,format), destination) |
|
|
|
def get_video_fps(video_path: Path) -> Optional[str]: |
|
"""Get FPS information from video file using ffprobe |
|
|
|
Args: |
|
video_path: Path to video file |
|
|
|
Returns: |
|
FPS string (e.g. "24 FPS, ") or None if unable to determine |
|
""" |
|
try: |
|
cmd = [ |
|
'ffprobe', |
|
'-v', 'error', |
|
'-select_streams', 'v:0', |
|
'-show_entries', 'stream=avg_frame_rate', |
|
'-of', 'default=noprint_wrappers=1:nokey=1', |
|
str(video_path) |
|
] |
|
|
|
result = subprocess.run(cmd, capture_output=True, text=True) |
|
if result.returncode != 0: |
|
logger.warning(f"Error getting FPS for {video_path}: {result.stderr}") |
|
return None |
|
|
|
fps = result.stdout.strip() |
|
if '/' in fps: |
|
|
|
num, den = map(int, fps.split('/')) |
|
if den == 0: |
|
return None |
|
fps = str(round(num / den)) |
|
|
|
return f"{fps} FPS, " |
|
|
|
except Exception as e: |
|
logger.warning(f"Failed to get FPS for {video_path}: {e}") |
|
return None |
|
|
|
def extract_scene_info(filename: str) -> Tuple[str, Optional[int]]: |
|
"""Extract base name and scene number from filename |
|
|
|
Args: |
|
filename: Input filename like "my_cool_video_1___001.mp4" |
|
|
|
Returns: |
|
Tuple of (base_name, scene_number) |
|
e.g. ("my_cool_video_1", 1) |
|
""" |
|
|
|
match = re.search(r'(.+?)___(\d+)$', Path(filename).stem) |
|
if match: |
|
return match.group(1), int(match.group(2)) |
|
return Path(filename).stem, None |
|
|
|
def is_image_file(file_path: Path) -> bool: |
|
"""Check if file is an image based on extension |
|
|
|
Args: |
|
file_path: Path to check |
|
|
|
Returns: |
|
bool: True if file has image extension |
|
""" |
|
image_extensions = {'.jpg', '.jpeg', '.png', '.webp', '.avif', '.heic'} |
|
return file_path.suffix.lower() in image_extensions |
|
|
|
def is_video_file(file_path: Path) -> bool: |
|
"""Check if file is a video based on extension |
|
|
|
Args: |
|
file_path: Path to check |
|
|
|
Returns: |
|
bool: True if file has video extension |
|
""" |
|
video_extensions = {'.mp4', '.webm'} |
|
return file_path.suffix.lower() in video_extensions |
|
|
|
def parse_bool_env(env_value: Optional[str]) -> bool: |
|
"""Parse environment variable string to boolean |
|
|
|
Handles various true/false string representations: |
|
- True: "true", "True", "TRUE", "1", etc |
|
- False: "false", "False", "FALSE", "0", "", None |
|
""" |
|
if not env_value: |
|
return False |
|
return str(env_value).lower() in ('true', '1', 't', 'y', 'yes') |
|
|
|
def validate_model_repo(repo_id: str) -> Dict[str, str]: |
|
"""Validate HuggingFace model repository name |
|
|
|
Args: |
|
repo_id: Repository ID in format "username/model-name" |
|
|
|
Returns: |
|
Dict with error message if invalid, or None if valid |
|
""" |
|
if not repo_id: |
|
return {"error": "Repository name is required"} |
|
|
|
if "/" not in repo_id: |
|
return {"error": "Repository name must be in format username/model-name"} |
|
|
|
|
|
invalid_chars = set('<>:"/\\|?*') |
|
if any(c in repo_id for c in invalid_chars): |
|
return {"error": "Repository name contains invalid characters"} |
|
|
|
return {"error": None} |
|
|
|
def save_to_hub(model_path: Path, repo_id: str, token: str, commit_message: str = "Update model") -> bool: |
|
"""Save model files to Hugging Face Hub |
|
|
|
Args: |
|
model_path: Path to model files |
|
repo_id: Repository ID (username/model-name) |
|
token: HuggingFace API token |
|
commit_message: Optional commit message |
|
|
|
Returns: |
|
bool: True if successful, False if failed |
|
""" |
|
try: |
|
api = HfApi(token=token) |
|
|
|
|
|
validation = validate_model_repo(repo_id) |
|
if validation["error"]: |
|
return False |
|
|
|
|
|
try: |
|
create_repo(repo_id, token=token, repo_type="model", exist_ok=True) |
|
except Exception as e: |
|
print(f"Error creating repo: {e}") |
|
return False |
|
|
|
|
|
api.upload_folder( |
|
folder_path=str(model_path), |
|
repo_id=repo_id, |
|
repo_type="model", |
|
commit_message=commit_message |
|
) |
|
|
|
return True |
|
except Exception as e: |
|
print(f"Error uploading to hub: {e}") |
|
return False |
|
|
|
def parse_training_log(line: str) -> Dict: |
|
"""Parse a training log line for metrics |
|
|
|
Args: |
|
line: Log line from training output |
|
|
|
Returns: |
|
Dict with parsed metrics (epoch, step, loss, etc) |
|
""" |
|
metrics = {} |
|
|
|
try: |
|
|
|
if "step=" in line: |
|
step = int(line.split("step=")[1].split()[0].strip(",")) |
|
metrics["step"] = step |
|
|
|
if "epoch=" in line: |
|
epoch = int(line.split("epoch=")[1].split()[0].strip(",")) |
|
metrics["epoch"] = epoch |
|
|
|
if "loss=" in line: |
|
loss = float(line.split("loss=")[1].split()[0].strip(",")) |
|
metrics["loss"] = loss |
|
|
|
if "lr=" in line: |
|
lr = float(line.split("lr=")[1].split()[0].strip(",")) |
|
metrics["learning_rate"] = lr |
|
except: |
|
pass |
|
|
|
return metrics |
|
|
|
def format_size(size_bytes: int) -> str: |
|
"""Format bytes into human readable string with appropriate unit |
|
|
|
Args: |
|
size_bytes: Size in bytes |
|
|
|
Returns: |
|
Formatted string (e.g. "1.5 Gb") |
|
""" |
|
units = ['bytes', 'Kb', 'Mb', 'Gb', 'Tb'] |
|
unit_index = 0 |
|
size = float(size_bytes) |
|
|
|
while size >= 1024 and unit_index < len(units) - 1: |
|
size /= 1024 |
|
unit_index += 1 |
|
|
|
|
|
if unit_index == 0: |
|
return f"{int(size)} {units[unit_index]}" |
|
|
|
return f"{size:.1f} {units[unit_index]}" |
|
|
|
|
|
def count_media_files(path: Path) -> Tuple[int, int, int]: |
|
"""Count videos and images in directory |
|
|
|
Args: |
|
path: Directory to scan |
|
|
|
Returns: |
|
Tuple of (video_count, image_count, total_size) |
|
""" |
|
video_count = 0 |
|
image_count = 0 |
|
total_size = 0 |
|
|
|
for file in path.glob("*"): |
|
|
|
if file.name.startswith('.') or file.suffix.lower() == '.txt': |
|
continue |
|
|
|
if is_video_file(file): |
|
video_count += 1 |
|
total_size += file.stat().st_size |
|
elif is_image_file(file): |
|
image_count += 1 |
|
total_size += file.stat().st_size |
|
|
|
return video_count, image_count, total_size |
|
|
|
def format_media_title(action: str, video_count: int, image_count: int, total_size: int) -> str: |
|
"""Format title with media counts and size |
|
|
|
Args: |
|
action: Action (eg "split", "caption") |
|
video_count: Number of videos |
|
image_count: Number of images |
|
total_size: Total size in bytes |
|
|
|
Returns: |
|
Formatted title string |
|
""" |
|
parts = [] |
|
if image_count > 0: |
|
parts.append(f"{image_count:,} photo{'s' if image_count != 1 else ''}") |
|
if video_count > 0: |
|
parts.append(f"{video_count:,} video{'s' if video_count != 1 else ''}") |
|
|
|
if not parts: |
|
return f"## 0 files to {action} (0 bytes)" |
|
|
|
return f"## {' and '.join(parts)} to {action} ({format_size(total_size)})" |
|
|
|
def add_prefix_to_caption(caption: str, prefix: str) -> str: |
|
"""Add prefix to caption if not already present""" |
|
if not prefix or not caption: |
|
return caption |
|
if caption.startswith(prefix): |
|
return caption |
|
return f"{prefix}{caption}" |
|
|
|
def format_time(seconds: float) -> str: |
|
"""Format time duration in seconds to human readable string |
|
|
|
Args: |
|
seconds: Time in seconds |
|
|
|
Returns: |
|
Formatted string (e.g. "2h 30m 45s") |
|
""" |
|
hours = int(seconds // 3600) |
|
minutes = int((seconds % 3600) // 60) |
|
secs = int(seconds % 60) |
|
|
|
parts = [] |
|
if hours > 0: |
|
parts.append(f"{hours}h") |
|
if minutes > 0: |
|
parts.append(f"{minutes}m") |
|
if secs > 0 or not parts: |
|
parts.append(f"{secs}s") |
|
|
|
return " ".join(parts) |