Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,901 Bytes
d9a2e19 1d117d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import torch
class CLIPTextModel_(torch.nn.Module):
"""#### The CLIPTextModel_ module."""
def __init__(
self,
config_dict: dict,
dtype: torch.dtype,
device: torch.device,
operations: object,
):
"""#### Initialize the CLIPTextModel_ module.
#### Args:
- `config_dict` (dict): The configuration dictionary.
- `dtype` (torch.dtype): The data type.
- `device` (torch.device): The device to use.
- `operations` (object): The operations object.
"""
num_layers = config_dict["num_hidden_layers"]
embed_dim = config_dict["hidden_size"]
heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
num_positions = config_dict["max_position_embeddings"]
self.eos_token_id = config_dict["eos_token_id"]
super().__init__()
from modules.clip.Clip import CLIPEmbeddings, CLIPEncoder
self.embeddings = CLIPEmbeddings(
embed_dim,
num_positions=num_positions,
dtype=dtype,
device=device,
operations=operations,
)
self.encoder = CLIPEncoder(
num_layers,
embed_dim,
heads,
intermediate_size,
intermediate_activation,
dtype,
device,
operations,
)
self.final_layer_norm = operations.LayerNorm(
embed_dim, dtype=dtype, device=device
)
def forward(
self,
input_tokens: torch.Tensor,
attention_mask: torch.Tensor = None,
intermediate_output: int = None,
final_layer_norm_intermediate: bool = True,
dtype: torch.dtype = torch.float32,
) -> tuple:
"""#### Forward pass for the CLIPTextModel_ module.
#### Args:
- `input_tokens` (torch.Tensor): The input tokens.
- `attention_mask` (torch.Tensor, optional): The attention mask. Defaults to None.
- `intermediate_output` (int, optional): The intermediate output layer. Defaults to None.
- `final_layer_norm_intermediate` (bool, optional): Whether to apply final layer normalization to the intermediate output. Defaults to True.
#### Returns:
- `tuple`: The output tensor, the intermediate output tensor, and the pooled output tensor.
"""
x = self.embeddings(input_tokens, dtype=dtype)
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape(
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
).expand(
attention_mask.shape[0],
1,
attention_mask.shape[-1],
attention_mask.shape[-1],
)
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
causal_mask = (
torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device)
.fill_(float("-inf"))
.triu_(1)
)
if mask is not None:
mask += causal_mask
else:
mask = causal_mask
x, i = self.encoder(x, mask=mask, intermediate_output=intermediate_output)
x = self.final_layer_norm(x)
if i is not None and final_layer_norm_intermediate:
i = self.final_layer_norm(i)
pooled_output = x[
torch.arange(x.shape[0], device=x.device),
(
torch.round(input_tokens).to(dtype=torch.int, device=x.device)
== self.eos_token_id
)
.int()
.argmax(dim=-1),
]
return x, i, pooled_output
class CLIPTextModel(torch.nn.Module):
"""#### The CLIPTextModel module."""
def __init__(
self,
config_dict: dict,
dtype: torch.dtype,
device: torch.device,
operations: object,
):
"""#### Initialize the CLIPTextModel module.
#### Args:
- `config_dict` (dict): The configuration dictionary.
- `dtype` (torch.dtype): The data type.
- `device` (torch.device): The device to use.
- `operations` (object): The operations object.
"""
super().__init__()
self.num_layers = config_dict["num_hidden_layers"]
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
embed_dim = config_dict["hidden_size"]
self.text_projection = operations.Linear(
embed_dim, embed_dim, bias=False, dtype=dtype, device=device
)
self.dtype = dtype
def get_input_embeddings(self) -> torch.nn.Embedding:
"""#### Get the input embeddings.
#### Returns:
- `torch.nn.Embedding`: The input embeddings.
"""
return self.text_model.embeddings.token_embedding
def set_input_embeddings(self, embeddings: torch.nn.Embedding) -> None:
"""#### Set the input embeddings.
#### Args:
- `embeddings` (torch.nn.Embedding): The input embeddings.
"""
self.text_model.embeddings.token_embedding = embeddings
def forward(self, *args, **kwargs) -> tuple:
"""#### Forward pass for the CLIPTextModel module.
#### Args:
- `*args`: Variable length argument list.
- `**kwargs`: Arbitrary keyword arguments.
#### Returns:
- `tuple`: The output tensors.
"""
x = self.text_model(*args, **kwargs)
out = self.text_projection(x[2])
return (x[0], x[1], out, x[2]) |