Spaces:
Running
Running
| import os | |
| import sys | |
| import json | |
| import time | |
| import shutil | |
| import gradio as gr | |
| from pathlib import Path | |
| from datetime import datetime | |
| import subprocess | |
| import signal | |
| import psutil | |
| import tempfile | |
| import zipfile | |
| import logging | |
| import traceback | |
| import threading | |
| import fcntl | |
| import select | |
| from typing import Any, Optional, Dict, List, Union, Tuple | |
| from huggingface_hub import upload_folder, create_repo | |
| from vms.config import ( | |
| TrainingConfig, RESOLUTION_OPTIONS, SD_TRAINING_BUCKETS, MD_TRAINING_BUCKETS, | |
| STORAGE_PATH, HF_API_TOKEN, | |
| MODEL_TYPES, TRAINING_TYPES, MODEL_VERSIONS, | |
| DEFAULT_NB_TRAINING_STEPS, DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS, | |
| DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P, | |
| DEFAULT_LEARNING_RATE, | |
| DEFAULT_LORA_RANK, DEFAULT_LORA_ALPHA, | |
| DEFAULT_LORA_RANK_STR, DEFAULT_LORA_ALPHA_STR, | |
| DEFAULT_SEED, DEFAULT_RESHAPE_MODE, | |
| DEFAULT_REMOVE_COMMON_LLM_CAPTION_PREFIXES, | |
| DEFAULT_DATASET_TYPE, DEFAULT_PROMPT_PREFIX, | |
| DEFAULT_MIXED_PRECISION, DEFAULT_TRAINING_TYPE, | |
| DEFAULT_NUM_GPUS, | |
| DEFAULT_MAX_GPUS, | |
| DEFAULT_PRECOMPUTATION_ITEMS, | |
| DEFAULT_NB_TRAINING_STEPS, | |
| DEFAULT_NB_LR_WARMUP_STEPS, | |
| DEFAULT_AUTO_RESUME, | |
| DEFAULT_CONTROL_TYPE, DEFAULT_TRAIN_QK_NORM, | |
| DEFAULT_FRAME_CONDITIONING_TYPE, DEFAULT_FRAME_CONDITIONING_INDEX, | |
| DEFAULT_FRAME_CONDITIONING_CONCATENATE_MASK, | |
| generate_model_project_id | |
| ) | |
| from vms.utils import ( | |
| get_available_gpu_count, | |
| make_archive, | |
| parse_training_log, | |
| is_image_file, | |
| is_video_file, | |
| prepare_finetrainers_dataset, | |
| copy_files_to_training_dir | |
| ) | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) | |
| class TrainingService: | |
| def __init__(self, app=None): | |
| """Initialize the training service | |
| Args: | |
| app: Reference to main application | |
| """ | |
| self.app = app | |
| self.file_lock = threading.Lock() | |
| self.file_handler = None | |
| self.setup_logging() | |
| self.ensure_valid_ui_state_file() | |
| # Start background cleanup task | |
| self._cleanup_stop_event = threading.Event() | |
| self._cleanup_thread = threading.Thread(target=self._background_cleanup_task, daemon=True) | |
| self._cleanup_thread.start() | |
| logger.info("Training service initialized") | |
| def setup_logging(self): | |
| """Set up logging with proper handler management""" | |
| global logger | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) | |
| # Remove any existing handlers to avoid duplicates | |
| logger.handlers.clear() | |
| # Add stdout handler | |
| stdout_handler = logging.StreamHandler(sys.stdout) | |
| stdout_handler.setFormatter(logging.Formatter( | |
| '%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| )) | |
| logger.addHandler(stdout_handler) | |
| # Add file handler if log file is accessible | |
| try: | |
| # Close existing file handler if it exists | |
| if self.file_handler: | |
| self.file_handler.close() | |
| logger.removeHandler(self.file_handler) | |
| self.file_handler = logging.FileHandler(str(self.app.log_file_path)) | |
| self.file_handler.setFormatter(logging.Formatter( | |
| '%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| )) | |
| logger.addHandler(self.file_handler) | |
| except Exception as e: | |
| logger.warning(f"Could not set up log file: {e}") | |
| def clear_logs(self) -> None: | |
| """Clear log file with proper handler cleanup""" | |
| try: | |
| # Remove and close the file handler | |
| if self.file_handler: | |
| logger.removeHandler(self.file_handler) | |
| self.file_handler.close() | |
| self.file_handler = None | |
| # Delete the file if it exists | |
| if self.app.log_file_path.exists(): | |
| self.app.log_file_path.unlink() | |
| # Recreate logging setup | |
| self.setup_logging() | |
| self.append_log("Log file cleared and recreated") | |
| except Exception as e: | |
| logger.error(f"Error clearing logs: {e}") | |
| raise | |
| def __del__(self): | |
| """Cleanup when the service is destroyed""" | |
| if self.file_handler: | |
| self.file_handler.close() | |
| def update_project_state(self, state_updates: Dict[str, Any]) -> None: | |
| """Update project state in UI state file | |
| Args: | |
| state_updates: Dict of state values to update | |
| """ | |
| current_state = self.load_ui_state() | |
| current_state.update(state_updates) | |
| self.save_ui_state(current_state) | |
| logger.info(f"Updated project state: {state_updates}") | |
| def save_ui_state(self, values: Dict[str, Any]) -> None: | |
| """Save current UI state to file with validation""" | |
| # Use a lock to prevent concurrent writes | |
| with self.file_lock: | |
| # Validate values before saving | |
| validated_values = {} | |
| default_state = self.get_default_ui_state() | |
| # Copy default values first | |
| validated_values = default_state.copy() | |
| # Update with provided values, converting types as needed | |
| for key, value in values.items(): | |
| if key in default_state: | |
| if key == "train_steps": | |
| try: | |
| validated_values[key] = int(value) | |
| except (ValueError, TypeError): | |
| validated_values[key] = default_state[key] | |
| elif key == "batch_size": | |
| try: | |
| validated_values[key] = int(value) | |
| except (ValueError, TypeError): | |
| validated_values[key] = default_state[key] | |
| elif key == "learning_rate": | |
| try: | |
| validated_values[key] = float(value) | |
| except (ValueError, TypeError): | |
| validated_values[key] = default_state[key] | |
| elif key == "save_iterations": | |
| try: | |
| validated_values[key] = int(value) | |
| except (ValueError, TypeError): | |
| validated_values[key] = default_state[key] | |
| elif key == "lora_rank" and value not in ["16", "32", "64", "128", "256", "512", "1024"]: | |
| validated_values[key] = default_state[key] | |
| elif key == "lora_alpha" and value not in ["16", "32", "64", "128", "256", "512", "1024"]: | |
| validated_values[key] = default_state[key] | |
| else: | |
| validated_values[key] = value | |
| try: | |
| # First verify we can serialize to JSON | |
| json_data = json.dumps(validated_values, indent=2) | |
| # Write to the file | |
| with open(self.app.output_ui_state_file, 'w') as f: | |
| f.write(json_data) | |
| logger.debug(f"UI state saved successfully") | |
| except Exception as e: | |
| logger.error(f"Error saving UI state: {str(e)}") | |
| def _backup_and_recreate_ui_state(self, default_state): | |
| """Backup the corrupted UI state file and create a new one with defaults""" | |
| try: | |
| # Create a backup with timestamp | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| backup_file = self.app.output_ui_state_file.with_suffix(f'.json.bak_{timestamp}') | |
| # Copy the corrupted file | |
| shutil.copy2(self.app.output_ui_state_file, backup_file) | |
| logger.info(f"Backed up corrupted UI state file to {backup_file}") | |
| except Exception as backup_error: | |
| logger.error(f"Failed to backup corrupted UI state file: {str(backup_error)}") | |
| # Create a new file with default values | |
| self.save_ui_state(default_state) | |
| logger.info("Created new UI state file with default values after error") | |
| def get_default_ui_state(self) -> Dict[str, Any]: | |
| """Get a default UI state with robust error handling""" | |
| default_state = { | |
| "model_project_id": self.app.current_model_project_id if self.app.current_model_project_id else generate_model_project_id(), | |
| "project_status": self.app.current_model_project_status if self.app.current_model_project_status else "draft", | |
| "model_type": list(MODEL_TYPES.keys())[0], | |
| "model_version": "", | |
| "training_type": list(TRAINING_TYPES.keys())[0], | |
| "lora_rank": DEFAULT_LORA_RANK_STR, | |
| "lora_alpha": DEFAULT_LORA_ALPHA_STR, | |
| "train_steps": DEFAULT_NB_TRAINING_STEPS, | |
| "batch_size": DEFAULT_BATCH_SIZE, | |
| "learning_rate": DEFAULT_LEARNING_RATE, | |
| "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS, | |
| "resolution": list(RESOLUTION_OPTIONS.keys())[0], | |
| "num_gpus": DEFAULT_NUM_GPUS, | |
| "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS, | |
| "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS, | |
| "auto_resume": DEFAULT_AUTO_RESUME, | |
| # Control parameters | |
| "control_type": DEFAULT_CONTROL_TYPE, | |
| "train_qk_norm": DEFAULT_TRAIN_QK_NORM, | |
| "frame_conditioning_type": DEFAULT_FRAME_CONDITIONING_TYPE, | |
| "frame_conditioning_index": DEFAULT_FRAME_CONDITIONING_INDEX, | |
| "frame_conditioning_concatenate_mask": DEFAULT_FRAME_CONDITIONING_CONCATENATE_MASK | |
| } | |
| return default_state | |
| def load_ui_state(self) -> Dict[str, Any]: | |
| """Load saved UI state with robust error handling""" | |
| default_state = self.get_default_ui_state() | |
| # Use lock for reading too to avoid reading during a write | |
| with self.file_lock: | |
| if not self.app.output_ui_state_file.exists(): | |
| logger.info("UI state file does not exist, using default values") | |
| return default_state | |
| try: | |
| # First check if the file is empty | |
| file_size = self.app.output_ui_state_file.stat().st_size | |
| if file_size == 0: | |
| logger.warning("UI state file exists but is empty, using default values") | |
| return default_state | |
| with open(self.app.output_ui_state_file, 'r') as f: | |
| file_content = f.read().strip() | |
| if not file_content: | |
| logger.warning("UI state file is empty or contains only whitespace, using default values") | |
| return default_state | |
| try: | |
| saved_state = json.loads(file_content) | |
| except json.JSONDecodeError as e: | |
| logger.error(f"Error parsing UI state JSON: {str(e)}") | |
| # Instead of showing the error, recreate the file with defaults | |
| self._backup_and_recreate_ui_state(default_state) | |
| return default_state | |
| # Clean up model type if it contains " (LoRA)" suffix | |
| if "model_type" in saved_state and " (LoRA)" in saved_state["model_type"]: | |
| saved_state["model_type"] = saved_state["model_type"].replace(" (LoRA)", "") | |
| logger.info(f"Removed (LoRA) suffix from saved model type: {saved_state['model_type']}") | |
| # Convert numeric values to appropriate types | |
| if "train_steps" in saved_state: | |
| try: | |
| saved_state["train_steps"] = int(saved_state["train_steps"]) | |
| except (ValueError, TypeError): | |
| saved_state["train_steps"] = default_state["train_steps"] | |
| logger.warning("Invalid train_steps value, using default") | |
| if "batch_size" in saved_state: | |
| try: | |
| saved_state["batch_size"] = int(saved_state["batch_size"]) | |
| except (ValueError, TypeError): | |
| saved_state["batch_size"] = default_state["batch_size"] | |
| logger.warning("Invalid batch_size value, using default") | |
| if "learning_rate" in saved_state: | |
| try: | |
| saved_state["learning_rate"] = float(saved_state["learning_rate"]) | |
| except (ValueError, TypeError): | |
| saved_state["learning_rate"] = default_state["learning_rate"] | |
| logger.warning("Invalid learning_rate value, using default") | |
| if "save_iterations" in saved_state: | |
| try: | |
| saved_state["save_iterations"] = int(saved_state["save_iterations"]) | |
| except (ValueError, TypeError): | |
| saved_state["save_iterations"] = default_state["save_iterations"] | |
| logger.warning("Invalid save_iterations value, using default") | |
| # Make sure we have all keys (in case structure changed) | |
| merged_state = default_state.copy() | |
| merged_state.update({k: v for k, v in saved_state.items() if v is not None}) | |
| # Validate model_type is in available choices | |
| if merged_state["model_type"] not in MODEL_TYPES: | |
| # Try to map from internal name | |
| model_found = False | |
| for display_name, internal_name in MODEL_TYPES.items(): | |
| if internal_name == merged_state["model_type"]: | |
| merged_state["model_type"] = display_name | |
| model_found = True | |
| break | |
| # If still not found, use default | |
| if not model_found: | |
| merged_state["model_type"] = default_state["model_type"] | |
| logger.warning(f"Invalid model type in saved state, using default") | |
| # Validate model_version is appropriate for model_type | |
| if "model_type" in merged_state and "model_version" in merged_state: | |
| model_internal_type = MODEL_TYPES.get(merged_state["model_type"]) | |
| if model_internal_type: | |
| valid_versions = MODEL_VERSIONS.get(model_internal_type, {}).keys() | |
| if merged_state["model_version"] not in valid_versions: | |
| # Set to default for this model type | |
| from vms.ui.project.tabs.train_tab import TrainTab | |
| train_tab = TrainTab(None) # Temporary instance just for the helper method | |
| merged_state["model_version"] = train_tab.get_default_model_version(saved_state["model_type"]) | |
| logger.warning(f"Invalid model version for {merged_state['model_type']}, using default") | |
| # Validate training_type is in available choices | |
| if merged_state["training_type"] not in TRAINING_TYPES: | |
| # Try to map from internal name | |
| training_found = False | |
| for display_name, internal_name in TRAINING_TYPES.items(): | |
| if internal_name == merged_state["training_type"]: | |
| merged_state["training_type"] = display_name | |
| training_found = True | |
| break | |
| # If still not found, use default | |
| if not training_found: | |
| merged_state["training_type"] = default_state["training_type"] | |
| logger.warning(f"Invalid training type in saved state, using default") | |
| # Validate resolution is in available choices | |
| if "resolution" in merged_state and merged_state["resolution"] not in RESOLUTION_OPTIONS: | |
| merged_state["resolution"] = default_state["resolution"] | |
| logger.warning(f"Invalid resolution in saved state, using default") | |
| # Validate lora_rank is in allowed values | |
| if merged_state.get("lora_rank") not in ["16", "32", "64", "128", "256", "512", "1024"]: | |
| merged_state["lora_rank"] = default_state["lora_rank"] | |
| logger.warning(f"Invalid lora_rank in saved state, using default") | |
| # Validate lora_alpha is in allowed values | |
| if merged_state.get("lora_alpha") not in ["16", "32", "64", "128", "256", "512", "1024"]: | |
| merged_state["lora_alpha"] = default_state["lora_alpha"] | |
| logger.warning(f"Invalid lora_alpha in saved state, using default") | |
| return merged_state | |
| except Exception as e: | |
| logger.error(f"Error loading UI state: {str(e)}") | |
| # If anything goes wrong, backup and recreate | |
| self._backup_and_recreate_ui_state(default_state) | |
| return default_state | |
| def ensure_valid_ui_state_file(self): | |
| """Ensure UI state file exists and is valid JSON""" | |
| default_state = self.get_default_ui_state() | |
| # If file doesn't exist, create it with default values | |
| if not self.app.output_ui_state_file.exists(): | |
| logger.info("Creating new UI state file with default values") | |
| self.save_ui_state(default_state) | |
| return | |
| # Check if file is valid JSON | |
| try: | |
| # First check if the file is empty | |
| file_size = self.app.output_ui_state_file.stat().st_size | |
| if file_size == 0: | |
| logger.warning("UI state file exists but is empty, recreating with default values") | |
| self.save_ui_state(default_state) | |
| return | |
| with open(self.app.output_ui_state_file, 'r') as f: | |
| file_content = f.read().strip() | |
| if not file_content: | |
| logger.warning("UI state file is empty or contains only whitespace, recreating with default values") | |
| self.save_ui_state(default_state) | |
| return | |
| # Try to parse the JSON content | |
| try: | |
| saved_state = json.loads(file_content) | |
| logger.debug("UI state file validation successful") | |
| except json.JSONDecodeError as e: | |
| # JSON parsing failed, backup and recreate | |
| logger.error(f"Error parsing UI state JSON: {str(e)}") | |
| self._backup_and_recreate_ui_state(default_state) | |
| return | |
| except Exception as e: | |
| # Any other error (file access, etc) | |
| logger.error(f"Error checking UI state file: {str(e)}") | |
| self._backup_and_recreate_ui_state(default_state) | |
| return | |
| # Modify save_session to also store the UI state at training start | |
| def save_session(self, params: Dict) -> None: | |
| """Save training session parameters""" | |
| session_data = { | |
| "timestamp": datetime.now().isoformat(), | |
| "params": params, | |
| "status": self.get_status(), | |
| # Add UI state at the time training started | |
| "initial_ui_state": self.load_ui_state() | |
| } | |
| with open(self.app.output_session_file, 'w') as f: | |
| json.dump(session_data, f, indent=2) | |
| def load_session(self) -> Optional[Dict]: | |
| """Load saved training session""" | |
| if self.app.output_session_file.exists(): | |
| try: | |
| with open(self.app.output_session_file, 'r') as f: | |
| return json.load(f) | |
| except json.JSONDecodeError: | |
| return None | |
| return None | |
| def get_status(self) -> Dict: | |
| """Get current training status""" | |
| default_status = {'status': 'stopped', 'message': 'No training in progress'} | |
| if not self.app.output_status_file.exists(): | |
| return default_status | |
| try: | |
| with open(self.app.output_status_file, 'r') as f: | |
| status = json.load(f) | |
| # Check if process is actually running | |
| if self.app.output_pid_file.exists(): | |
| with open(self.app.output_pid_file, 'r') as f: | |
| pid = int(f.read().strip()) | |
| if not psutil.pid_exists(pid): | |
| # Process died unexpectedly | |
| if status['status'] == 'training': | |
| # Only log this once by checking if we've already updated the status | |
| if not hasattr(self, '_process_terminated_logged') or not self._process_terminated_logged: | |
| self.append_log("Training process terminated unexpectedly") | |
| self._process_terminated_logged = True | |
| status['status'] = 'error' | |
| status['message'] = 'Training process terminated unexpectedly' | |
| # Update the status file to avoid repeated logging | |
| with open(self.app.output_status_file, 'w') as f: | |
| json.dump(status, f, indent=2) | |
| else: | |
| status['status'] = 'stopped' | |
| status['message'] = 'Training process not found' | |
| return status | |
| except (json.JSONDecodeError, ValueError): | |
| return default_status | |
| def get_logs(self, max_lines: int = 100) -> str: | |
| """Get training logs with line limit""" | |
| if self.app.output_log_file.exists(): | |
| with open(self.app.output_log_file, 'r') as f: | |
| lines = f.readlines() | |
| return ''.join(lines[-max_lines:]) | |
| return "" | |
| def append_log(self, message: str) -> None: | |
| """Append message to log file and logger""" | |
| timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| with open(self.app.output_log_file, 'a') as f: | |
| f.write(f"[{timestamp}] {message}\n") | |
| logger.info(message) | |
| def clear_logs(self) -> None: | |
| """Clear log file""" | |
| if self.app.output_log_file.exists(): | |
| self.app.output_log_file.unlink() | |
| self.append_log("Log file cleared") | |
| def validate_training_config(self, config: TrainingConfig, model_type: str) -> Optional[str]: | |
| """Validate training configuration""" | |
| logger.info(f"Validating config for {model_type}") | |
| try: | |
| # Basic validation | |
| if not config.output_dir: | |
| return "Output directory not specified" | |
| # For the dataset_config validation, we now expect it to be a JSON file | |
| dataset_config_path = Path(config.data_root) | |
| if not dataset_config_path.exists(): | |
| return f"Dataset config file does not exist: {dataset_config_path}" | |
| # Check the JSON file is valid | |
| try: | |
| with open(dataset_config_path, 'r') as f: | |
| dataset_json = json.load(f) | |
| # Basic validation of the JSON structure | |
| if "datasets" not in dataset_json or not isinstance(dataset_json["datasets"], list) or len(dataset_json["datasets"]) == 0: | |
| return "Invalid dataset config JSON: missing or empty 'datasets' array" | |
| except json.JSONDecodeError: | |
| return f"Invalid JSON in dataset config file: {dataset_config_path}" | |
| except Exception as e: | |
| return f"Error reading dataset config file: {str(e)}" | |
| # Check training videos directory exists | |
| if not self.app.training_videos_path.exists(): | |
| return f"Training videos directory does not exist: {self.app.training_videos_path}" | |
| # Validate file counts | |
| video_count = len(list(self.app.training_videos_path.glob('*.mp4'))) | |
| if video_count == 0: | |
| return "No training files found" | |
| # Model-specific validation | |
| if model_type == "hunyuan_video": | |
| if config.batch_size > 2: | |
| return "Hunyuan model recommended batch size is 1-2" | |
| if not config.gradient_checkpointing: | |
| return "Gradient checkpointing is required for Hunyuan model" | |
| elif model_type == "ltx_video": | |
| if config.batch_size > 4: | |
| return "LTX model recommended batch size is 1-4" | |
| elif model_type == "wan": | |
| if config.batch_size > 4: | |
| return "Wan model recommended batch size is 1-4" | |
| logger.info(f"Config validation passed with {video_count} training files") | |
| return None | |
| except Exception as e: | |
| logger.error(f"Error during config validation: {str(e)}") | |
| return f"Configuration validation failed: {str(e)}" | |
| def start_training( | |
| self, | |
| model_type: str, | |
| lora_rank: str, | |
| lora_alpha: str, | |
| train_steps: int, | |
| batch_size: int, | |
| learning_rate: float, | |
| save_iterations: int, | |
| repo_id: str, | |
| training_type: str = DEFAULT_TRAINING_TYPE, | |
| model_version: str = "", | |
| resume_from_checkpoint: Optional[str] = None, | |
| num_gpus: int = DEFAULT_NUM_GPUS, | |
| precomputation_items: int = DEFAULT_PRECOMPUTATION_ITEMS, | |
| lr_warmup_steps: int = DEFAULT_NB_LR_WARMUP_STEPS, | |
| progress: Optional[gr.Progress] = None, | |
| custom_prompt_prefix: Optional[str] = None, | |
| ) -> Tuple[str, str]: | |
| """Start training with finetrainers""" | |
| self.clear_logs() | |
| if not model_type: | |
| raise ValueError("model_type cannot be empty") | |
| if model_type not in MODEL_TYPES.values(): | |
| raise ValueError(f"Invalid model_type: {model_type}. Must be one of {list(MODEL_TYPES.values())}") | |
| if training_type not in TRAINING_TYPES.values(): | |
| raise ValueError(f"Invalid training_type: {training_type}. Must be one of {list(TRAINING_TYPES.values())}") | |
| # Check if we're resuming or starting new | |
| is_resuming = resume_from_checkpoint is not None | |
| log_prefix = "Resuming" if is_resuming else "Initializing" | |
| logger.info(f"{log_prefix} training with model_type={model_type}, training_type={training_type}") | |
| # Update progress if available | |
| #if progress: | |
| # progress(0.15, desc="Setting up training configuration") | |
| try: | |
| current_dir = Path(__file__).parent.parent.parent.absolute() # Go up to project root | |
| train_script = current_dir / "train.py" | |
| if not train_script.exists(): | |
| # Try alternative locations | |
| alt_locations = [ | |
| current_dir.parent / "train.py", # One level up from project root | |
| Path("/home/user/app/train.py"), # Absolute path | |
| Path("train.py") # Current working directory | |
| ] | |
| for alt_path in alt_locations: | |
| if alt_path.exists(): | |
| train_script = alt_path | |
| logger.info(f"Found train.py at alternative location: {train_script}") | |
| break | |
| if not train_script.exists(): | |
| error_msg = f"Training script not found at {train_script} or any alternative locations" | |
| logger.error(error_msg) | |
| return error_msg, "Training script not found" | |
| # Log paths for debugging | |
| logger.info("Current working directory: %s", current_dir) | |
| logger.info("Training script path: %s", train_script) | |
| logger.info("Training data path: %s", self.app.training_path) | |
| # Update progress | |
| #if progress: | |
| # progress(0.2, desc="Preparing training dataset") | |
| videos_file, prompts_file = prepare_finetrainers_dataset() | |
| if videos_file is None or prompts_file is None: | |
| error_msg = "Failed to generate training lists" | |
| logger.error(error_msg) | |
| return error_msg, "Training preparation failed" | |
| video_count = sum(1 for _ in open(videos_file)) | |
| logger.info(f"Generated training lists with {video_count} files") | |
| if video_count == 0: | |
| error_msg = "No training files found" | |
| logger.error(error_msg) | |
| return error_msg, "No training data available" | |
| # Update progress | |
| #if progress: | |
| # progress(0.25, desc="Creating dataset configuration") | |
| # Get resolution configuration from UI state | |
| ui_state = self.load_ui_state() | |
| resolution_option = ui_state.get("resolution", list(RESOLUTION_OPTIONS.keys())[0]) | |
| training_buckets_name = RESOLUTION_OPTIONS.get(resolution_option, "SD_TRAINING_BUCKETS") | |
| # Determine which buckets to use based on the selected resolution | |
| if training_buckets_name == "SD_TRAINING_BUCKETS": | |
| training_buckets = SD_TRAINING_BUCKETS | |
| elif training_buckets_name == "MD_TRAINING_BUCKETS": | |
| training_buckets = MD_TRAINING_BUCKETS | |
| else: | |
| training_buckets = SD_TRAINING_BUCKETS # Default fallback | |
| # Determine flow weighting scheme based on model type | |
| if model_type == "hunyuan_video": | |
| flow_weighting_scheme = "none" | |
| else: | |
| flow_weighting_scheme = "logit_normal" | |
| # Use the custom prompt prefix passed as parameter | |
| # Clean the prefix - remove trailing comma, space or comma+space | |
| if custom_prompt_prefix: | |
| custom_prompt_prefix = custom_prompt_prefix.rstrip(', ') | |
| # Create a proper dataset configuration JSON file | |
| dataset_config_file = self.app.output_path / "dataset_config.json" | |
| # Determine appropriate ID token based on model type and custom prefix | |
| id_token = custom_prompt_prefix # Use custom prefix as the primary id_token | |
| # Only use default ID tokens if no custom prefix is provided | |
| if not id_token: | |
| id_token = DEFAULT_PROMPT_PREFIX | |
| dataset_config = { | |
| "datasets": [ | |
| { | |
| "data_root": str(self.app.training_path), | |
| "dataset_type": DEFAULT_DATASET_TYPE, | |
| "id_token": id_token, | |
| "video_resolution_buckets": [[f, h, w] for f, h, w in training_buckets], | |
| "reshape_mode": DEFAULT_RESHAPE_MODE, | |
| "remove_common_llm_caption_prefixes": DEFAULT_REMOVE_COMMON_LLM_CAPTION_PREFIXES, | |
| } | |
| ] | |
| } | |
| # Write the dataset config to file | |
| with open(dataset_config_file, 'w') as f: | |
| json.dump(dataset_config, f, indent=2) | |
| logger.info(f"Created dataset configuration file at {dataset_config_file}") | |
| # Get config for selected model type with preset buckets | |
| if model_type == "hunyuan_video": | |
| if training_type == "lora": | |
| config = TrainingConfig.hunyuan_video_lora( | |
| data_path=str(self.app.training_path), | |
| output_path=str(self.app.output_path), | |
| buckets=training_buckets | |
| ) | |
| else: | |
| # Hunyuan doesn't support full finetune in our UI yet | |
| error_msg = "Full finetune is not supported for Hunyuan Video due to memory limitations" | |
| logger.error(error_msg) | |
| return error_msg, "Training configuration error" | |
| elif model_type == "ltx_video": | |
| if training_type == "lora": | |
| config = TrainingConfig.ltx_video_lora( | |
| data_path=str(self.app.training_path), | |
| output_path=str(self.app.output_path), | |
| buckets=training_buckets | |
| ) | |
| else: | |
| config = TrainingConfig.ltx_video_full_finetune( | |
| data_path=str(self.app.training_path), | |
| output_path=str(self.app.output_path), | |
| buckets=training_buckets | |
| ) | |
| elif model_type == "wan": | |
| if training_type == "lora": | |
| config = TrainingConfig.wan_lora( | |
| data_path=str(self.app.training_path), | |
| output_path=str(self.app.output_path), | |
| buckets=training_buckets | |
| ) | |
| else: | |
| error_msg = "Full finetune for Wan is not yet supported in this UI" | |
| logger.error(error_msg) | |
| return error_msg, "Training configuration error" | |
| else: | |
| error_msg = f"Unsupported model type: {model_type}" | |
| logger.error(error_msg) | |
| return error_msg, "Unsupported model" | |
| # Create validation dataset if needed | |
| validation_file = None | |
| #if enable_validation: # Add a parameter to control this | |
| # validation_file = create_validation_config(self.app.training_videos_path, self.app.output_path) | |
| # if validation_file: | |
| # config_args.extend([ | |
| # "--validation_dataset_file", str(validation_file), | |
| # "--validation_steps", "500" # Set this to a suitable value | |
| # ]) | |
| # Update with UI parameters | |
| config.train_steps = int(train_steps) | |
| config.batch_size = int(batch_size) | |
| config.lr = float(learning_rate) | |
| config.checkpointing_steps = int(save_iterations) | |
| config.training_type = training_type | |
| config.flow_weighting_scheme = flow_weighting_scheme | |
| config.lr_warmup_steps = int(lr_warmup_steps) | |
| # Update the NUM_GPUS variable and CUDA_VISIBLE_DEVICES | |
| num_gpus = min(num_gpus, get_available_gpu_count()) | |
| if num_gpus <= 0: | |
| num_gpus = 1 | |
| # Generate CUDA_VISIBLE_DEVICES string | |
| visible_devices = ",".join([str(i) for i in range(num_gpus)]) | |
| config.data_root = str(dataset_config_file) | |
| # Update LoRA parameters if using LoRA training type | |
| if training_type == "lora" or training_type == "control-lora": | |
| config.lora_rank = int(lora_rank) | |
| config.lora_alpha = int(lora_alpha) | |
| # Update Control parameters if using control training types | |
| if training_type in ["control-lora", "control-full-finetune"]: | |
| # Get control parameters from UI state | |
| current_state = self.load_ui_state() | |
| # Add control-specific parameters | |
| control_type = current_state.get("control_type", DEFAULT_CONTROL_TYPE) | |
| train_qk_norm = current_state.get("train_qk_norm", DEFAULT_TRAIN_QK_NORM) | |
| frame_conditioning_type = current_state.get("frame_conditioning_type", DEFAULT_FRAME_CONDITIONING_TYPE) | |
| frame_conditioning_index = current_state.get("frame_conditioning_index", DEFAULT_FRAME_CONDITIONING_INDEX) | |
| frame_conditioning_concatenate_mask = current_state.get("frame_conditioning_concatenate_mask", DEFAULT_FRAME_CONDITIONING_CONCATENATE_MASK) | |
| # Map boolean from UI state to command line args | |
| config_args.extend([ | |
| "--control_type", control_type, | |
| ]) | |
| if train_qk_norm: | |
| config_args.append("--train_qk_norm") | |
| config_args.extend([ | |
| "--frame_conditioning_type", frame_conditioning_type, | |
| "--frame_conditioning_index", str(frame_conditioning_index) | |
| ]) | |
| if frame_conditioning_concatenate_mask: | |
| config_args.append("--frame_conditioning_concatenate_mask") | |
| # Update with resume_from_checkpoint if provided | |
| if resume_from_checkpoint: | |
| # Validate checkpoints and find a valid one to resume from | |
| valid_checkpoint = self.validate_and_find_valid_checkpoint() | |
| if valid_checkpoint: | |
| config.resume_from_checkpoint = "latest" | |
| checkpoint_step = int(Path(valid_checkpoint).name.split("_")[-1]) | |
| self.append_log(f"Resuming from validated checkpoint at step {checkpoint_step}") | |
| logger.info(f"Resuming from validated checkpoint: {valid_checkpoint}") | |
| else: | |
| error_msg = "No valid checkpoints found to resume from" | |
| logger.error(error_msg) | |
| self.append_log(error_msg) | |
| return error_msg, "No valid checkpoints available" | |
| # Common settings for both models | |
| config.mixed_precision = DEFAULT_MIXED_PRECISION | |
| config.seed = DEFAULT_SEED | |
| config.gradient_checkpointing = True | |
| config.enable_slicing = True | |
| config.enable_tiling = True | |
| config.caption_dropout_p = DEFAULT_CAPTION_DROPOUT_P | |
| config.precomputation_items = precomputation_items | |
| validation_error = self.validate_training_config(config, model_type) | |
| if validation_error: | |
| error_msg = f"Configuration validation failed: {validation_error}" | |
| logger.error(error_msg) | |
| return "Error: Invalid configuration", error_msg | |
| # Convert config to command line arguments for all launcher types | |
| config_args = config.to_args_list() | |
| logger.debug("Generated args list: %s", config_args) | |
| # Use different launch commands based on model type | |
| # For Wan models, use torchrun instead of accelerate launch | |
| if model_type == "wan": | |
| # Configure torchrun parameters | |
| torchrun_args = [ | |
| "torchrun", | |
| "--standalone", | |
| "--nproc_per_node=" + str(num_gpus), | |
| "--nnodes=1", | |
| "--rdzv_backend=c10d", | |
| "--rdzv_endpoint=localhost:0", | |
| str(train_script) | |
| ] | |
| # Additional args needed for torchrun | |
| config_args.extend([ | |
| "--parallel_backend", "ptd", | |
| "--pp_degree", "1", | |
| "--dp_degree", "1", | |
| "--dp_shards", "1", | |
| "--cp_degree", "1", | |
| "--tp_degree", "1" | |
| ]) | |
| # Log the full command for debugging | |
| command_str = ' '.join(torchrun_args + config_args) | |
| self.append_log(f"Command: {command_str}") | |
| logger.info(f"Executing command: {command_str}") | |
| launch_args = torchrun_args | |
| else: | |
| # For other models, use accelerate launch as before | |
| # Determine the appropriate accelerate config file based on num_gpus | |
| accelerate_config = None | |
| if num_gpus == 1: | |
| accelerate_config = "accelerate_configs/uncompiled_1.yaml" | |
| elif num_gpus == 2: | |
| accelerate_config = "accelerate_configs/uncompiled_2.yaml" | |
| elif num_gpus == 4: | |
| accelerate_config = "accelerate_configs/uncompiled_4.yaml" | |
| elif num_gpus == 8: | |
| accelerate_config = "accelerate_configs/uncompiled_8.yaml" | |
| else: | |
| # Default to 1 GPU config if no matching config is found | |
| accelerate_config = "accelerate_configs/uncompiled_1.yaml" | |
| num_gpus = 1 | |
| visible_devices = "0" | |
| # Configure accelerate parameters | |
| accelerate_args = [ | |
| "accelerate", "launch", | |
| "--config_file", accelerate_config, | |
| "--gpu_ids", visible_devices, | |
| "--mixed_precision=bf16", | |
| "--num_processes=" + str(num_gpus), | |
| "--num_machines=1", | |
| "--dynamo_backend=no", | |
| str(train_script) | |
| ] | |
| # Log the full command for debugging | |
| command_str = ' '.join(accelerate_args + config_args) | |
| self.append_log(f"Command: {command_str}") | |
| logger.info(f"Executing command: {command_str}") | |
| launch_args = accelerate_args | |
| # Set environment variables | |
| env = os.environ.copy() | |
| env["NCCL_P2P_DISABLE"] = "1" | |
| env["TORCH_NCCL_ENABLE_MONITORING"] = "0" | |
| env["WANDB_MODE"] = "offline" | |
| env["HF_API_TOKEN"] = HF_API_TOKEN | |
| env["FINETRAINERS_LOG_LEVEL"] = "DEBUG" # Added for better debugging | |
| env["CUDA_VISIBLE_DEVICES"] = visible_devices | |
| #if progress: | |
| # progress(0.9, desc="Launching training process") | |
| # Start the training process | |
| process = subprocess.Popen( | |
| launch_args + config_args, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| start_new_session=True, | |
| env=env, | |
| cwd=str(current_dir), | |
| bufsize=1, | |
| universal_newlines=True | |
| ) | |
| logger.info(f"Started process with PID: {process.pid}") | |
| with open(self.app.output_pid_file, 'w') as f: | |
| f.write(str(process.pid)) | |
| # Get current UI state for all parameters | |
| current_state = self.load_ui_state() | |
| # Build session data | |
| session_data = { | |
| "model_type": model_type, | |
| "model_version": model_version, | |
| "training_type": training_type, | |
| "lora_rank": lora_rank, | |
| "lora_alpha": lora_alpha, | |
| "train_steps": train_steps, | |
| "batch_size": batch_size, | |
| "learning_rate": learning_rate, | |
| "save_iterations": save_iterations, | |
| "num_gpus": num_gpus, | |
| "precomputation_items": precomputation_items, | |
| "lr_warmup_steps": lr_warmup_steps, | |
| "repo_id": repo_id, | |
| "start_time": datetime.now().isoformat() | |
| } | |
| # Add control parameters if relevant | |
| if training_type in ["control-lora", "control-full-finetune"]: | |
| session_data.update({ | |
| "control_type": current_state.get("control_type", DEFAULT_CONTROL_TYPE), | |
| "train_qk_norm": current_state.get("train_qk_norm", DEFAULT_TRAIN_QK_NORM), | |
| "frame_conditioning_type": current_state.get("frame_conditioning_type", DEFAULT_FRAME_CONDITIONING_TYPE), | |
| "frame_conditioning_index": current_state.get("frame_conditioning_index", DEFAULT_FRAME_CONDITIONING_INDEX), | |
| "frame_conditioning_concatenate_mask": current_state.get("frame_conditioning_concatenate_mask", DEFAULT_FRAME_CONDITIONING_CONCATENATE_MASK) | |
| }) | |
| # Save session | |
| self.save_session(session_data) | |
| # Update initial training status | |
| total_steps = int(train_steps) | |
| self.save_status( | |
| state='training', | |
| step=0, | |
| total_steps=total_steps, | |
| loss=0.0, | |
| message='Training started', | |
| repo_id=repo_id, | |
| model_type=model_type, | |
| training_type=training_type | |
| ) | |
| # Start monitoring process output | |
| self._start_log_monitor(process) | |
| success_msg = f"Started {training_type} training for {model_type} model" | |
| self.append_log(success_msg) | |
| logger.info(success_msg) | |
| # Final progress update - now we'll track it through the log monitor | |
| #if progress: | |
| # progress(1.0, desc="Training started successfully") | |
| return success_msg, self.get_logs() | |
| except Exception as e: | |
| error_msg = f"Error {'resuming' if is_resuming else 'starting'} training: {str(e)}" | |
| self.append_log(error_msg) | |
| logger.exception("Training startup failed") | |
| traceback.print_exc() | |
| return f"Error {'resuming' if is_resuming else 'starting'} training", error_msg | |
| def stop_training(self) -> Tuple[str, str]: | |
| """Stop training process""" | |
| if not self.app.output_pid_file.exists(): | |
| return "No training process found", self.get_logs() | |
| try: | |
| with open(self.app.output_pid_file, 'r') as f: | |
| pid = int(f.read().strip()) | |
| if psutil.pid_exists(pid): | |
| os.killpg(os.getpgid(pid), signal.SIGTERM) | |
| if self.app.output_pid_file.exists(): | |
| self.app.output_pid_file.unlink() | |
| self.append_log("Training process stopped") | |
| self.save_status(state='stopped', message='Training stopped') | |
| return "Training stopped successfully", self.get_logs() | |
| except Exception as e: | |
| error_msg = f"Error stopping training: {str(e)}" | |
| self.append_log(error_msg) | |
| if self.app.output_pid_file.exists(): | |
| self.app.output_pid_file.unlink() | |
| return "Error stopping training", error_msg | |
| def pause_training(self) -> Tuple[str, str]: | |
| """Pause training process by sending SIGUSR1""" | |
| if not self.is_training_running(): | |
| return "No training process found", self.get_logs() | |
| try: | |
| with open(self.app.output_pid_file, 'r') as f: | |
| pid = int(f.read().strip()) | |
| if psutil.pid_exists(pid): | |
| os.kill(pid, signal.SIGUSR1) # Signal to pause | |
| self.save_status(state='paused', message='Training paused') | |
| self.append_log("Training paused") | |
| return "Training paused", self.get_logs() | |
| except Exception as e: | |
| error_msg = f"Error pausing training: {str(e)}" | |
| self.append_log(error_msg) | |
| return "Error pausing training", error_msg | |
| def resume_training(self) -> Tuple[str, str]: | |
| """Resume training process by sending SIGUSR2""" | |
| if not self.is_training_running(): | |
| return "No training process found", self.get_logs() | |
| try: | |
| with open(self.app.output_pid_file, 'r') as f: | |
| pid = int(f.read().strip()) | |
| if psutil.pid_exists(pid): | |
| os.kill(pid, signal.SIGUSR2) # Signal to resume | |
| self.save_status(state='training', message='Training resumed') | |
| self.append_log("Training resumed") | |
| return "Training resumed", self.get_logs() | |
| except Exception as e: | |
| error_msg = f"Error resuming training: {str(e)}" | |
| self.append_log(error_msg) | |
| return "Error resuming training", error_msg | |
| def is_training_running(self) -> bool: | |
| """Check if training is currently running""" | |
| if not self.app.output_pid_file.exists(): | |
| return False | |
| try: | |
| with open(self.app.output_pid_file, 'r') as f: | |
| pid = int(f.read().strip()) | |
| # Check if process exists AND is a Python process running train.py | |
| if psutil.pid_exists(pid): | |
| try: | |
| process = psutil.Process(pid) | |
| cmdline = process.cmdline() | |
| # Check if it's a Python process running train.py | |
| return any('train.py' in cmd for cmd in cmdline) | |
| except (psutil.NoSuchProcess, psutil.AccessDenied): | |
| return False | |
| return False | |
| except: | |
| return False | |
| def validate_and_find_valid_checkpoint(self) -> Optional[str]: | |
| """Validate checkpoint directories and find the most recent valid one | |
| Returns: | |
| Path to valid checkpoint directory or None if no valid checkpoint found | |
| """ | |
| # Find all checkpoint directories | |
| checkpoints = list(self.app.output_path.glob("finetrainers_step_*")) | |
| if not checkpoints: | |
| logger.info("No checkpoint directories found") | |
| return None | |
| # Sort by step number in descending order (latest first) | |
| sorted_checkpoints = sorted(checkpoints, key=lambda x: int(x.name.split("_")[-1]), reverse=True) | |
| corrupted_checkpoints = [] | |
| for checkpoint_dir in sorted_checkpoints: | |
| step_num = int(checkpoint_dir.name.split("_")[-1]) | |
| logger.info(f"Validating checkpoint at step {step_num}: {checkpoint_dir}") | |
| # Check if the .metadata file exists (indicator of complete checkpoint) | |
| metadata_file = checkpoint_dir / ".metadata" | |
| if not metadata_file.exists(): | |
| logger.warning(f"Checkpoint {checkpoint_dir.name} is corrupted: missing .metadata file") | |
| corrupted_checkpoints.append(checkpoint_dir) | |
| continue | |
| # .metadata file exists, checkpoint is considered valid | |
| # We don't read the file contents to avoid encoding/parsing issues | |
| logger.info(f"Checkpoint {checkpoint_dir.name} is valid") | |
| # Clean up any corrupted checkpoints we found before this valid one | |
| if corrupted_checkpoints: | |
| self.cleanup_corrupted_checkpoints(corrupted_checkpoints) | |
| return str(checkpoint_dir) | |
| # If we reach here, all checkpoints are corrupted | |
| if corrupted_checkpoints: | |
| logger.error("All checkpoint directories are corrupted") | |
| self.cleanup_corrupted_checkpoints(corrupted_checkpoints) | |
| return None | |
| def cleanup_corrupted_checkpoints(self, corrupted_checkpoints: List[Path]) -> None: | |
| """Remove corrupted checkpoint directories | |
| Args: | |
| corrupted_checkpoints: List of corrupted checkpoint directory paths | |
| """ | |
| for checkpoint_dir in corrupted_checkpoints: | |
| try: | |
| step_num = int(checkpoint_dir.name.split("_")[-1]) | |
| logger.info(f"Removing corrupted checkpoint at step {step_num}: {checkpoint_dir}") | |
| shutil.rmtree(checkpoint_dir) | |
| self.append_log(f"Removed corrupted checkpoint: {checkpoint_dir.name}") | |
| except Exception as e: | |
| logger.error(f"Failed to remove corrupted checkpoint {checkpoint_dir}: {e}") | |
| self.append_log(f"Failed to remove corrupted checkpoint {checkpoint_dir.name}: {e}") | |
| def cleanup_old_lora_weights(self, max_to_keep: int = 2) -> None: | |
| """Remove old LoRA weight directories, keeping only the most recent ones | |
| Args: | |
| max_to_keep: Maximum number of LoRA weight directories to keep (default: 2) | |
| """ | |
| lora_weights_path = self.app.output_path / "lora_weights" | |
| if not lora_weights_path.exists(): | |
| logger.debug("LoRA weights directory does not exist, nothing to clean up") | |
| return | |
| # Find all LoRA weight directories (should be named with step numbers) | |
| lora_dirs = [] | |
| for item in lora_weights_path.iterdir(): | |
| if item.is_dir() and item.name.isdigit(): | |
| lora_dirs.append(item) | |
| if len(lora_dirs) <= max_to_keep: | |
| logger.debug(f"Found {len(lora_dirs)} LoRA weight directories, no cleanup needed (keeping {max_to_keep})") | |
| return | |
| # Sort by step number (directory name) in descending order (newest first) | |
| lora_dirs_sorted = sorted(lora_dirs, key=lambda x: int(x.name), reverse=True) | |
| # Keep the most recent max_to_keep directories, remove the rest | |
| dirs_to_keep = lora_dirs_sorted[:max_to_keep] | |
| dirs_to_remove = lora_dirs_sorted[max_to_keep:] | |
| logger.info(f"Cleaning up old LoRA weights: keeping {len(dirs_to_keep)}, removing {len(dirs_to_remove)}") | |
| self.append_log(f"Cleaning up old LoRA weights: keeping latest {max_to_keep} directories") | |
| for lora_dir in dirs_to_remove: | |
| try: | |
| step_num = int(lora_dir.name) | |
| logger.info(f"Removing old LoRA weights at step {step_num}: {lora_dir}") | |
| shutil.rmtree(lora_dir) | |
| self.append_log(f"Removed old LoRA weights: step {step_num}") | |
| except Exception as e: | |
| logger.error(f"Failed to remove old LoRA weights {lora_dir}: {e}") | |
| self.append_log(f"Failed to remove old LoRA weights {lora_dir.name}: {e}") | |
| # Log what we kept | |
| kept_steps = [int(d.name) for d in dirs_to_keep] | |
| kept_steps.sort(reverse=True) | |
| logger.info(f"Kept LoRA weights for steps: {kept_steps}") | |
| self.append_log(f"Kept LoRA weights for steps: {kept_steps}") | |
| def _background_cleanup_task(self) -> None: | |
| """Background task that runs every 10 minutes to clean up old LoRA weights""" | |
| cleanup_interval = 600 # 10 minutes in seconds | |
| logger.info("Started background LoRA cleanup task (runs every 10 minutes)") | |
| while not self._cleanup_stop_event.is_set(): | |
| try: | |
| # Wait for 10 minutes or until stop event is set | |
| if self._cleanup_stop_event.wait(timeout=cleanup_interval): | |
| break # Stop event was set | |
| # Only run cleanup if we have an output path | |
| if hasattr(self.app, 'output_path') and self.app.output_path: | |
| lora_weights_path = self.app.output_path / "lora_weights" | |
| # Only cleanup if the directory exists and has content | |
| if lora_weights_path.exists(): | |
| lora_dirs = [d for d in lora_weights_path.iterdir() if d.is_dir() and d.name.isdigit()] | |
| if len(lora_dirs) > 2: | |
| logger.info(f"Background cleanup: Found {len(lora_dirs)} LoRA weight directories, cleaning up old ones") | |
| self.cleanup_old_lora_weights(max_to_keep=2) | |
| else: | |
| logger.debug(f"Background cleanup: Found {len(lora_dirs)} LoRA weight directories, no cleanup needed") | |
| except Exception as e: | |
| logger.error(f"Background LoRA cleanup task error: {e}") | |
| # Continue running despite errors | |
| logger.info("Background LoRA cleanup task stopped") | |
| def stop_background_cleanup(self) -> None: | |
| """Stop the background cleanup task""" | |
| if hasattr(self, '_cleanup_stop_event'): | |
| self._cleanup_stop_event.set() | |
| if hasattr(self, '_cleanup_thread') and self._cleanup_thread.is_alive(): | |
| self._cleanup_thread.join(timeout=5) | |
| logger.info("Background cleanup task stopped") | |
| def recover_interrupted_training(self) -> Dict[str, Any]: | |
| """Attempt to recover interrupted training | |
| Returns: | |
| Dict with recovery status and UI updates | |
| """ | |
| status = self.get_status() | |
| ui_updates = {} | |
| # Check for any valid checkpoints, even if status doesn't indicate training | |
| valid_checkpoint = self.validate_and_find_valid_checkpoint() | |
| has_checkpoints = valid_checkpoint is not None | |
| # If status indicates training but process isn't running, or if we have checkpoints | |
| # and no active training process, try to recover | |
| if (status.get('status') in ['training', 'paused'] and not self.is_training_running()) or \ | |
| (has_checkpoints and not self.is_training_running()): | |
| logger.info("Detected interrupted training session or existing checkpoints, attempting to recover...") | |
| # Get the latest checkpoint | |
| last_session = self.load_session() | |
| if not last_session: | |
| logger.warning("No session data found for recovery, but will check for checkpoints") | |
| # Try to create a default session based on UI state if we have checkpoints | |
| if has_checkpoints: | |
| ui_state = self.load_ui_state() | |
| # Create a default session using UI state values | |
| last_session = { | |
| "params": { | |
| "model_type": MODEL_TYPES.get(ui_state.get("model_type", list(MODEL_TYPES.keys())[0])), | |
| "model_version": ui_state.get("model_version", ""), | |
| "training_type": TRAINING_TYPES.get(ui_state.get("training_type", list(TRAINING_TYPES.keys())[0])), | |
| "lora_rank": ui_state.get("lora_rank", DEFAULT_LORA_RANK_STR), | |
| "lora_alpha": ui_state.get("lora_alpha", DEFAULT_LORA_ALPHA_STR), | |
| "train_steps": ui_state.get("train_steps", DEFAULT_NB_TRAINING_STEPS), | |
| "batch_size": ui_state.get("batch_size", DEFAULT_BATCH_SIZE), | |
| "learning_rate": ui_state.get("learning_rate", DEFAULT_LEARNING_RATE), | |
| "save_iterations": ui_state.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS), | |
| "resolution": ui_state.get("resolution", list(RESOLUTION_OPTIONS.keys())[0]), | |
| "repo_id": "", # Default empty repo ID, | |
| "auto_resume": ui_state.get("auto_resume", DEFAULT_AUTO_RESUME) | |
| } | |
| } | |
| logger.info("Created default session from UI state for recovery") | |
| else: | |
| logger.warning(f"No checkpoints found for recovery") | |
| # Set buttons for no active training | |
| ui_updates = { | |
| "start_btn": {"interactive": True, "variant": "primary", "value": "Start Training"}, | |
| "stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"}, | |
| "delete_checkpoints_btn": {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"}, | |
| "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False} | |
| } | |
| return {"status": "idle", "message": "No training in progress", "ui_updates": ui_updates} | |
| # Use the valid checkpoint we found | |
| latest_checkpoint = None | |
| checkpoint_step = 0 | |
| if has_checkpoints and valid_checkpoint: | |
| checkpoint_step = int(Path(valid_checkpoint).name.split("_")[-1]) | |
| logger.info(f"Found valid checkpoint at step {checkpoint_step}") | |
| # both options are valid, but imho it is easier to just return "latest" | |
| # under the hood Finetrainers will convert ("latest") to (-1) | |
| #latest_checkpoint = int(checkpoint_step) | |
| latest_checkpoint = "latest" | |
| else: | |
| logger.warning("No checkpoints found for recovery") | |
| # Set buttons for no active training | |
| ui_updates = { | |
| "start_btn": {"interactive": True, "variant": "primary", "value": "Start Training"}, | |
| "stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"}, | |
| "delete_checkpoints_btn": {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"}, | |
| "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False} | |
| } | |
| return {"status": "error", "message": "No checkpoints found", "ui_updates": ui_updates} | |
| # Extract parameters from the saved session (not current UI state) | |
| # This ensures we use the original training parameters | |
| params = last_session.get('params', {}) | |
| # Map internal model type back to display name for UI | |
| model_type_internal = params.get('model_type') | |
| model_type_display = model_type_internal | |
| # Find the display name that maps to our internal model type | |
| for display_name, internal_name in MODEL_TYPES.items(): | |
| if internal_name == model_type_internal: | |
| model_type_display = display_name | |
| logger.info(f"Mapped internal model type '{model_type_internal}' to display name '{model_type_display}'") | |
| break | |
| # Get training type (default to LoRA if not present in saved session) | |
| training_type_internal = params.get('training_type', 'lora') | |
| training_type_display = next((disp for disp, val in TRAINING_TYPES.items() if val == training_type_internal), list(TRAINING_TYPES.keys())[0]) | |
| # Add UI updates to restore the training parameters in the UI | |
| # This shows the user what values are being used for the resumed training | |
| ui_updates.update({ | |
| "model_type": model_type_display, | |
| "model_version": params.get('model_version', ''), | |
| "training_type": training_type_display, | |
| "lora_rank": params.get('lora_rank', DEFAULT_LORA_RANK_STR), | |
| "lora_alpha": params.get('lora_alpha', DEFAULT_LORA_ALPHA_STR), | |
| "train_steps": params.get('train_steps', DEFAULT_NB_TRAINING_STEPS), | |
| "batch_size": params.get('batch_size', DEFAULT_BATCH_SIZE), | |
| "learning_rate": params.get('learning_rate', DEFAULT_LEARNING_RATE), | |
| "save_iterations": params.get('save_iterations', DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS), | |
| "resolution": params.get('resolution', list(RESOLUTION_OPTIONS.keys())[0]), | |
| "auto_resume": params.get("auto_resume", DEFAULT_AUTO_RESUME) | |
| }) | |
| # Check if we should auto-recover (immediate restart) | |
| ui_state = self.load_ui_state() | |
| auto_recover = ui_state.get("auto_resume", DEFAULT_AUTO_RESUME) | |
| logger.info(f"Auto-resume is {'enabled' if auto_recover else 'disabled'}") | |
| if auto_recover: | |
| try: | |
| result = self.start_training( | |
| model_type=model_type_internal, | |
| lora_rank=params.get('lora_rank', DEFAULT_LORA_RANK_STR), | |
| lora_alpha=params.get('lora_alpha', DEFAULT_LORA_ALPHA_STR), | |
| train_steps=params.get('train_steps', DEFAULT_NB_TRAINING_STEPS), | |
| batch_size=params.get('batch_size', DEFAULT_BATCH_SIZE), | |
| learning_rate=params.get('learning_rate', DEFAULT_LEARNING_RATE), | |
| save_iterations=params.get('save_iterations', DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS), | |
| model_version=params.get('model_version', ''), | |
| repo_id=params.get('repo_id', ''), | |
| training_type=training_type_internal, | |
| resume_from_checkpoint="latest" | |
| ) | |
| # Set buttons for active training | |
| ui_updates.update({ | |
| "start_btn": {"interactive": False, "variant": "secondary", "value": "Start over a new training"}, | |
| "stop_btn": {"interactive": True, "variant": "primary", "value": "Stop at Last Checkpoint"}, | |
| "delete_checkpoints_btn": {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"}, | |
| "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False} | |
| }) | |
| return { | |
| "status": "recovered", | |
| "message": f"Training resumed from checkpoint {checkpoint_step}", | |
| "result": result, | |
| "ui_updates": ui_updates | |
| } | |
| except Exception as e: | |
| logger.error(f"Failed to auto-resume training: {str(e)}") | |
| # Set buttons for manual recovery | |
| ui_updates.update({ | |
| "start_btn": {"interactive": True, "variant": "primary", "value": "Start over a new training"}, | |
| "stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"}, | |
| "delete_checkpoints_btn": {"interactive": True, "variant": "stop", "value": "Delete All Checkpoints"}, | |
| "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False} | |
| }) | |
| return {"status": "error", "message": f"Failed to auto-resume: {str(e)}", "ui_updates": ui_updates} | |
| else: | |
| # Set up UI for manual recovery | |
| ui_updates.update({ | |
| "start_btn": {"interactive": True, "variant": "primary", "value": "Start over a new training"}, | |
| "stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"}, | |
| "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False} | |
| }) | |
| return {"status": "ready_to_recover", "message": f"Ready to resume from checkpoint {checkpoint_step}", "ui_updates": ui_updates} | |
| elif self.is_training_running(): | |
| # Process is still running, set buttons accordingly | |
| ui_updates = { | |
| "start_btn": {"interactive": False, "variant": "secondary", "value": "Start over a new training" if has_checkpoints else "Start Training"}, | |
| "stop_btn": {"interactive": True, "variant": "primary", "value": "Stop at Last Checkpoint"}, | |
| "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}, | |
| "delete_checkpoints_btn": {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"} | |
| } | |
| return {"status": "running", "message": "Training process is running", "ui_updates": ui_updates} | |
| else: | |
| # No training process, set buttons to default state | |
| button_text = "Start over a new training" if has_checkpoints else "Start Training" | |
| ui_updates = { | |
| "start_btn": {"interactive": True, "variant": "primary", "value": button_text}, | |
| "stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"}, | |
| "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}, | |
| "delete_checkpoints_btn": {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"} | |
| } | |
| return {"status": "idle", "message": "No training in progress", "ui_updates": ui_updates} | |
| def delete_all_checkpoints(self) -> str: | |
| """Delete all checkpoints in the output directory. | |
| Returns: | |
| Status message | |
| """ | |
| if self.is_training_running(): | |
| return "Cannot delete checkpoints while training is running. Stop training first." | |
| try: | |
| # Find all checkpoint directories | |
| checkpoints = list(self.app.output_path.glob("finetrainers_step_*")) | |
| if not checkpoints: | |
| return "No checkpoints found to delete." | |
| # Delete each checkpoint directory | |
| for checkpoint in checkpoints: | |
| if checkpoint.is_dir(): | |
| shutil.rmtree(checkpoint) | |
| # Also delete session.json which contains previous training info | |
| if self.app.output_session_file.exists(): | |
| self.app.output_session_file.unlink() | |
| # Reset status file to idle | |
| self.save_status(state='idle', message='No training in progress') | |
| self.append_log(f"Deleted {len(checkpoints)} checkpoint(s)") | |
| return f"Successfully deleted {len(checkpoints)} checkpoint(s)" | |
| except Exception as e: | |
| error_msg = f"Error deleting checkpoints: {str(e)}" | |
| self.append_log(error_msg) | |
| return error_msg | |
| def clear_training_data(self) -> str: | |
| """Clear all training data""" | |
| if self.is_training_running(): | |
| return gr.Error("Cannot clear data while training is running") | |
| try: | |
| for file in self.app.training_videos_path.glob("*.*"): | |
| file.unlink() | |
| for file in self.app.training_path.glob("*.*"): | |
| file.unlink() | |
| self.append_log("Cleared all training data") | |
| return "Training data cleared successfully" | |
| except Exception as e: | |
| error_msg = f"Error clearing training data: {str(e)}" | |
| self.append_log(error_msg) | |
| return error_msg | |
| def save_status(self, state: str, **kwargs) -> None: | |
| """Save current training status""" | |
| status = { | |
| 'status': state, | |
| 'timestamp': datetime.now().isoformat(), | |
| **kwargs | |
| } | |
| if state == "Training started" or state == "initializing": | |
| gr.Info("Initializing model and dataset..") | |
| elif state == "training": | |
| #gr.Info("Training started!") | |
| # Training is in progress | |
| pass | |
| elif state == "completed": | |
| gr.Info("Training completed!") | |
| with open(self.app.output_status_file, 'w') as f: | |
| json.dump(status, f, indent=2) | |
| def _start_log_monitor(self, process: subprocess.Popen) -> None: | |
| """Start monitoring process output for logs""" | |
| def monitor(): | |
| self.append_log("Starting log monitor thread") | |
| def read_stream(stream, is_error=False): | |
| if stream: | |
| output = stream.readline() | |
| if output: | |
| # Remove decode() since output is already a string due to universal_newlines=True | |
| line = output.strip() | |
| self.append_log(line) | |
| if is_error: | |
| #logger.error(line) | |
| pass | |
| # Parse metrics only from stdout | |
| metrics = parse_training_log(line) | |
| if metrics: | |
| # Get current status first | |
| current_status = self.get_status() | |
| # Update with new metrics | |
| current_status.update(metrics) | |
| # Ensure 'state' is present, use current status if available, default to 'training' | |
| if 'status' in current_status: | |
| # Use 'status' as 'state' to match the required parameter | |
| state = current_status.pop('status', 'training') | |
| self.save_status(state, **current_status) | |
| else: | |
| # If no status in the current_status, use 'training' as the default state | |
| self.save_status('training', **current_status) | |
| return True | |
| return False | |
| # Create separate threads to monitor stdout and stderr | |
| def monitor_stream(stream, is_error=False): | |
| while process.poll() is None: | |
| if not read_stream(stream, is_error): | |
| time.sleep(0.1) # Short sleep to avoid CPU thrashing | |
| # Start threads to monitor each stream | |
| stdout_thread = threading.Thread(target=monitor_stream, args=(process.stdout, False)) | |
| stderr_thread = threading.Thread(target=monitor_stream, args=(process.stderr, True)) | |
| stdout_thread.daemon = True | |
| stderr_thread.daemon = True | |
| stdout_thread.start() | |
| stderr_thread.start() | |
| # Wait for process to complete | |
| process.wait() | |
| # Wait for threads to finish reading any remaining output | |
| stdout_thread.join(timeout=2) | |
| stderr_thread.join(timeout=2) | |
| # Process any remaining output after process ends | |
| while read_stream(process.stdout): | |
| pass | |
| while read_stream(process.stderr, True): | |
| pass | |
| # Process finished | |
| return_code = process.poll() | |
| if return_code == 0: | |
| success_msg = "Training completed successfully" | |
| self.append_log(success_msg) | |
| gr.Info(success_msg) | |
| self.save_status(state='completed', message=success_msg) | |
| # Clean up old LoRA weights to save disk space | |
| try: | |
| self.cleanup_old_lora_weights(max_to_keep=2) | |
| except Exception as e: | |
| logger.warning(f"Failed to cleanup old LoRA weights: {e}") | |
| self.append_log(f"Warning: Failed to cleanup old LoRA weights: {e}") | |
| # Upload final model if repository was specified | |
| session = self.load_session() | |
| if session and session['params'].get('repo_id'): | |
| repo_id = session['params']['repo_id'] | |
| latest_run = max(Path(self.app.output_path).glob('*'), key=os.path.getmtime) | |
| if self.upload_to_hub(latest_run, repo_id): | |
| self.append_log(f"Model uploaded to {repo_id}") | |
| else: | |
| self.append_log("Failed to upload model to hub") | |
| else: | |
| error_msg = f"Training failed with return code {return_code}" | |
| self.append_log(error_msg) | |
| logger.error(error_msg) | |
| self.save_status(state='error', message=error_msg) | |
| # Clean up PID file | |
| if self.app.output_pid_file.exists(): | |
| self.app.output_pid_file.unlink() | |
| monitor_thread = threading.Thread(target=monitor) | |
| monitor_thread.daemon = True | |
| monitor_thread.start() | |
| def upload_to_hub(self, model_path: Path, repo_id: str) -> bool: | |
| """Upload model to Hugging Face Hub | |
| Args: | |
| model_path: Path to model files | |
| repo_id: Repository ID (username/model-name) | |
| Returns: | |
| bool: Whether upload was successful | |
| """ | |
| try: | |
| token = os.getenv("HF_API_TOKEN") | |
| if not token: | |
| self.append_log("Error: HF_API_TOKEN not set") | |
| return False | |
| # Create or get repo | |
| create_repo(repo_id, token=token, repo_type="model", exist_ok=True) | |
| # Upload files | |
| upload_folder( | |
| folder_path=str(self.app.output_path), | |
| repo_id=repo_id, | |
| repo_type="model", | |
| commit_message="Training completed" | |
| ) | |
| return True | |
| except Exception as e: | |
| self.append_log(f"Error uploading to hub: {str(e)}") | |
| return False | |
| def get_model_output_info(self) -> Dict[str, Any]: | |
| """Return info about the model safetensors including path and step count | |
| Returns: | |
| Dict with 'path' (str or None) and 'steps' (int or None) | |
| """ | |
| result = {"path": None, "steps": None} | |
| # Check if the root level file exists (this should be the primary location) | |
| model_output_safetensors_path = self.app.output_path / "pytorch_lora_weights.safetensors" | |
| if model_output_safetensors_path.exists(): | |
| result["path"] = str(model_output_safetensors_path) | |
| # For root level, we can't determine steps easily, so return None | |
| return result | |
| # Check in lora_weights directory | |
| lora_weights_dir = self.app.output_path / "lora_weights" | |
| if lora_weights_dir.exists(): | |
| logger.info(f"Found lora_weights directory: {lora_weights_dir}") | |
| # Look for the latest checkpoint directory in lora_weights | |
| lora_checkpoints = [d for d in lora_weights_dir.glob("*") if d.is_dir() and d.name.isdigit()] | |
| if lora_checkpoints: | |
| latest_lora_checkpoint = max(lora_checkpoints, key=lambda x: int(x.name)) | |
| logger.info(f"Found latest LoRA checkpoint: {latest_lora_checkpoint}") | |
| # Extract step count from directory name | |
| result["steps"] = int(latest_lora_checkpoint.name) | |
| # List contents of the latest checkpoint directory | |
| checkpoint_contents = list(latest_lora_checkpoint.glob("*")) | |
| logger.info(f"Contents of LoRA checkpoint {latest_lora_checkpoint.name}: {checkpoint_contents}") | |
| # Check for weights in the latest LoRA checkpoint | |
| lora_safetensors = latest_lora_checkpoint / "pytorch_lora_weights.safetensors" | |
| if lora_safetensors.exists(): | |
| logger.info(f"Found weights in latest LoRA checkpoint: {lora_safetensors}") | |
| result["path"] = str(lora_safetensors) | |
| return result | |
| # Also check for other common weight file names | |
| possible_weight_files = [ | |
| "pytorch_lora_weights.safetensors", | |
| "adapter_model.safetensors", | |
| "pytorch_model.safetensors", | |
| "model.safetensors" | |
| ] | |
| for weight_file in possible_weight_files: | |
| weight_path = latest_lora_checkpoint / weight_file | |
| if weight_path.exists(): | |
| logger.info(f"Found weights file {weight_file} in latest LoRA checkpoint: {weight_path}") | |
| result["path"] = str(weight_path) | |
| return result | |
| # Check if any .safetensors files exist | |
| safetensors_files = list(latest_lora_checkpoint.glob("*.safetensors")) | |
| if safetensors_files: | |
| logger.info(f"Found .safetensors files in LoRA checkpoint: {safetensors_files}") | |
| # Return the first .safetensors file found | |
| result["path"] = str(safetensors_files[0]) | |
| return result | |
| # Fallback: check for direct safetensors file in lora_weights root | |
| lora_safetensors = lora_weights_dir / "pytorch_lora_weights.safetensors" | |
| if lora_safetensors.exists(): | |
| logger.info(f"Found weights in lora_weights directory: {lora_safetensors}") | |
| result["path"] = str(lora_safetensors) | |
| return result | |
| else: | |
| logger.info(f"pytorch_lora_weights.safetensors not found in lora_weights directory") | |
| # If not found in root or lora_weights, log the issue and check fallback | |
| logger.warning(f"Model weights not found at expected location: {model_output_safetensors_path}") | |
| logger.info(f"Checking output directory contents: {list(self.app.output_path.glob('*'))}") | |
| # Check if there are any checkpoint directories as a fallback | |
| checkpoints = list(self.app.output_path.glob("finetrainers_step_*")) | |
| if checkpoints: | |
| logger.info(f"Found {len(checkpoints)} checkpoint directories, but main weights file is missing") | |
| latest_checkpoint = max(checkpoints, key=lambda x: int(x.name.split("_")[-1])) | |
| logger.info(f"Latest checkpoint directory: {latest_checkpoint}") | |
| # Extract step count from checkpoint directory name | |
| result["steps"] = int(latest_checkpoint.name.split("_")[-1]) | |
| # Log contents of latest checkpoint | |
| checkpoint_contents = list(latest_checkpoint.glob("*")) | |
| logger.info(f"Contents of latest checkpoint {latest_checkpoint.name}: {checkpoint_contents}") | |
| checkpoint_weights = latest_checkpoint / "pytorch_lora_weights.safetensors" | |
| if checkpoint_weights.exists(): | |
| logger.info(f"Found weights in latest checkpoint: {checkpoint_weights}") | |
| result["path"] = str(checkpoint_weights) | |
| return result | |
| else: | |
| logger.info(f"pytorch_lora_weights.safetensors not found in checkpoint directory") | |
| return result | |
| def get_model_output_safetensors(self) -> Optional[str]: | |
| """Return the path to the model safetensors | |
| Returns: | |
| Path to safetensors file or None if not found | |
| """ | |
| return self.get_model_output_info()["path"] | |
| def create_training_dataset_zip(self) -> str: | |
| """Create a ZIP file containing all training data | |
| Returns: | |
| Path to created ZIP file | |
| """ | |
| # Create temporary zip file | |
| with tempfile.NamedTemporaryFile(suffix='.zip', delete=False) as temp_zip: | |
| temp_zip_path = str(temp_zip.name) | |
| print(f"Creating zip file for {self.app.training_path}..") | |
| try: | |
| make_archive(self.app.training_path, temp_zip_path) | |
| print(f"Zip file created!") | |
| return temp_zip_path | |
| except Exception as e: | |
| print(f"Failed to create zip: {str(e)}") | |
| raise gr.Error(f"Failed to create zip: {str(e)}") | |
| def create_output_directory_zip(self) -> str: | |
| """Create a ZIP file containing all output data (checkpoints, models, etc.) | |
| Returns: | |
| Path to created ZIP file | |
| """ | |
| # Create temporary zip file | |
| with tempfile.NamedTemporaryFile(suffix='.zip', delete=False) as temp_zip: | |
| temp_zip_path = str(temp_zip.name) | |
| print(f"Creating zip file for {self.app.output_path}..") | |
| try: | |
| make_archive(self.app.output_path, temp_zip_path) | |
| print(f"Output zip file created!") | |
| return temp_zip_path | |
| except Exception as e: | |
| print(f"Failed to create output zip: {str(e)}") | |
| raise gr.Error(f"Failed to create output zip: {str(e)}") | |
| def create_checkpoint_zip(self) -> Optional[str]: | |
| """Create a ZIP file containing the latest finetrainers checkpoint | |
| Returns: | |
| Path to created ZIP file or None if no checkpoint found | |
| """ | |
| # Find all checkpoint directories | |
| checkpoints = list(self.app.output_path.glob("finetrainers_step_*")) | |
| if not checkpoints: | |
| logger.info("No checkpoint directories found") | |
| raise gr.Error("No checkpoint directories found") | |
| # Get the latest checkpoint by step number | |
| latest_checkpoint = max(checkpoints, key=lambda x: int(x.name.split("_")[-1])) | |
| step_num = int(latest_checkpoint.name.split("_")[-1]) | |
| # Create temporary zip file | |
| with tempfile.NamedTemporaryFile(suffix='.zip', delete=False) as temp_zip: | |
| temp_zip_path = str(temp_zip.name) | |
| print(f"Creating zip file for checkpoint {latest_checkpoint.name}..") | |
| try: | |
| make_archive(latest_checkpoint, temp_zip_path) | |
| print(f"Checkpoint zip file created for step {step_num}!") | |
| return temp_zip_path | |
| except Exception as e: | |
| print(f"Failed to create checkpoint zip: {str(e)}") | |
| raise gr.Error(f"Failed to create checkpoint zip: {str(e)}") | |
| def get_checkpoint_button_text(self) -> str: | |
| """Get the dynamic text for the download checkpoint button based on available checkpoints""" | |
| try: | |
| checkpoints = list(self.app.output_path.glob("finetrainers_step_*")) | |
| if not checkpoints: | |
| return "📥 Download checkpoints (not available)" | |
| # Get the latest checkpoint by step number | |
| latest_checkpoint = max(checkpoints, key=lambda x: int(x.name.split("_")[-1])) | |
| step_num = int(latest_checkpoint.name.split("_")[-1]) | |
| return f"📥 Download checkpoints (step {step_num})" | |
| except Exception as e: | |
| logger.warning(f"Error getting checkpoint info for button text: {e}") | |
| return "📥 Download checkpoints (not available)" |