|
|
|
import torch |
|
import torch.nn as nn |
|
from transformers import BertModel, BertConfig |
|
|
|
class BertHierarchicalClassification(nn.Module): |
|
def __init__(self, config): |
|
super(BertHierarchicalClassification, self).__init__() |
|
self.bert = BertModel(config) |
|
hidden_size = config.hidden_size |
|
|
|
self.num_grades = config.num_grades |
|
self.num_domains = config.num_domains |
|
self.num_clusters = config.num_clusters |
|
self.num_standards = config.num_standards |
|
|
|
self.grade_classifier = nn.Linear(hidden_size, self.num_grades) |
|
self.domain_classifier = nn.Linear(hidden_size, self.num_domains) |
|
self.cluster_classifier = nn.Linear(hidden_size, self.num_clusters) |
|
self.standard_classifier = nn.Linear(hidden_size, self.num_standards) |
|
self.dropout = nn.Dropout(0.1) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
|
pooled_output = outputs.pooler_output |
|
pooled_output = self.dropout(pooled_output) |
|
|
|
grade_logits = self.grade_classifier(pooled_output) |
|
domain_logits = self.domain_classifier(pooled_output) |
|
cluster_logits = self.cluster_classifier(pooled_output) |
|
standard_logits = self.standard_classifier(pooled_output) |
|
|
|
return grade_logits, domain_logits, cluster_logits, standard_logits |
|
|