import torch class CLIPTextModel_(torch.nn.Module): """#### The CLIPTextModel_ module.""" def __init__( self, config_dict: dict, dtype: torch.dtype, device: torch.device, operations: object, ): """#### Initialize the CLIPTextModel_ module. #### Args: - `config_dict` (dict): The configuration dictionary. - `dtype` (torch.dtype): The data type. - `device` (torch.device): The device to use. - `operations` (object): The operations object. """ num_layers = config_dict["num_hidden_layers"] embed_dim = config_dict["hidden_size"] heads = config_dict["num_attention_heads"] intermediate_size = config_dict["intermediate_size"] intermediate_activation = config_dict["hidden_act"] num_positions = config_dict["max_position_embeddings"] self.eos_token_id = config_dict["eos_token_id"] super().__init__() from modules.clip.Clip import CLIPEmbeddings, CLIPEncoder self.embeddings = CLIPEmbeddings( embed_dim, num_positions=num_positions, dtype=dtype, device=device, operations=operations, ) self.encoder = CLIPEncoder( num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations, ) self.final_layer_norm = operations.LayerNorm( embed_dim, dtype=dtype, device=device ) def forward( self, input_tokens: torch.Tensor, attention_mask: torch.Tensor = None, intermediate_output: int = None, final_layer_norm_intermediate: bool = True, dtype: torch.dtype = torch.float32, ) -> tuple: """#### Forward pass for the CLIPTextModel_ module. #### Args: - `input_tokens` (torch.Tensor): The input tokens. - `attention_mask` (torch.Tensor, optional): The attention mask. Defaults to None. - `intermediate_output` (int, optional): The intermediate output layer. Defaults to None. - `final_layer_norm_intermediate` (bool, optional): Whether to apply final layer normalization to the intermediate output. Defaults to True. #### Returns: - `tuple`: The output tensor, the intermediate output tensor, and the pooled output tensor. """ x = self.embeddings(input_tokens, dtype=dtype) mask = None if attention_mask is not None: mask = 1.0 - attention_mask.to(x.dtype).reshape( (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) ).expand( attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1], ) mask = mask.masked_fill(mask.to(torch.bool), float("-inf")) causal_mask = ( torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device) .fill_(float("-inf")) .triu_(1) ) if mask is not None: mask += causal_mask else: mask = causal_mask x, i = self.encoder(x, mask=mask, intermediate_output=intermediate_output) x = self.final_layer_norm(x) if i is not None and final_layer_norm_intermediate: i = self.final_layer_norm(i) pooled_output = x[ torch.arange(x.shape[0], device=x.device), ( torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id ) .int() .argmax(dim=-1), ] return x, i, pooled_output class CLIPTextModel(torch.nn.Module): """#### The CLIPTextModel module.""" def __init__( self, config_dict: dict, dtype: torch.dtype, device: torch.device, operations: object, ): """#### Initialize the CLIPTextModel module. #### Args: - `config_dict` (dict): The configuration dictionary. - `dtype` (torch.dtype): The data type. - `device` (torch.device): The device to use. - `operations` (object): The operations object. """ super().__init__() self.num_layers = config_dict["num_hidden_layers"] self.text_model = CLIPTextModel_(config_dict, dtype, device, operations) embed_dim = config_dict["hidden_size"] self.text_projection = operations.Linear( embed_dim, embed_dim, bias=False, dtype=dtype, device=device ) self.dtype = dtype def get_input_embeddings(self) -> torch.nn.Embedding: """#### Get the input embeddings. #### Returns: - `torch.nn.Embedding`: The input embeddings. """ return self.text_model.embeddings.token_embedding def set_input_embeddings(self, embeddings: torch.nn.Embedding) -> None: """#### Set the input embeddings. #### Args: - `embeddings` (torch.nn.Embedding): The input embeddings. """ self.text_model.embeddings.token_embedding = embeddings def forward(self, *args, **kwargs) -> tuple: """#### Forward pass for the CLIPTextModel module. #### Args: - `*args`: Variable length argument list. - `**kwargs`: Arbitrary keyword arguments. #### Returns: - `tuple`: The output tensors. """ x = self.text_model(*args, **kwargs) out = self.text_projection(x[2]) return (x[0], x[1], out, x[2])