| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from step1x3d_geometry.utils.typing import * | |
| from step1x3d_geometry.utils.checkpoint import checkpoint | |
| from .utils import init_linear | |
| from .attention import ResidualAttentionBlock | |
| class Perceiver(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| n_ctx: int, | |
| width: int, | |
| layers: int, | |
| heads: int, | |
| init_scale: float = 0.25, | |
| qkv_bias: bool = True, | |
| qk_norm: bool = True, | |
| use_flash: bool = False, | |
| use_checkpoint: bool = False | |
| ): | |
| super().__init__() | |
| self.n_ctx = n_ctx | |
| self.width = width | |
| self.layers = layers | |
| self.resblocks = nn.ModuleList( | |
| [ | |
| ResidualAttentionBlock( | |
| n_ctx=n_ctx, | |
| width=width, | |
| heads=heads, | |
| init_scale=init_scale, | |
| qkv_bias=qkv_bias, | |
| qk_norm=qk_norm, | |
| use_flash=use_flash, | |
| use_checkpoint=use_checkpoint, | |
| ) | |
| for _ in range(layers) | |
| ] | |
| ) | |
| def forward(self, x: torch.Tensor): | |
| for block in self.resblocks: | |
| x = block(x) | |
| return x | |