File size: 1,163 Bytes
cc10c23 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import torch
from transformers import PretrainedConfig
from typing import List
from pdb import set_trace
class MultiLabelClassifierConfig(PretrainedConfig):
model_type = "multi_label_classification"
problem_type = "multi_label_classification"
def __init__(
self,
embedding_dim: int=768,
labels: List[str]=[],
transformer_name: str = "bert-base-uncased",
hidden_dim: int = 256,
num_layers: int = 2,
bidirectional: bool = True,
dropout: float =.3,
**kwargs,
):
self.transformer_name = transformer_name
self.hidden_dim = hidden_dim
self.labels = labels
self.num_layers = num_layers
self.bidirectional = bidirectional
self.dropout = dropout
self.num_classes = len(labels)
self.embedding_dim = embedding_dim
#self.nlp_config = config.to_dict()
if 'id2label' not in kwargs: kwargs['id2label'] = {idx:label for idx, label in enumerate(labels)}
if 'label2id' not in kwargs: kwargs['label2id'] = {label:idx for idx, label in enumerate(labels)}
super().__init__(**kwargs)
|