|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import gc |
|
import os |
|
from abc import ABC |
|
from typing import Any |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from .t5_text_encoder import CosmosT5TextEncoder |
|
from .presets import presets as guardrail_presets |
|
|
|
|
|
class BaseWorldGenerationPipeline(ABC): |
|
def __init__( |
|
self, |
|
inference_type: str | None = None, |
|
checkpoint_dir: str | None = None, |
|
checkpoint_name: str | None = None, |
|
enable_text_guardrail: bool = False, |
|
enable_video_guardrail: bool = False, |
|
offload_network: bool = False, |
|
offload_tokenizer: bool = False, |
|
offload_text_encoder_model: bool = False, |
|
offload_guardrail_models: bool = False, |
|
): |
|
"""Initialize base world generation pipeline. |
|
|
|
This abstract base class provides core functionality for world generation models including: |
|
- Model loading and initialization |
|
- Text encoding and embedding |
|
- Safety checks and content filtering |
|
- Memory management through model offloading |
|
|
|
Args: |
|
inference_type: The type of inference pipeline ("text2world" or "video2world") |
|
checkpoint_dir: Root directory containing model checkpoints |
|
checkpoint_name: Name of the specific checkpoint file to load |
|
enable_text_guardrail: If True, validates input prompts for safety |
|
enable_video_guardrail: If True, validates generated videos for safety |
|
offload_network: If True, moves main model to CPU after inference |
|
offload_tokenizer: If True, moves tokenizer to CPU after use |
|
offload_text_encoder_model: If True, moves T5 encoder to CPU after encoding |
|
offload_guardrail_models: If True, moves safety models to CPU after checks |
|
""" |
|
self.inference_type = inference_type |
|
self.checkpoint_dir = checkpoint_dir |
|
self.checkpoint_name = checkpoint_name |
|
self.guardrail_dir = "Cosmos-1.0-Guardrail" |
|
self.enable_text_guardrail = enable_text_guardrail |
|
self.enable_video_guardrail = enable_video_guardrail |
|
|
|
|
|
self.offload_network = offload_network |
|
self.offload_tokenizer = offload_tokenizer |
|
self.offload_text_encoder_model = offload_text_encoder_model |
|
self.offload_guardrail_models = offload_guardrail_models |
|
|
|
|
|
self.text_guardrail = None |
|
self.video_guardrail = None |
|
self.text_encoder = None |
|
self.model = None |
|
|
|
self._load_model() |
|
|
|
if not self.offload_text_encoder_model: |
|
self._load_text_encoder_model() |
|
if not self.offload_guardrail_models: |
|
if self.enable_text_guardrail: |
|
self._load_text_guardrail() |
|
if self.enable_video_guardrail: |
|
self._load_video_guardrail() |
|
if not self.offload_network: |
|
self._load_network() |
|
if not self.offload_tokenizer: |
|
self._load_tokenizer() |
|
|
|
def _load_tokenizer(self): |
|
pass |
|
|
|
def _load_network(self): |
|
pass |
|
|
|
def _load_model(self, checkpoint_name: str) -> Any: |
|
"""Load the world generation model from a checkpoint. |
|
|
|
This abstract method must be implemented by subclasses to load their specific |
|
model architecture and weights. |
|
|
|
Args: |
|
checkpoint_name: Path to the model checkpoint file |
|
|
|
Returns: |
|
The loaded model instance |
|
|
|
Raises: |
|
NotImplementedError: Must be implemented by subclasses |
|
""" |
|
pass |
|
|
|
def _load_text_encoder_model(self): |
|
"""Load the T5 text encoder model. |
|
|
|
Initializes and loads the T5 encoder model used for converting text prompts |
|
into embeddings that condition the world generation model. |
|
|
|
Returns: |
|
Loaded T5 text encoder model instance |
|
""" |
|
self.text_encoder = CosmosT5TextEncoder(cache_dir=self.checkpoint_dir) |
|
|
|
def _load_text_guardrail(self): |
|
"""Load text safety classifier models. |
|
|
|
Initializes models used for checking input prompts against safety policies. |
|
Models are loaded from the specified guardrail directory. |
|
""" |
|
self.text_guardrail = guardrail_presets.create_text_guardrail_runner( |
|
checkpoint_dir=os.path.join(self.checkpoint_dir, self.guardrail_dir) |
|
) |
|
|
|
def _load_video_guardrail(self): |
|
"""Load video safety classifier models. |
|
|
|
Initializes models used for validating generated video content against |
|
safety policies. Models are loaded from the specified guardrail directory. |
|
""" |
|
self.video_guardrail = guardrail_presets.create_video_guardrail_runner( |
|
checkpoint_dir=os.path.join(self.checkpoint_dir, self.guardrail_dir) |
|
) |
|
|
|
def _offload_network(self): |
|
if self.model.model: |
|
del self.model.model |
|
self.model.model = None |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
def _offload_tokenizer(self): |
|
if self.model.tokenizer: |
|
del self.model.tokenizer |
|
self.model.tokenizer = None |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
def _offload_guardrail_models(self): |
|
"""Offload safety classifier models to reduce memory usage. |
|
|
|
Moves safety models to CPU and clears GPU memory if they are no longer needed. |
|
This helps manage memory when processing multiple inputs sequentially. |
|
""" |
|
if self.text_guardrail: |
|
del self.text_guardrail |
|
self.text_guardrail = None |
|
if self.video_guardrail: |
|
del self.video_guardrail |
|
self.video_guardrail = None |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
def _offload_text_encoder_model(self): |
|
"""Offload T5 text encoder to reduce memory usage. |
|
|
|
Moves the T5 encoder to CPU and clears GPU memory after text encoding is complete. |
|
This helps manage memory when processing multiple inputs sequentially. |
|
""" |
|
if self.text_encoder: |
|
del self.text_encoder |
|
self.text_encoder = None |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
def _run_model(self, *args: Any, **kwargs: Any) -> torch.Tensor: |
|
"""Generate world latents using the model. |
|
|
|
This abstract method must be implemented by subclasses to define their specific |
|
generation process. |
|
|
|
Args: |
|
*args: Variable positional arguments for model inference |
|
**kwargs: Variable keyword arguments for model inference |
|
|
|
Returns: |
|
torch.Tensor: Generated world representation tensor |
|
""" |
|
pass |
|
|
|
def _run_model_with_offload(self, *args: Any, **kwargs: Any) -> torch.Tensor: |
|
"""Generate world representation with memory management. |
|
|
|
Handles loading the model before inference and offloading afterward if enabled. |
|
This helps minimize GPU memory usage during inference. |
|
|
|
Args: |
|
*args: Arguments passed to _run_model |
|
**kwargs: Keyword arguments passed to _run_model |
|
|
|
Returns: |
|
np.ndarray: Generated world representation as numpy array |
|
""" |
|
pass |
|
|
|
def _run_guardrail_on_prompt(self, prompt: str) -> bool: |
|
"""Check if prompt meets safety requirements. |
|
|
|
Validates the input prompt against safety policies using loaded guardrail models. |
|
|
|
Args: |
|
prompt: Raw text prompt to validate |
|
|
|
Returns: |
|
bool: True if prompt passes all safety checks, False otherwise |
|
""" |
|
return guardrail_presets.run_text_guardrail(prompt, self.text_guardrail) |
|
|
|
def _run_guardrail_on_prompt_with_offload(self, prompt: str) -> bool: |
|
"""Check prompt safety with memory management. |
|
|
|
Validates prompt safety while handling model loading/offloading to manage memory. |
|
|
|
Args: |
|
prompt: Raw text prompt to validate |
|
|
|
Returns: |
|
bool: True if prompt passes all safety checks, False otherwise |
|
""" |
|
if self.offload_guardrail_models: |
|
self._load_text_guardrail() |
|
|
|
is_safe = self._run_guardrail_on_prompt(prompt) |
|
|
|
if self.offload_guardrail_models: |
|
self._offload_guardrail_models() |
|
|
|
return is_safe |
|
|
|
def _run_guardrail_on_video(self, video: np.ndarray) -> np.ndarray | None: |
|
"""Check if video meets safety requirements. |
|
|
|
Validates generated video content against safety policies using guardrail models. |
|
|
|
Args: |
|
video: Video frames to validate |
|
|
|
Returns: |
|
np.ndarray: Processed video if safe, None if unsafe |
|
""" |
|
return guardrail_presets.run_video_guardrail(video, self.video_guardrail) |
|
|
|
def _run_guardrail_on_video_with_offload(self, video: np.ndarray) -> np.ndarray | None: |
|
"""Check if generated video meets safety requirements. |
|
|
|
Args: |
|
video: Video frames to validate |
|
|
|
Returns: |
|
np.ndarray: Processed video frames if safe, None otherwise |
|
|
|
Note: |
|
Guardrail models are offloaded after checks if enabled. |
|
""" |
|
if self.offload_guardrail_models: |
|
self._load_video_guardrail() |
|
|
|
video = self._run_guardrail_on_video(video) |
|
|
|
if self.offload_guardrail_models: |
|
self._offload_guardrail_models() |
|
return video |
|
|
|
def _run_text_embedding_on_prompt( |
|
self, prompts: list[str], **kwargs: Any |
|
) -> tuple[list[torch.Tensor], list[torch.Tensor]]: |
|
"""Convert text prompts to embeddings. |
|
|
|
Processes text prompts into embedding tensors that condition the generation model. |
|
|
|
Args: |
|
prompts: List of text prompts to encode |
|
**kwargs: Additional arguments for text encoding |
|
|
|
Returns: |
|
tuple containing: |
|
- List of text embedding tensors for each prompt |
|
- List of attention masks for each embedding |
|
""" |
|
|
|
embeddings = [] |
|
masks = [] |
|
for prompt in prompts: |
|
embedding, mask = self.text_encoder.encode_prompts( |
|
[prompt], |
|
**kwargs, |
|
) |
|
embeddings.append(embedding) |
|
masks.append(mask) |
|
|
|
return embeddings, masks |
|
|
|
def _run_text_embedding_on_prompt_with_offload( |
|
self, prompts: list[str], **kwargs: Any |
|
) -> tuple[list[torch.Tensor], list[torch.Tensor]]: |
|
"""Convert text prompt into embeddings using T5 encoder. |
|
|
|
Args: |
|
prompt: Processed and validated text prompt |
|
|
|
Returns: |
|
Text embedding tensor to condition diffusion model |
|
|
|
Note: |
|
T5 model is offloaded after encoding if enabled. |
|
""" |
|
if self.offload_text_encoder_model: |
|
self._load_text_encoder_model() |
|
|
|
embeddings, masks = self._run_text_embedding_on_prompt(prompts, **kwargs) |
|
|
|
if self.offload_text_encoder_model: |
|
self._offload_text_encoder_model() |
|
return embeddings, masks |
|
|
|
def _run_tokenizer_decoding(self, samples: torch.Tensor) -> np.ndarray: |
|
"""Decode model outputs into final world representation. |
|
|
|
This abstract method must be implemented by subclasses to convert raw model |
|
outputs into their specific world representation format. |
|
|
|
Args: |
|
samples: Raw output tensor from the generation model |
|
|
|
Returns: |
|
np.ndarray: Decoded world representation |
|
""" |
|
pass |
|
|
|
def generate(self, *args: Any, **kwargs: Any): |
|
"""Generate world representation. |
|
|
|
This abstract method must be implemented by subclasses to convert raw model |
|
outputs into their specific world representation format. |
|
|
|
Args: |
|
*args: Variable positional arguments for model inference |
|
**kwargs: Variable keyword arguments for model inference |
|
""" |
|
pass |
|
|