|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Tuple, Union |
|
import torch |
|
|
|
|
|
def get_2d_sincos_pos_embed( |
|
embed_dim: int, grid_size: Union[int, Tuple[int, int]] |
|
) -> torch.Tensor: |
|
""" |
|
This function initializes a grid and generates a 2D positional embedding using sine and cosine functions. |
|
It is a wrapper of get_2d_sincos_pos_embed_from_grid. |
|
Args: |
|
- embed_dim: The embedding dimension. |
|
- grid_size: The grid size. |
|
Returns: |
|
- pos_embed: The generated 2D positional embedding. |
|
""" |
|
if isinstance(grid_size, tuple): |
|
grid_size_h, grid_size_w = grid_size |
|
else: |
|
grid_size_h = grid_size_w = grid_size |
|
grid_h = torch.arange(grid_size_h, dtype=torch.float) |
|
grid_w = torch.arange(grid_size_w, dtype=torch.float) |
|
grid = torch.meshgrid(grid_w, grid_h, indexing="xy") |
|
grid = torch.stack(grid, dim=0) |
|
grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) |
|
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) |
|
return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2) |
|
|
|
|
|
def get_2d_sincos_pos_embed_from_grid( |
|
embed_dim: int, grid: torch.Tensor |
|
) -> torch.Tensor: |
|
""" |
|
This function generates a 2D positional embedding from a given grid using sine and cosine functions. |
|
|
|
Args: |
|
- embed_dim: The embedding dimension. |
|
- grid: The grid to generate the embedding from. |
|
|
|
Returns: |
|
- emb: The generated 2D positional embedding. |
|
""" |
|
assert embed_dim % 2 == 0 |
|
|
|
|
|
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) |
|
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) |
|
|
|
emb = torch.cat([emb_h, emb_w], dim=2) |
|
return emb |
|
|
|
|
|
def get_1d_sincos_pos_embed_from_grid( |
|
embed_dim: int, pos: torch.Tensor |
|
) -> torch.Tensor: |
|
""" |
|
This function generates a 1D positional embedding from a given grid using sine and cosine functions. |
|
|
|
Args: |
|
- embed_dim: The embedding dimension. |
|
- pos: The position to generate the embedding from. |
|
|
|
Returns: |
|
- emb: The generated 1D positional embedding. |
|
""" |
|
assert embed_dim % 2 == 0 |
|
omega = torch.arange(embed_dim // 2, dtype=torch.double) |
|
omega /= embed_dim / 2.0 |
|
omega = 1.0 / 10000**omega |
|
|
|
pos = pos.reshape(-1) |
|
out = torch.einsum("m,d->md", pos, omega) |
|
|
|
emb_sin = torch.sin(out) |
|
emb_cos = torch.cos(out) |
|
|
|
emb = torch.cat([emb_sin, emb_cos], dim=1) |
|
return emb[None].float() |
|
|
|
|
|
def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor: |
|
""" |
|
This function generates a 2D positional embedding from given coordinates using sine and cosine functions. |
|
|
|
Args: |
|
- xy: The coordinates to generate the embedding from. |
|
- C: The size of the embedding. |
|
- cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding. |
|
|
|
Returns: |
|
- pe: The generated 2D positional embedding. |
|
""" |
|
B, N, D = xy.shape |
|
assert D == 2 |
|
|
|
x = xy[:, :, 0:1] |
|
y = xy[:, :, 1:2] |
|
div_term = ( |
|
torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C) |
|
).reshape(1, 1, int(C / 2)) |
|
|
|
pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) |
|
pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) |
|
|
|
pe_x[:, :, 0::2] = torch.sin(x * div_term) |
|
pe_x[:, :, 1::2] = torch.cos(x * div_term) |
|
|
|
pe_y[:, :, 0::2] = torch.sin(y * div_term) |
|
pe_y[:, :, 1::2] = torch.cos(y * div_term) |
|
|
|
pe = torch.cat([pe_x, pe_y], dim=2) |
|
if cat_coords: |
|
pe = torch.cat([xy, pe], dim=2) |
|
return pe |
|
|