|
from pathlib import Path |
|
from typing import Any, Dict, List |
|
|
|
import torch |
|
from pydantic import BaseModel |
|
|
|
|
|
class State(BaseModel): |
|
model_config = {"arbitrary_types_allowed": True} |
|
|
|
train_frames: int |
|
train_height: int |
|
train_width: int |
|
|
|
transformer_config: Dict[str, Any] = None |
|
|
|
weight_dtype: torch.dtype = torch.float32 |
|
num_trainable_parameters: int = 0 |
|
overwrote_max_train_steps: bool = False |
|
num_update_steps_per_epoch: int = 0 |
|
total_batch_size_count: int = 0 |
|
|
|
generator: torch.Generator | None = None |
|
|
|
validation_prompts: List[str] = [] |
|
validation_images: List[Path | None] = [] |
|
validation_videos: List[Path | None] = [] |
|
|
|
|
|
validation_prompt_embeddings: List[Path | None] = [] |
|
validation_video_latents: List[Path | None] = [] |
|
validation_flow_latents: List[Path | None] = [] |
|
validation_valid_masks: List[Path | None] = [] |
|
validation_valid_masks_interp: List[Path | None] = [] |
|
|
|
using_deepspeed: bool = False |
|
|