import torch import torch.nn as nn from transformers import PreTrainedModel from collections import OrderedDict from transformers.modeling_outputs import SequenceClassifierOutput from typing import List, Optional, Tuple, Union from .configuration import MultiLabelClassifierConfig class MultiLabelClassifierModel(PreTrainedModel): config_class = MultiLabelClassifierConfig def __init__(self, config): super().__init__(config) self.nlp_model = torch.hub.load('huggingface/pytorch-transformers', 'model', config.transformer_name) self.rnn = nn.GRU(config.embedding_dim, config.hidden_dim, num_layers = config.num_layers, bidirectional = config.bidirectional, batch_first = True, dropout = 0 if config.num_layers < 2 else config.dropout) self.dropout = nn.Dropout(config.dropout) self.out = nn.Linear(config.hidden_dim * 2 if config.bidirectional else config.hidden_dim, config.num_classes) def forward(self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, )-> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: output = self.nlp_model( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) _, hidden = self.rnn(output['last_hidden_state']) if self.rnn.bidirectional: hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)) else: hidden = self.dropout(hidden[-1,:,:]) logits = self.out(hidden) return SequenceClassifierOutput( logits=logits, hidden_states=output.hidden_states, attentions=output.attentions, )