iolimat482's picture
Upload BERT hierarchical classification model for grades 1, 2 and 3
b92e7e5 verified
import torch
import torch.nn as nn
from transformers import BertModel, BertConfig
class BertHierarchicalClassification(nn.Module):
def __init__(self, config):
super(BertHierarchicalClassification, self).__init__()
self.bert = BertModel(config)
hidden_size = config.hidden_size
self.num_grades = config.num_grades
self.num_domains = config.num_domains
self.num_clusters = config.num_clusters
self.num_standards = config.num_standards
self.grade_classifier = nn.Linear(hidden_size, self.num_grades)
self.domain_classifier = nn.Linear(hidden_size, self.num_domains)
self.cluster_classifier = nn.Linear(hidden_size, self.num_clusters)
self.standard_classifier = nn.Linear(hidden_size, self.num_standards)
self.dropout = nn.Dropout(0.1)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
pooled_output = self.dropout(pooled_output)
grade_logits = self.grade_classifier(pooled_output)
domain_logits = self.domain_classifier(pooled_output)
cluster_logits = self.cluster_classifier(pooled_output)
standard_logits = self.standard_classifier(pooled_output)
return grade_logits, domain_logits, cluster_logits, standard_logits