File size: 1,395 Bytes
b92e7e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34

import torch
import torch.nn as nn
from transformers import BertModel, BertConfig

class BertHierarchicalClassification(nn.Module):
    def __init__(self, config):
        super(BertHierarchicalClassification, self).__init__()
        self.bert = BertModel(config)
        hidden_size = config.hidden_size

        self.num_grades = config.num_grades
        self.num_domains = config.num_domains
        self.num_clusters = config.num_clusters
        self.num_standards = config.num_standards

        self.grade_classifier = nn.Linear(hidden_size, self.num_grades)
        self.domain_classifier = nn.Linear(hidden_size, self.num_domains)
        self.cluster_classifier = nn.Linear(hidden_size, self.num_clusters)
        self.standard_classifier = nn.Linear(hidden_size, self.num_standards)
        self.dropout = nn.Dropout(0.1)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        pooled_output = self.dropout(pooled_output)

        grade_logits = self.grade_classifier(pooled_output)
        domain_logits = self.domain_classifier(pooled_output)
        cluster_logits = self.cluster_classifier(pooled_output)
        standard_logits = self.standard_classifier(pooled_output)

        return grade_logits, domain_logits, cluster_logits, standard_logits