|
import torch |
|
from transformers import PretrainedConfig |
|
from typing import List |
|
from pdb import set_trace |
|
|
|
class MultiLabelClassifierConfig(PretrainedConfig): |
|
model_type = "multi_label_classification" |
|
problem_type = "multi_label_classification" |
|
|
|
def __init__( |
|
self, |
|
embedding_dim: int=768, |
|
labels: List[str]=[], |
|
transformer_name: str = "bert-base-uncased", |
|
hidden_dim: int = 256, |
|
num_layers: int = 2, |
|
bidirectional: bool = True, |
|
dropout: float =.3, |
|
**kwargs, |
|
): |
|
self.transformer_name = transformer_name |
|
self.hidden_dim = hidden_dim |
|
self.labels = labels |
|
self.num_layers = num_layers |
|
self.bidirectional = bidirectional |
|
self.dropout = dropout |
|
self.num_classes = len(labels) |
|
self.embedding_dim = embedding_dim |
|
|
|
|
|
if 'id2label' not in kwargs: kwargs['id2label'] = {idx:label for idx, label in enumerate(labels)} |
|
if 'label2id' not in kwargs: kwargs['label2id'] = {label:idx for idx, label in enumerate(labels)} |
|
super().__init__(**kwargs) |
|
|