Spaces:
Running
Running
| import gradio as gr | |
| from pathlib import Path | |
| import logging | |
| import shutil | |
| from typing import Any, Optional, Dict, List, Union, Tuple | |
| from ..config import ( | |
| STORAGE_PATH, TRAINING_PATH, STAGING_PATH, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN, MODEL_TYPES, | |
| DEFAULT_VALIDATION_NB_STEPS, | |
| DEFAULT_VALIDATION_HEIGHT, | |
| DEFAULT_VALIDATION_WIDTH, | |
| DEFAULT_VALIDATION_NB_FRAMES, | |
| DEFAULT_VALIDATION_FRAMERATE | |
| ) | |
| from .utils import get_video_fps, extract_scene_info, make_archive, is_image_file, is_video_file | |
| logger = logging.getLogger(__name__) | |
| def prepare_finetrainers_dataset() -> Tuple[Path, Path]: | |
| """Prepare a Finetrainers-compatible dataset structure | |
| Creates: | |
| training/ | |
| ├── prompt.txt # All captions, one per line | |
| ├── videos.txt # All video paths, one per line | |
| └── videos/ # Directory containing all mp4 files | |
| ├── 00000.mp4 | |
| ├── 00001.mp4 | |
| └── ... | |
| Returns: | |
| Tuple of (videos_file_path, prompts_file_path) | |
| """ | |
| # Verifies the videos subdirectory | |
| TRAINING_VIDEOS_PATH.mkdir(exist_ok=True) | |
| # Clear existing training lists | |
| for f in TRAINING_PATH.glob("*"): | |
| if f.is_file(): | |
| if f.name in ["videos.txt", "prompts.txt", "prompt.txt"]: | |
| f.unlink() | |
| videos_file = TRAINING_PATH / "videos.txt" | |
| prompts_file = TRAINING_PATH / "prompts.txt" # Finetrainers can use either prompts.txt or prompt.txt | |
| media_files = [] | |
| captions = [] | |
| # Process all video files from the videos subdirectory | |
| for idx, file in enumerate(sorted(TRAINING_VIDEOS_PATH.glob("*.mp4"))): | |
| caption_file = file.with_suffix('.txt') | |
| if caption_file.exists(): | |
| # Normalize caption to single line | |
| caption = caption_file.read_text().strip() | |
| caption = ' '.join(caption.split()) | |
| # Use relative path from training root | |
| relative_path = f"videos/{file.name}" | |
| media_files.append(relative_path) | |
| captions.append(caption) | |
| # Write files if we have content | |
| if media_files and captions: | |
| videos_file.write_text('\n'.join(media_files)) | |
| prompts_file.write_text('\n'.join(captions)) | |
| logger.info(f"Created dataset with {len(media_files)} video/caption pairs") | |
| else: | |
| logger.warning("No valid video/caption pairs found in training directory") | |
| return None, None | |
| # Verify file contents | |
| with open(videos_file) as vf: | |
| video_lines = [l.strip() for l in vf.readlines() if l.strip()] | |
| with open(prompts_file) as pf: | |
| prompt_lines = [l.strip() for l in pf.readlines() if l.strip()] | |
| if len(video_lines) != len(prompt_lines): | |
| logger.error(f"Mismatch in generated files: {len(video_lines)} videos vs {len(prompt_lines)} prompts") | |
| return None, None | |
| return videos_file, prompts_file | |
| def copy_files_to_training_dir(prompt_prefix: str) -> int: | |
| """Just copy files over, with no destruction""" | |
| gr.Info("Copying assets to the training dataset..") | |
| # Find files needing captions | |
| video_files = list(STAGING_PATH.glob("*.mp4")) | |
| image_files = [f for f in STAGING_PATH.glob("*") if is_image_file(f)] | |
| all_files = video_files + image_files | |
| nb_copied_pairs = 0 | |
| for file_path in all_files: | |
| caption = "" | |
| file_caption_path = file_path.with_suffix('.txt') | |
| if file_caption_path.exists(): | |
| logger.debug(f"Found caption file: {file_caption_path}") | |
| caption = file_caption_path.read_text() | |
| # Get parent caption if this is a clip | |
| parent_caption = "" | |
| if "___" in file_path.stem: | |
| parent_name, _ = extract_scene_info(file_path.stem) | |
| #print(f"parent_name is {parent_name}") | |
| parent_caption_path = STAGING_PATH / f"{parent_name}.txt" | |
| if parent_caption_path.exists(): | |
| logger.debug(f"Found parent caption file: {parent_caption_path}") | |
| parent_caption = parent_caption_path.read_text().strip() | |
| target_file_path = TRAINING_VIDEOS_PATH / file_path.name | |
| target_caption_path = target_file_path.with_suffix('.txt') | |
| if parent_caption and not caption.endswith(parent_caption): | |
| caption = f"{caption}\n{parent_caption}" | |
| # Add FPS information for videos | |
| if is_video_file(file_path) and caption: | |
| # Only add FPS if not already present | |
| if not any(f"FPS, " in line for line in caption.split('\n')): | |
| fps_info = get_video_fps(file_path) | |
| if fps_info: | |
| caption = f"{fps_info}{caption}" | |
| if prompt_prefix and not caption.startswith(prompt_prefix): | |
| caption = f"{prompt_prefix}{caption}" | |
| # make sure we only copy over VALID pairs | |
| if caption: | |
| try: | |
| target_caption_path.write_text(caption) | |
| shutil.copy2(file_path, target_file_path) | |
| nb_copied_pairs += 1 | |
| except Exception as e: | |
| print(f"failed to copy one of the pairs: {e}") | |
| pass | |
| prepare_finetrainers_dataset() | |
| gr.Info(f"Successfully generated the training dataset ({nb_copied_pairs} pairs)") | |
| return nb_copied_pairs | |
| # Add this function to finetrainers_utils.py or a suitable place | |
| def create_validation_config() -> Optional[Path]: | |
| """Create a validation configuration JSON file for Finetrainers | |
| Creates a validation dataset file with a subset of the training data | |
| Returns: | |
| Path to the validation JSON file, or None if no training files exist | |
| """ | |
| # Ensure training dataset exists | |
| if not TRAINING_VIDEOS_PATH.exists() or not any(TRAINING_VIDEOS_PATH.glob("*.mp4")): | |
| logger.warning("No training videos found for validation") | |
| return None | |
| # Get a subset of the training videos (up to 4) for validation | |
| training_videos = list(TRAINING_VIDEOS_PATH.glob("*.mp4")) | |
| validation_videos = training_videos[:min(4, len(training_videos))] | |
| if not validation_videos: | |
| logger.warning("No validation videos selected") | |
| return None | |
| # Create validation data entries | |
| validation_data = {"data": []} | |
| for video_path in validation_videos: | |
| # Get caption from matching text file | |
| caption_path = video_path.with_suffix('.txt') | |
| if not caption_path.exists(): | |
| logger.warning(f"Missing caption for {video_path}, skipping for validation") | |
| continue | |
| caption = caption_path.read_text().strip() | |
| # Get video dimensions and properties | |
| try: | |
| # Use the most common default resolution and settings | |
| data_entry = { | |
| "caption": caption, | |
| "image_path": "", # No input image for text-to-video | |
| "video_path": str(video_path), | |
| "num_inference_steps": DEFAULT_VALIDATION_NB_STEPS, | |
| "height": DEFAULT_VALIDATION_HEIGHT, | |
| "width": DEFAULT_VALIDATION_WIDTH, | |
| "num_frames": DEFAULT_VALIDATION_NB_FRAMES, | |
| "frame_rate": DEFAULT_VALIDATION_FRAMERATE | |
| } | |
| validation_data["data"].append(data_entry) | |
| except Exception as e: | |
| logger.warning(f"Error adding validation entry for {video_path}: {e}") | |
| if not validation_data["data"]: | |
| logger.warning("No valid validation entries created") | |
| return None | |
| # Write validation config to file | |
| validation_file = OUTPUT_PATH / "validation_config.json" | |
| with open(validation_file, 'w') as f: | |
| json.dump(validation_data, f, indent=2) | |
| logger.info(f"Created validation config with {len(validation_data['data'])} entries") | |
| return validation_file | |