File size: 1,163 Bytes
cc10c23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
from transformers import PretrainedConfig
from typing import List
from pdb import set_trace 

class MultiLabelClassifierConfig(PretrainedConfig):
    model_type = "multi_label_classification"
    problem_type = "multi_label_classification"

    def __init__(
        self,
        embedding_dim: int=768,
        labels: List[str]=[],
        transformer_name: str = "bert-base-uncased",
        hidden_dim: int = 256,
        num_layers: int = 2,
        bidirectional: bool = True,
        dropout: float =.3,
        **kwargs,
    ):
        self.transformer_name = transformer_name 
        self.hidden_dim = hidden_dim 
        self.labels = labels
        self.num_layers = num_layers 
        self.bidirectional = bidirectional 
        self.dropout = dropout 
        self.num_classes = len(labels)
        self.embedding_dim = embedding_dim

        #self.nlp_config = config.to_dict()
        if 'id2label' not in kwargs: kwargs['id2label'] = {idx:label for idx, label in enumerate(labels)}
        if 'label2id' not in kwargs: kwargs['label2id'] = {label:idx for idx, label in enumerate(labels)}
        super().__init__(**kwargs)