import torch from transformers import PretrainedConfig from typing import List from pdb import set_trace class MultiLabelClassifierConfig(PretrainedConfig): model_type = "multi_label_classification" problem_type = "multi_label_classification" def __init__( self, embedding_dim: int=768, labels: List[str]=[], transformer_name: str = "bert-base-uncased", hidden_dim: int = 256, num_layers: int = 2, bidirectional: bool = True, dropout: float =.3, **kwargs, ): self.transformer_name = transformer_name self.hidden_dim = hidden_dim self.labels = labels self.num_layers = num_layers self.bidirectional = bidirectional self.dropout = dropout self.num_classes = len(labels) self.embedding_dim = embedding_dim #self.nlp_config = config.to_dict() if 'id2label' not in kwargs: kwargs['id2label'] = {idx:label for idx, label in enumerate(labels)} if 'label2id' not in kwargs: kwargs['label2id'] = {label:idx for idx, label in enumerate(labels)} super().__init__(**kwargs)