|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
from abc import ABC, abstractmethod |
|
from collections import defaultdict |
|
from dataclasses import dataclass, fields |
|
from enum import Enum |
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from .batch_ops import batch_mul |
|
from .log import log |
|
from .lazy_config_init import instantiate |
|
|
|
|
|
class BaseConditionEntry(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
self._dropout_rate = None |
|
self._input_key = None |
|
self._return_dict = False |
|
|
|
@property |
|
def dropout_rate(self) -> Union[float, torch.Tensor]: |
|
return self._dropout_rate |
|
|
|
@property |
|
def input_key(self) -> str: |
|
return self._input_key |
|
|
|
@property |
|
def is_return_dict(self) -> bool: |
|
return self._return_dict |
|
|
|
@dropout_rate.setter |
|
def dropout_rate(self, value: Union[float, torch.Tensor]): |
|
self._dropout_rate = value |
|
|
|
@input_key.setter |
|
def input_key(self, value: str): |
|
self._input_key = value |
|
|
|
@is_return_dict.setter |
|
def is_return_dict(self, value: bool): |
|
self._return_dict = value |
|
|
|
@dropout_rate.deleter |
|
def dropout_rate(self): |
|
del self._dropout_rate |
|
|
|
@input_key.deleter |
|
def input_key(self): |
|
del self._input_key |
|
|
|
@is_return_dict.deleter |
|
def is_return_dict(self): |
|
del self._return_dict |
|
|
|
def random_dropout_input( |
|
self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None |
|
) -> torch.Tensor: |
|
del key |
|
dropout_rate = dropout_rate if dropout_rate is not None else self.dropout_rate |
|
return batch_mul( |
|
torch.bernoulli((1.0 - dropout_rate) * torch.ones(in_tensor.shape[0])).type_as(in_tensor), |
|
in_tensor, |
|
) |
|
|
|
def summary(self) -> str: |
|
pass |
|
|
|
|
|
class DataType(Enum): |
|
IMAGE = "image" |
|
VIDEO = "video" |
|
|
|
|
|
class TextAttr(BaseConditionEntry): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, token: torch.Tensor, mask: torch.Tensor): |
|
return {"crossattn_emb": token, "crossattn_mask": mask} |
|
|
|
def random_dropout_input( |
|
self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None |
|
) -> torch.Tensor: |
|
if key is not None and "mask" in key: |
|
return in_tensor |
|
return super().random_dropout_input(in_tensor, dropout_rate, key) |
|
|
|
|
|
@dataclass |
|
class BaseVideoCondition: |
|
crossattn_emb: torch.Tensor |
|
crossattn_mask: torch.Tensor |
|
data_type: DataType = DataType.VIDEO |
|
padding_mask: Optional[torch.Tensor] = None |
|
fps: Optional[torch.Tensor] = None |
|
num_frames: Optional[torch.Tensor] = None |
|
image_size: Optional[torch.Tensor] = None |
|
scalar_feature: Optional[torch.Tensor] = None |
|
|
|
def to_dict(self) -> Dict[str, Optional[torch.Tensor]]: |
|
return {f.name: getattr(self, f.name) for f in fields(self)} |
|
|
|
|
|
@dataclass |
|
class VideoExtendCondition(BaseVideoCondition): |
|
video_cond_bool: Optional[torch.Tensor] = None |
|
gt_latent: Optional[torch.Tensor] = None |
|
condition_video_indicator: Optional[torch.Tensor] = None |
|
|
|
|
|
|
|
condition_video_input_mask: Optional[torch.Tensor] = None |
|
|
|
condition_video_augment_sigma: Optional[torch.Tensor] = None |
|
|
|
|
|
class GeneralConditioner(nn.Module, ABC): |
|
""" |
|
An abstract module designed to handle various embedding models with conditional and |
|
unconditional configurations. This abstract base class initializes and manages a collection |
|
of embedders that can dynamically adjust their dropout rates based on conditioning. |
|
|
|
Attributes: |
|
KEY2DIM (dict): A mapping from output keys to dimensions used for concatenation. |
|
embedders (nn.ModuleDict): A dictionary containing all embedded models initialized and |
|
configured based on the provided configurations. |
|
|
|
Parameters: |
|
emb_models (Union[List, Any]): A dictionary where keys are embedder names and values |
|
are configurations for initializing the embedders. |
|
|
|
""" |
|
|
|
KEY2DIM = {"crossattn_emb": 1, "crossattn_mask": 1} |
|
|
|
def __init__(self, **emb_models: Union[List, Any]): |
|
super().__init__() |
|
self.embedders = nn.ModuleDict() |
|
for n, (emb_name, embconfig) in enumerate(emb_models.items()): |
|
embedder = instantiate(embconfig.obj) |
|
assert isinstance( |
|
embedder, BaseConditionEntry |
|
), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel" |
|
embedder.dropout_rate = getattr(embconfig, "dropout_rate", 0.0) |
|
|
|
if hasattr(embconfig, "input_key"): |
|
embedder.input_key = embconfig.input_key |
|
elif hasattr(embconfig, "input_keys"): |
|
embedder.input_keys = embconfig.input_keys |
|
else: |
|
raise KeyError(f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}") |
|
|
|
log.debug(f"Initialized embedder #{n}-{emb_name}: \n {embedder.summary()}") |
|
self.embedders[emb_name] = embedder |
|
|
|
@abstractmethod |
|
def forward( |
|
self, |
|
batch: Dict, |
|
override_dropout_rate: Optional[Dict[str, float]] = None, |
|
) -> Any: |
|
"""Should be implemented in subclasses to handle conditon datatype""" |
|
raise NotImplementedError |
|
|
|
def _forward( |
|
self, |
|
batch: Dict, |
|
override_dropout_rate: Optional[Dict[str, float]] = None, |
|
) -> Dict: |
|
""" |
|
Processes the input batch through all configured embedders, applying conditional dropout rates if specified. |
|
Output tensors for each key are concatenated along the dimensions specified in KEY2DIM. |
|
|
|
Parameters: |
|
batch (Dict): The input data batch to process. |
|
override_dropout_rate (Optional[Dict[str, float]]): Optional dictionary to override default dropout rates |
|
per embedder key. |
|
|
|
Returns: |
|
Dict: A dictionary of output tensors concatenated by specified dimensions. |
|
|
|
Note: |
|
In case the network code is sensitive to the order of concatenation, you can either control the order via \ |
|
config file or make sure the embedders return a unique key for each output. |
|
""" |
|
output = defaultdict(list) |
|
if override_dropout_rate is None: |
|
override_dropout_rate = {} |
|
|
|
|
|
for emb_name in override_dropout_rate.keys(): |
|
assert emb_name in self.embedders, f"invalid name found {emb_name}" |
|
|
|
for emb_name, embedder in self.embedders.items(): |
|
with torch.no_grad(): |
|
if hasattr(embedder, "input_key") and (embedder.input_key is not None): |
|
emb_out = embedder( |
|
embedder.random_dropout_input( |
|
batch[embedder.input_key], override_dropout_rate.get(emb_name, None) |
|
) |
|
) |
|
elif hasattr(embedder, "input_keys"): |
|
emb_out = embedder( |
|
*[ |
|
embedder.random_dropout_input(batch[k], override_dropout_rate.get(emb_name, None), k) |
|
for k in embedder.input_keys |
|
] |
|
) |
|
for k, v in emb_out.items(): |
|
output[k].append(v) |
|
|
|
return {k: torch.cat(v, dim=self.KEY2DIM.get(k, -1)) for k, v in output.items()} |
|
|
|
def get_condition_uncondition( |
|
self, |
|
data_batch: Dict, |
|
) -> Tuple[Any, Any]: |
|
""" |
|
Processes the provided data batch to generate conditioned and unconditioned outputs. |
|
|
|
This method manipulates dropout rates to simulate two scenarios: |
|
1. All conditions applied (conditioned) |
|
2. Conditions removed/reduced to minimum (unconditioned) |
|
|
|
This method sets dropout rates to zero for the conditioned scenario to fully apply |
|
embedders' effects. For unconditioned, it sets rates to 1 (or 0 if initial rate is |
|
insignificant) to minimize embedder influences. |
|
|
|
Parameters: |
|
data_batch (Dict): Input data batch containing all necessary information for |
|
embedding processing. |
|
|
|
Returns: |
|
Tuple[Any, Any]: A tuple containing: |
|
- Outputs with all embedders fully applied (conditioned) |
|
- Outputs with embedders minimized/not applied (unconditioned) |
|
""" |
|
cond_dropout_rates, dropout_rates = {}, {} |
|
for emb_name, embedder in self.embedders.items(): |
|
cond_dropout_rates[emb_name] = 0.0 |
|
dropout_rates[emb_name] = 1.0 if embedder.dropout_rate > 1e-4 else 0.0 |
|
|
|
condition: Any = self(data_batch, override_dropout_rate=cond_dropout_rates) |
|
un_condition: Any = self(data_batch, override_dropout_rate=dropout_rates) |
|
return condition, un_condition |
|
|
|
def get_condition_with_negative_prompt( |
|
self, |
|
data_batch: Dict, |
|
) -> Tuple[Any, Any]: |
|
""" |
|
Similar functionality as get_condition_uncondition |
|
But use negative prompts for unconditon |
|
""" |
|
cond_dropout_rates, uncond_dropout_rates = {}, {} |
|
for emb_name, embedder in self.embedders.items(): |
|
cond_dropout_rates[emb_name] = 0.0 |
|
if isinstance(embedder, TextAttr): |
|
uncond_dropout_rates[emb_name] = 0.0 |
|
else: |
|
uncond_dropout_rates[emb_name] = 1.0 if embedder.dropout_rate > 1e-4 else 0.0 |
|
|
|
data_batch_neg_prompt = copy.deepcopy(data_batch) |
|
if "neg_t5_text_embeddings" in data_batch_neg_prompt: |
|
if isinstance(data_batch_neg_prompt["neg_t5_text_embeddings"], torch.Tensor): |
|
data_batch_neg_prompt["t5_text_embeddings"] = data_batch_neg_prompt["neg_t5_text_embeddings"] |
|
data_batch_neg_prompt["t5_text_mask"] = data_batch_neg_prompt["neg_t5_text_mask"] |
|
|
|
condition: Any = self(data_batch, override_dropout_rate=cond_dropout_rates) |
|
un_condition: Any = self(data_batch_neg_prompt, override_dropout_rate=uncond_dropout_rates) |
|
|
|
return condition, un_condition |
|
|
|
|
|
@dataclass |
|
class CosmosCondition: |
|
crossattn_emb: torch.Tensor |
|
crossattn_mask: torch.Tensor |
|
padding_mask: Optional[torch.Tensor] = None |
|
scalar_feature: Optional[torch.Tensor] = None |
|
|
|
def to_dict(self) -> Dict[str, Optional[torch.Tensor]]: |
|
return {f.name: getattr(self, f.name) for f in fields(self)} |
|
|
|
|
|
class VideoConditioner(GeneralConditioner): |
|
def forward( |
|
self, |
|
batch: Dict, |
|
override_dropout_rate: Optional[Dict[str, float]] = None, |
|
) -> BaseVideoCondition: |
|
output = super()._forward(batch, override_dropout_rate) |
|
return BaseVideoCondition(**output) |
|
|
|
|
|
class VideoExtendConditioner(GeneralConditioner): |
|
def forward( |
|
self, |
|
batch: Dict, |
|
override_dropout_rate: Optional[Dict[str, float]] = None, |
|
) -> VideoExtendCondition: |
|
output = super()._forward(batch, override_dropout_rate) |
|
return VideoExtendCondition(**output) |
|
|