Diffusers documentation
LatteTransformer3DModel
LatteTransformer3DModel
A Diffusion Transformer model for 3D data from Latte.
LatteTransformer3DModel
class diffusers.LatteTransformer3DModel
< source >( num_attention_heads: int = 16 attention_head_dim: int = 88 in_channels: Optional = None out_channels: Optional = None num_layers: int = 1 dropout: float = 0.0 cross_attention_dim: Optional = None attention_bias: bool = False sample_size: int = 64 patch_size: Optional = None activation_fn: str = 'geglu' num_embeds_ada_norm: Optional = None norm_type: str = 'layer_norm' norm_elementwise_affine: bool = True norm_eps: float = 1e-05 caption_channels: int = None video_length: int = 16 )
forward
< source >( hidden_states: Tensor timestep: Optional = None encoder_hidden_states: Optional = None encoder_attention_mask: Optional = None enable_temporal_attentions: bool = True return_dict: bool = True )
Parameters
- hidden_states shape
(batch size, channel, num_frame, height, width)— Inputhidden_states. - timestep (
torch.LongTensor, optional) — Used to indicate denoising step. Optional timestep to be applied as an embedding inAdaLayerNorm. - encoder_hidden_states (
torch.FloatTensorof shape(batch size, sequence len, embed dims), optional) — Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. - encoder_attention_mask (
torch.Tensor, optional) — Cross-attention mask applied toencoder_hidden_states. Two formats supported:- Mask
(batcheight, sequence_length)True = keep, False = discard. - Bias
(batcheight, 1, sequence_length)0 = keep, -10000 = discard.
If
ndim == 2: will be interpreted as a mask, then converted into a bias consistent with the format above. This bias will be added to the cross-attention scores. enable_temporal_attentions — (bool, optional, defaults toTrue): Whether to enable temporal attentions. - Mask
- return_dict (
bool, optional, defaults toTrue) — Whether or not to return a~models.unet_2d_condition.UNet2DConditionOutputinstead of a plain tuple.
The LatteTransformer3DModel forward method.